import os
import sys
import json
import spacy
import pickle
from collections import Counter
from tqdm import tqdm
import numpy as np
import matplotlib.pyplot as plt
"""
Plant Seeds in a Basic Seed Tray
Grow Avocado
Touch up the scratches
Replace a Piston Ring on a Dirtbike
Create a Tropical Bouquet
Repair a Door Frame
"""

def get_surr_context(step, step_goal, step2goal):
    cl = 1
    try:
        step_context = step_goal[step2goal[step]]['caption']
    except KeyError:
        return [step], 0
    _t = []
    for x in step_context:
        x = x if not x.endswith('.') else x[:-1].strip()
        _t.append(x)
    step_context = _t
    step_index = step_context.index(step)
    # get context
    pre = None
    post = None
    max_pre = step_index
    max_post = len(step_context) - step_index - 1
    if max_pre >= cl and max_post >= cl:
        pre = cl
        post = cl
    elif max_pre < cl:
        pre = step_index
        post = min(cl + cl - step_index, max_post)
    elif max_post < cl:
        post = len(step_context) - step_index - 1
        pre = min(max_pre, cl + cl - len(step_context) + step_index + 1)
    assert pre is not None and post is not None
    pre = int(pre)
    post = int(post)
    step_context = step_context[step_index - pre: step_index + post + 1]
    goal = step2goal[step]
    return step_context, goal

def compare_p1_p2():
    with open("./data/wikihow/step_goal.json", "r") as f:
        step_goal = json.load(f)
        step_goal = {x['task']: x for x in step_goal}
    with open("././data/wikihow/step2goal.json", "r") as f:
        step2goal = json.load(f)

    # compare phase one and two
    with open("./data/wikihow/gold.rerank.org.t30.test.json", "r") as f:
        r1 = json.load(f)
    with open("./data/wikihow/gold.rerank.org.t30.test.deberta.para_score.goal.c1.result", "r") as f:
        r2 = json.load(f)
        r2 = {x['step']: x for x in r2}

    # steps = list(r1.keys())
    # _steps = list(r2.keys())
    # for x in r1.keys():
    #     if x not in r2.keys():
    #         print(x)

    steps = [x for x in r1.keys() if x in r2.keys()]
    compare = []
    win = 0
    loss = 0
    for s in steps:
        p1 = r1[s]
        p2 = r2[s]
        gold = p1['gold_goal']
        assert gold == p2['gold']
        p1_goal = p1['retrieved_goals'][0]
        p2_goal = sorted(p2['pred'].items(), key=lambda x: x[1], reverse=True)[0][0]
        if p1_goal != gold and p2_goal == gold:
            win += 1
            s_ctx, s_goal = get_surr_context(s, step_goal, step2goal)
            compare.append(f"{s}\t{gold}\t{p1_goal}\t{p2_goal}\t{' || '.join(s_ctx)}\t{s_goal}")
        elif p2_goal != gold and p1_goal == gold:
            loss += 1

    for item in compare:
        print(item)

    print(len(steps), win, loss)

def count_num(nlp, sent_list):
    c = Counter()
    for s in tqdm(sent_list):
        doc = nlp(s)
        for token in doc:
            if token.pos_ == 'VERB':
                word = token.lemma_.lower()
                c[word] += 1
    c = sorted(c.items(), key=lambda x: x[1], reverse=True)
    return c



def expansion_verb_analysis():
    if not os.path.exists("./data/wikihow/human/verb_annotation_train_null.json"):
        _ = input("From scratch? ")
        with open("./data/wikihow/human/step_goal.para.d1.all_base.all_expansion.rerank.goal.c1.train_null00.50.howto1k.all.json", "r") as f:
            d = json.load(f)
        has_ep_step = []
        ep_step = []
        for item in d:
            cap, cap_source = item['caption'], item['caption_source']
            for index, s in enumerate(cap_source):
                if len(s.split(".")) == 2 and f"{s}.0" in cap_source:
                    has_ep_step.append(cap[index])
                    for i in range(0, 100):
                        if f"{s}.{i}" in cap_source:
                            ep_step.append(cap[cap_source.index(f"{s}.{i}")])
                        else:
                            assert i >= 1
                            break
        print(len(has_ep_step), len(ep_step))

        nlp = spacy.load("en_core_web_sm")
        has_ep_count = count_num(nlp, has_ep_step)
        ep_count = count_num(nlp, ep_step)

        for x in has_ep_count[:100]:
            print(x)
        print("=====================")
        for x in ep_count[:100]:
            print(x)

        with open("./data/wikihow/human/verb_annotation_train_null.json", "w+") as f:
            json.dump({'has_ep_count': has_ep_count, 'ep_count': ep_count}, f, indent=2)
    else:
        with open("./data/wikihow/human/verb_annotation_train_null.json", "r") as f:
            d = json.load(f)
            has_ep_count, ep_count = d['has_ep_count'], d['ep_count']
            full_has_ep_words= [x[0] for x in has_ep_count]
            full_ep_words= [x[0] for x in ep_count]
            has_ep_count = has_ep_count[:100]
            ep_count = ep_count[:100]
            has_ep_words = [x[0] for x in has_ep_count]
            ep_words = [x[0] for x in ep_count]
            delete_words = [x for x in has_ep_words if x not in ep_words]
            add_words = [x for x in ep_words if x not in has_ep_words]

            increase_words = []
            decrease_words = []
            for hi, w in enumerate(has_ep_words):
                if w in ep_words:
                    ei = ep_words.index(w)
                    if hi > ei:
                        increase_words.append([w, hi, ei])
                    elif ei > hi:
                        decrease_words.append([w, hi, ei])
            for w in delete_words:
                decrease_words.append([w, has_ep_words.index(w), full_ep_words.index(w)])
            for w in add_words:
                increase_words.append([w, full_has_ep_words.index(w), ep_words.index(w)])

            decrease_words = sorted(decrease_words, key=lambda x: x[1] - x[2], reverse=True)
            increase_words = sorted(increase_words, key=lambda x: x[1] - x[2], reverse=True)

            # interval = 1
            # fig, ax = plt.subplots(nrows=1, ncols=1)
            # xs = np.arange(len(decrease_words + increase_words))
            # x1 = np.arange(len(decrease_words))
            # x2 = np.arange(len(decrease_words), len(decrease_words + increase_words))
            #
            # gap = [x[1] - x[2] for x in decrease_words]
            # ax.bar(x1 * interval, gap, alpha=0.5, color='b')
            #
            # gap = [x[1] - x[2] for x in increase_words]
            # ax.bar(x2 * interval, gap, alpha=0.5, color='r')
            #
            # labels = [x[0] for x in decrease_words + increase_words]
            # plt.sca(ax)
            # plt.xticks(xs * interval, labels, rotation='vertical')
            # plt.legend()
            # plt.show()

            interval = 20
            width = 16
            fs = 12
            shrink = 0.3
            xs = np.arange(len(decrease_words + increase_words))
            x1 = np.arange(len(decrease_words))
            gap1 = [(x[1] - x[2]) for x in decrease_words]
            x2 = np.arange(len(decrease_words), len(decrease_words + increase_words))
            gap2 = [(x[1] - x[2]) for x in increase_words]

            fig = plt.figure(figsize=(20, 10))
            bar1 = plt.bar(x1 * interval, gap1, width = width)
            bar2 = plt.bar(x2 * interval, gap2, width = width)
            # print(gap1[-3:], gap2[:3]) # [-115, -127, -224] [264, 196, 147]

            # Add counts above the two bar graphs
            label1 = [x[0] for x in decrease_words]
            for t, rect in zip(label1, bar1):
                height = rect.get_height()
                plt.text(rect.get_x() + rect.get_width() / 2.0, height - 0.5, t, ha='center', va='top', rotation='vertical', fontsize=fs)

            label2 = [x[0] for x in increase_words]
            for t, rect in zip(label2, bar2):
                height = rect.get_height()
                plt.text(rect.get_x() + rect.get_width() / 2.0, height + 0.5, t, ha='center', va='bottom', rotation='vertical', fontsize=fs)

            # labels = [x[0] for x in decrease_words + increase_words]
            # plt.xticks(xs * interval, labels, rotation='vertical', fontsize=16)
            plt.ylabel('cluster 1 rank - cluster 2 rank')
            plt.xticks([])
            plt.box(False)
            plt.tight_layout()
            plt.savefig('./data/wikihow/human/verb_distribution.pdf')
            # plt.show()

            # for x, y in zip([x[0] for x in decrease_words + increase_words], [x[1] - x[2] for x in decrease_words + increase_words]):
            #     print(x, y)

            # print("=====delete words=====")
            # print(delete_words)
            # print("=====add words=====")
            # print(add_words)
            #
            # print("=====improve words=====")
            # for x in increase_words:
            #     print(x)
            #
            # print("=====decrease words=====")
            # for x in decrease_words:
            #     print(x)


    # doc = nlp(step)
    # for token in doc:
    #     # print(token.pos_)
    #     if token.pos_ == "VERB":
    #         word = token.lemma_.lower()
    #         if word not in verb_count:
    #             verb_count[word] = 1
    #         else:
    #             verb_count[word] += 1

def draw_part():
    with open("./data/wikihow/human/verb_annotation_train_null.json", "r") as f:
        d = json.load(f)
        has_ep_count, ep_count = d['has_ep_count'], d['ep_count']
        full_has_ep_words = [x[0] for x in has_ep_count]
        full_ep_words = [x[0] for x in ep_count]
        has_ep_count = has_ep_count[:100]
        ep_count = ep_count[:100]
        has_ep_words = [x[0] for x in has_ep_count]
        ep_words = [x[0] for x in ep_count]
        delete_words = [x for x in has_ep_words if x not in ep_words]
        add_words = [x for x in ep_words if x not in has_ep_words]

        increase_words = []
        decrease_words = []
        for hi, w in enumerate(has_ep_words):
            if w in ep_words:
                ei = ep_words.index(w)
                if hi > ei:
                    increase_words.append([w, hi, ei])
                elif ei > hi:
                    decrease_words.append([w, hi, ei])
        for w in delete_words:
            decrease_words.append([w, has_ep_words.index(w), full_ep_words.index(w)])
        for w in add_words:
            increase_words.append([w, full_has_ep_words.index(w), ep_words.index(w)])


        interval = 20
        width = 16
        fs = 10
        shrink = 0.3
        max_index = 10
        decrease_words = sorted(decrease_words, key=lambda x: x[1] - x[2], reverse=True)[-max_index:]
        increase_words = sorted(increase_words, key=lambda x: x[1] - x[2], reverse=True)[:max_index]

        xs = np.arange(len(decrease_words + increase_words))
        x1 = np.arange(len(decrease_words))
        gap1 = [(x[1] - x[2]) for x in decrease_words]
        # gap1[-1] = -128
        # gap1 = [-x for x in gap1]
        x2 = np.arange(len(decrease_words), len(decrease_words + increase_words))
        gap2 = [(x[1] - x[2]) for x in increase_words]
        # gap2[0] = 197

        fig = plt.figure()
        bar1 = plt.bar(x1 * interval, gap1, width=width)
        bar2 = plt.bar(x2 * interval, gap2, width=width)
        # print(gap1[-3:], gap2[:3]) # [-115, -127, -224] [264, 196, 147]

        # Add counts above the two bar graphs
        label1 = [x[0] for x in decrease_words]
        for t, rect in zip(label1, bar1):
            height = rect.get_height()
            plt.text(rect.get_x() + rect.get_width() / 2.0, height - 0.5, t, ha='center', va='top',
                     rotation='vertical', fontsize=fs)

        label2 = [x[0] for x in increase_words]
        for t, rect in zip(label2, bar2):
            height = rect.get_height()
            plt.text(rect.get_x() + rect.get_width() / 2.0, height + 0.5, t, ha='center', va='bottom',
                     rotation='vertical', fontsize=fs)

        # labels = [x[0] for x in decrease_words + increase_words]
        # plt.xticks(xs * interval, labels, rotation='vertical', fontsize=16)
        # plt.axis('off')
        # axes = plt.gca()
        # axes.get_xaxis().set_visible(False)
        # axes.set_ylim([-500, 500])
        # plt.legend(loc='upper left')
        plt.ylabel('cluster 1 rank - cluster 2 rank')
        plt.xticks([])
        plt.box(False)
        plt.tight_layout()
        plt.savefig('./data/wikihow/human/verb_distribution_part.pdf')
        # plt.show()

def human_plot():
    fig, ax = plt.subplots(nrows=1, ncols=2)
    w = 0.3
    x =  np.arange(4)

    case = 1
    # unlinkable
    if case == 1:
        y = [289, 232, 138, 83]  # ul
        z = [295, 221, 143, 87]  # deberta
        k = [202, 174, 148, 175]  # sp
    else:
        y = [134, 58, 27, 11]
        z = [129, 67, 19, 13]
        k = [93, 54, 34, 32]
    # unlinkable
    label = ['exact', 'helpful', 'related', 'unhelpful']
    fs = 15
    ax[0].bar(x - w, y, width=w, align='center', label='DeBERTa-UL')
    ax[0].bar(x, z, width=w, align='center', label='DeBERTa')
    ax[0].bar(x + w, k, width=w, align='center', label='SP')
    # ax[0].set_xticks(label)
    ax[0].set_xticks(x)
    ax[0].set_xticklabels(label, fontsize=fs, rotation=45)
    ax[0].set_title("linkable", fontsize=fs)
    # ymax = ax[0].get_y

    # unlinkable
    if case == 1:
        y = [68, 139, 218, 260] # ul
        z = [84, 156, 193, 246] # deberta
        k = [45, 96, 162, 395] # sp
    else:
        y = [35, 54, 48, 52]
        z = [31, 47, 57, 67]
        k = [16, 31, 54, 89]

    ax[1].bar(x - w, y, width=w, align='center', label='DeBERTa-UL')
    ax[1].bar(x, z, width=w, align='center', label='DeBERTa')
    ax[1].bar(x + w, k, width=w, align='center', label='SP')
    ax[1].set_xticks(x)
    ax[1].set_xticklabels(label, fontsize=fs, rotation=45)
    # ax[1].set_ylim(*ax[0].get_ylim())
    ax[1].set_title("unlinkable", fontsize=fs)
    ax[1].set_yticks([])

    if case == 1:
        ax[0].set_ylim(*ax[1].get_ylim())
    else:
        ax[1].set_ylim(*ax[0].get_ylim())


    plt.legend()
    plt.tight_layout()
    if case == 1:
        plt.savefig('./data/wikihow/human/human_main.pdf')
    else:
        plt.savefig('./data/wikihow/human/human_part.pdf')
    plt.show()


if __name__ == "__main__":
    # draw_part()
    # expansion_verb_analysis()
    # compare_p1_p2()
    human_plot()