import json
from collections import defaultdict
import re
from tqdm import tqdm

from transformers import GPT2Tokenizer
from nltk import sent_tokenize, word_tokenize
import numpy as np
from nltk import sent_tokenize, word_tokenize
from matplotlib import pyplot as plt
import sys
from itertools import chain

import numpy as np

sys.path.append('..')
from util.graph_relation import get_conceptnet, KnowledgeGraph

# plt.switch_backend('agg')

tokenizer = GPT2Tokenizer.from_pretrained('../gpt2_en_ckpt_origin')

graph = None

def get_len(text):
    # return len(re.split(' |\r|\n|\t', text))
    return len(tokenizer.encode(text))


def stat(in_file_name):
    stats = defaultdict(int)
    limit = 1000
    with open(in_file_name, encoding='utf-8') as f:
        a = json.load(f)
        for story in a:
            for scene in story["scenes"]:
                l = 0
                cnt = 0
                for entry in scene["entries"]:
                    try:
                        l += get_len(entry['description'])
                    except:
                        pass
                    if l > limit:
                        stats[cnt] += 1
                        break
                    cnt += 1
    stats = sorted(stats.items(), key=lambda item: item[0])
    print(stats)


def stat_train_valid_test():
    def stat_one_file(file_name):
        with open(file_name, encoding='utf-8') as f:
            a = json.load(f)
            print('*' * 40 + file_name + '*' * 40)
            print('samples:', len(a))
            context_len = 0
            persona_len = 0
            output_len = 0
            target_len = 0

            input_kw_num = 0
            story_kw_num = 0

            triple_num = 0

            for scene in tqdm(a):
                for entry in scene['entries'][:-1]:
                    context_len += get_len(entry['description'])

                last_entry = scene['entries'][-1]
                assert len(last_entry['cards']) == 1
                persona_len += get_len(last_entry['cards'][0]['description'])
                output_len += get_len(last_entry['description'])
                peak_idx = scene['peak_idx']
                target_len += get_len(sent_tokenize(last_entry['description'])[peak_idx])
                kws_list = scene['context_kws']
                input_kw_num += len(scene['context_kws'])
                for k in ['bedding_kws', 'target_kws', 'ending_kws']:
                    v = scene[k]
                    x = list(chain(*v))
                    story_kw_num += len(x)
                    kws_list.extend(x)

                temp_triple_num = 0
                for kw in kws_list:
                    temp_triple_num += len(graph.get_triples(kw))

                triple_num += temp_triple_num / len(kws_list)


            print(f"avg context_len:{context_len / len(a)}")
            print(f"avg persona_len:{persona_len / len(a)}")
            print(f"avg output_len:{output_len / len(a)}")
            print(f"avg target_len:{target_len / len(a)}")
            print(f"avg input kw num:{input_kw_num / len(a)}")
            print(f"avg story kw num:{story_kw_num / len(a)}")
            print(f"avg triple num:{triple_num / len(a)}")


    stat_one_file('../data/test_dynamic_persona.json')
    stat_one_file('../data/valid_dynamic_persona.json')
    stat_one_file('../data/train_dynamic_persona.json')


def stat_train_valid_test_kw():
    def stat_one_file(file_name):
        with open(file_name, encoding='utf-8') as f:
            a = json.load(f)
            print('*' * 40 + file_name + '*' * 40)
            print('samples:', len(a))
            context_kw_num = 0
            bedding_kw_num = 0
            target_kw_num = 0
            filter_bedding_kw_num = 0
            intersect_nodes_num = 0

            for scene in tqdm(a):
                context_kw_num += len(scene['context_kws'])
                bedding_kw_num += len(scene['bedding_kws'])
                target_kw_num += len(scene['target_kws'])
                filter_bedding_kw_num += len(scene['filter_bedding_kws'])
                intersect_nodes_num += len(scene['intersect_nodes'])

            print(f"avg context_kw_num:{context_kw_num / len(a)}")
            print(f"avg bedding_kw_num:{bedding_kw_num / len(a)}")
            print(f"avg target_kw_num:{target_kw_num / len(a)}")
            print(f"avg filter_bedding_kw_num:{filter_bedding_kw_num / len(a)}")
            print(f"avg intersect_nodes_num:{intersect_nodes_num / len(a)}")

    stat_one_file('../data/test_add_node_onecard.json')
    stat_one_file('../data/valid_add_node_onecard.json')
    stat_one_file('../data/train_add_node_onecard.json')


def stat_train_valid_test_card():
    def stat_one_file(file_name):
        with open(file_name, encoding='utf-8') as f:
            a = json.load(f)
            print('*' * 40 + file_name + '*' * 40)
            print('samples:', len(a))
            voc = defaultdict(int)
            entry_voc = defaultdict(int)
            for scene in tqdm(a):
                # for card in scene['character']['cards']:
                #     words = re.split(' |\r|\n|\t', card['description'])
                #     for word in words:
                #         voc[word] += 1
                for entry in scene['entries'][1:]:
                    if entry['role'] == 'narrator':
                        continue
                    card_words = []
                    for card in entry['cards']:
                        card_words.extend(re.split(' |\r|\n|\t', card['description']))
                    des = re.split(' |\r|\n|\t', entry['description'])
                    for word in des:
                        if word in card_words:
                            entry_voc[word] += 1
            # voc = sorted(voc.items(), key=lambda item: item[1], reverse=True)
            # print(voc[:150])
            entry_voc = sorted(entry_voc.items(), key=lambda item: item[1], reverse=True)
            print(entry_voc[:150])

    stat_one_file('test.json')
    stat_one_file('valid.json')
    stat_one_file('train.json')


def stat_one_character_entry():
    def stat_one_file(file_name):
        with open(file_name, encoding='utf-8') as f:
            a = json.load(f)
            print('*' * 40 + file_name + '*' * 40)
            print('samples:', len(a))
            # entry_len = 0
            card_len = 0
            # card_num = 0
            used_card_num = 0
            # first_entry_len = 0
            # entry_num = 0
            input_len = 0
            output_len = 0
            # entry_num_with_cards = 0
            for scene in tqdm(a):
                # temp_entry_len = 0
                # temp_card_len = 0
                for idx, entry in enumerate(scene['entries']):
                    if idx == len(scene['entries']) - 1:
                        output_len += get_len(entry['description'])
                        used_card_num += len(entry['cards'])
                        for card in entry['cards']:
                            card_len += get_len(card['description'])
                    else:
                        input_len += get_len(entry['description'])

                    # if entry['role'] != 'narrator' and len(entry['cards']) > 1:
                    #     entry_num_with_cards += 1
                    # if entry['role'] != 'narrator':
                    #     used_card_num += len(entry['cards'])
                # for card in scene['character']['cards']:
                #     temp_card_len += get_len(card['description'])
                #     card_num += 1
                # for card in scene['extra_cards']:
                #     temp_card_len += get_len(card['description'])
                #     card_num += 1
                # entry_len += temp_entry_len
                # card_len += temp_card_len
            # print('entry num per scene (except first entry):', entry_num / len(a))
            print('input avg len:', input_len / len(a))
            print('output avg len:', output_len / len(a))
            print('card avg len:', card_len / used_card_num)
            # print('card avg num per entry', card_num / entry_num)
            # print('used card avg num per entry', used_card_num / entry_num)
            # print('card avg num per scene', card_num / len(a))
            print('used card avg num per scene', used_card_num / len(a))
            # print('card avg len per card', card_len / card_num)
            # print('first entry avg len:', first_entry_len / len(a))
            # print('entry (character) num with multiple cards', entry_num_with_cards)

    stat_one_file('test.json')
    stat_one_file('valid.json')
    stat_one_file('train.json')


def stat_outline_ave_len(file_name):
    with open(file_name, encoding='utf-8') as f:
        a = json.load(f)
        l = 0
        for scene in a:
            l += len(scene['outline'])

        print('avg outline:', l / len(a))




def one_scene_intersection(scene):
    if not scene['bedding_kws']:
        return -1
    context_kws = scene['context_kws']
    target_kws = scene['target_kws']

    # graph.show_paths(context_kws, target_kws)
    # exit()

    context_points = graph.get_hops_set(context_kws, hop=2)
    target_points = graph.get_hops_set(target_kws, hop=2)

    # print('context points:', len(context_points))
    # print('target points:', len(target_points))
    # # return len(context_points)
    # print('intersect points:', len(context_points & target_points))
    # print('contextpoints:::', context_points)
    # print('targetpoints:::', target_points)
    # print('intersectpoints:::', context_points & target_points)

    # exit()
    return len(context_points & target_points)


def stat_conceptnet(file_name):
    global graph
    graph = get_conceptnet()
    print('finish get graph')
    print('graph avg deg:', graph.get_avg_deg())
    print('graph node num:', graph.get_node_num())

    # exit()

    def stat_intersection(a):
        from functools import partial
        from multiprocessing import Pool
        from os import cpu_count
        # partial_work = partial(one_scene_intersection, graph=graph)
        print('start pool')
        with Pool(cpu_count()) as pool:
            res = list(tqdm(pool.imap(one_scene_intersection, a), total=len(a)))

        res = [x for x in res if x != -1]
        print('avg intersection for 2 hops:', sum(res) / len(res))

    with open(file_name, encoding='utf-8') as f:
        a = json.load(f)
        stat_intersection(a)
        exit()
        c_score = 0
        t_score = 0
        cnt = 0
        near_bed_kw_num = 0
        bedding_kw_num = 0
        context_kw_num = 0
        target_kw_num = 0
        appear_target_kw_num = 0
        for scene in tqdm(a):
            context_kws = graph.filter_points(scene['context_kws'])
            bedding_kws = graph.filter_points(scene['bedding_kws'])
            target_kws = graph.filter_points(scene['target_kws'])

            if not bedding_kws:
                continue
            cnt += 1

            bedding_kw_num += len(bedding_kws)
            context_kw_num += len(context_kws)
            target_kw_num += len(target_kws)

            for word in target_kws:
                if word in context_kws:
                    appear_target_kw_num += 1
            # continue
            c = 0
            t = 0

            # bedding = sent_tokenize(scene['entries'][-1]['description'])[:scene['peak_idx']]
            # new_bedding_kws = []
            # for sent in bedding:
            #     new_bedding_kws.extend(word_tokenize(sent))
            #
            # for word in bedding_kws:
            #     assert word in new_bedding_kws
            # continue
            # print(bedding_kws)
            # exit()
            for bed_kw in bedding_kws:
                cs = graph.get_dis(bed_kw, context_kws)
                ts = graph.get_dis(bed_kw, target_kws)
                c += cs
                t += ts

                if cs < 3 and ts < 3:
                    near_bed_kw_num += 1

            c /= len(bedding_kws)
            t /= len(bedding_kws)

            c_score += c
            t_score += t

        print('c_score:', c_score / cnt)
        print('t_score:', t_score / cnt)
        print('near bed kw num:', near_bed_kw_num / cnt)
        print('bedding kws num:', bedding_kw_num / cnt)
        print('context kws num:', context_kw_num / cnt)
        print('target kws num:', target_kw_num / cnt)
        print('appear target kws num:', appear_target_kw_num / cnt)


def get_len_distribution(data, bins, range, new=True, save=None):
    # data = [len(i) for i in data]
    if new:
        plt.figure()
    n, bins, patches = plt.hist(data, bins=bins, range=range, ec='black')
    plt.show()
    if save:
        plt.savefig(save)
    # plt.savefig(file_name)


def stat_bedding():
    def stat_bedding_one_file(file_name):
        with open(file_name, encoding='utf-8') as f:
            a = json.load(f)
            sent_sum = 0
            word_sum = 0
            sent_list = []
            word_list = []
            for scene in tqdm(a):
                text = scene['entries'][-1]['description']
                peak_idx = scene['peak_idx']
                sents = sent_tokenize(text)
                bedding_sents = sents[:peak_idx]
                sent_sum += peak_idx
                sent_list.append(peak_idx)
                temp_word_sum = 0
                for sent in bedding_sents:
                    temp_word_sum += len(word_tokenize(sent))
                word_sum += temp_word_sum
                word_list.append(temp_word_sum)
            print(f"{file_name} statistic")
            print(f"avg sent num:{sent_sum / len(a)}")
            print(f"avg word num:{word_sum / len(a)}")
            # get_len_distribution(sent_list, 20, (0, 20))
            get_len_distribution(word_list, 40, (0, 200))

    for split in ['train', 'valid', 'test']:
        stat_bedding_one_file(f"../data/{split}_add_node.json")


def stat_kws_graph_dis(file_name, graph: KnowledgeGraph, window=1, sent_level=True):
    with open(file_name, encoding='utf-8') as f:
        a = json.load(f)
        lens = []
        for scene in tqdm(a):
            if sent_level:
                kws = scene['bedding_kws'] + scene['ending_kws']
                for i in range(1, len(kws)):
                    before = kws[max(0, i - window):i]
                    from itertools import chain
                    before_kws = list(chain(*before))
                    for kw in kws[i]:
                        lens.append(graph.get_dis(kw, before_kws, max_hop=4))
            else:
                kws = scene['bedding_kws'] + scene['ending_kws']
                from itertools import chain
                kws = list(chain(*kws))
                for i in range(1, len(kws)):
                    before_kws = kws[max(0, i - window):i]
                    lens.append(graph.get_dis(kws[i], before_kws, max_hop=4))

        get_len_distribution(lens, 5, (0, 5), save=f'sent_level{sent_level}_{window}.png')


def stat_kws_hopset_work(scene, sent_level=False):
    max_hop = 2
    lens = [[] for _ in range(max_hop)]
    # print(lens)
    if sent_level:
        kws = scene['bedding_kws'] + scene['ending_kws']
        for i in range(1, len(kws)):
            before = kws[max(0, i - window):i]
            from itertools import chain
            before_kws = list(chain(*before))
            for hop in range(1, max_hop + 1):
                lens[hop - 1].append(len(graph.get_hops_set(before_kws, hop)))
    else:
        kws = scene['bedding_kws'] + scene['ending_kws']
        from itertools import chain
        kws = list(chain(*kws))
        for i in range(1, len(kws)):
            before_kws = kws[max(0, i - window):i]
            for hop in range(1, max_hop + 1):
                lens[hop - 1].append(len(graph.get_hops_set(before_kws, hop)))

    return lens


def stat_kws_hopset(file_name, graph: KnowledgeGraph, window=1, sent_level=True):
    with open(file_name, encoding='utf-8') as f:
        a = json.load(f)
        max_hop = 2
        # lens = [[] for _ in range(max_hop)]
        # print(lens)
        from multiprocessing import Pool
        from os import cpu_count
        with Pool(cpu_count()) as pool:
            lens = list(tqdm(pool.imap(stat_kws_hopset_work, a), total=len(a)))
        # print(lens)
        # for scene in tqdm(a):
        #     if sent_level:
        #         kws = scene['bedding_kws'] + scene['ending_kws']
        #         for i in range(1, len(kws)):
        #             before = kws[max(0, i - window):i]
        #             from itertools import chain
        #             before_kws = list(chain(*before))
        #             for hop in range(1, max_hop + 1):
        #                 lens[hop - 1].append(len(graph.get_hops_set(before_kws, hop)))
        #     else:
        #         kws = scene['bedding_kws'] + scene['ending_kws']
        #         from itertools import chain
        #         kws = list(chain(*kws))
        #         for i in range(1, len(kws)):
        #             before_kws = kws[max(0, i - window):i]
        #             for hop in range(1, max_hop + 1):
        #                 lens[hop - 1].append(len(graph.get_hops_set(before_kws, hop)))

        new_lens = [[] for _ in range(max_hop)]
        for x in lens:
            for i in range(max_hop):
                new_lens[i].extend(x[i])
        lens = new_lens
        # print(lens)
        for hop in range(1, max_hop + 1):
            max_hop_num = max(lens[hop - 1])
            get_len_distribution(lens[hop - 1], 10, (0, max_hop_num),
                                 save=f'sent_level{sent_level}_window{window}_hop{hop}.png')
            print(
                f'sent_level:{sent_level}, windows size:{window}, hop: {hop}, avg hop num:{sum(lens[hop - 1]) / len(lens[hop - 1])}')


def stat_kws_in_hopset_ratio(file_name):
    with open(file_name, encoding='utf-8') as f:
        a = json.load(f)
        cnt = 0
        sum_cnt = 0
        for scene in tqdm(a):
            from itertools import chain
            kws_list = list(chain(*(scene['bedding_kws'] + scene['ending_kws'])))
            for i in range(1, len(kws_list)):
                sum_cnt += 1
                now = kws_list[i]
                before = kws_list[i - 1]
                if now in graph.get_hops_set([before], hop=2):
                    cnt += 1
        print(f"cnt: {cnt}, sum_cnt: {sum_cnt}, ratio: {cnt / sum_cnt}")


def stat_kws_in_hopset_ratio_all(file_name):
    '''
    1-hop的总数
    预测时在1-hop的比例
    '''
    with open(file_name, encoding='utf-8') as f:
        a = json.load(f)
        cnt = 0
        sum_cnt = 0
        len_cnt = []
        for scene in tqdm(a):
            from itertools import chain
            context_kws_list = scene['context_kws'] + list(chain(*scene['target_kws']))
            kws_list = list(chain(*(scene['bedding_kws'] + scene['ending_kws'])))
            for i in range(1, len(kws_list)):
                sum_cnt += 1
                now = kws_list[i]
                before = kws_list[:i] + context_kws_list
                # a = list(graph.get_hops_set(before, hop=1))
                # b = list(graph.get_hops_set(before, hop=2))
                # print('now = ', now)
                # print('1-hop = ', a[:10])
                # print('2-hop = ', b[:10])
                # print(a != b)  
                # assert set(a)|set(b) == set(b)
                # print('='*100)          
                hop_set = graph.get_hops_set(before, hop=1)
                if now in hop_set:
                    cnt += 1
                len_cnt.append(len(hop_set))
        print(f"cnt: {cnt}, sum_cnt: {sum_cnt}, ratio: {cnt / sum_cnt}, mean_set_size = {np.mean(len_cnt)}")


def stat_kws_alpha_ratio(file_name):
    with open(file_name, encoding='utf-8') as f:
        a = json.load(f)
        cnt = 0
        sum_cnt = 0
        for scene in tqdm(a):
            from itertools import chain
            kws_list = list(chain(*(scene['bedding_kws'] + scene['ending_kws'])))
            sum_cnt += len(kws_list)
            for word in kws_list:
                if word[0].isalpha():
                    cnt += 1

        print(f"cnt:{cnt}, sum_cnt:{sum_cnt}, ratio:{cnt / sum_cnt}")


def stat_persona_kws_in_hopset_ratio_all(file_name):
    with open(file_name, encoding='utf-8') as f:
        a = json.load(f)
        cnt = 0
        sum_cnt = 0
        len_cnt = []
        for scene in tqdm(a):
            from itertools import chain
            persona_kws = scene['persona_kws']
            sents = sent_tokenize(scene['entries'][-1]['description'])
            peak_idx = scene['peak_idx']
            target = sents[peak_idx]
            target_kws = list(chain(*scene['target_kws']))
            for i in range(len(target_kws)):
                hop_set = graph.get_hops_set(persona_kws + target_kws[:i], hop=2)
                cur = target_kws[i]
                sum_cnt += 1
                if cur in hop_set:
                    cnt += 1
                len_cnt.append(len(hop_set))
            # context_kws_list = scene['context_kws'] + list(chain(*scene['target_kws']))
            # kws_list = list(chain(*(scene['bedding_kws'] + scene['ending_kws'])))
            # for i in range(1, len(kws_list)):
            #     sum_cnt += 1
            #     now = kws_list[i]
            #     before = kws_list[:i] + context_kws_list
            #     # a = list(graph.get_hops_set(before, hop=1))
            #     # b = list(graph.get_hops_set(before, hop=2))
            #     # print('now = ', now)
            #     # print('1-hop = ', a[:10])
            #     # print('2-hop = ', b[:10])
            #     # print(a != b)  
            #     # assert set(a)|set(b) == set(b)
            #     # print('='*100)          
            #     hop_set = graph.get_hops_set(before, hop=1)
            #     if now in hop_set:
            #         cnt += 1
            #     len_cnt.append(len(hop_set))
        print(f"cnt: {cnt}, sum_cnt: {sum_cnt}, ratio: {cnt / sum_cnt}, mean_set_size = {np.mean(len_cnt)}")


def stat_father_kws_dis(file_name):
    with open(file_name, encoding='utf-8') as f:
        a = json.load(f)
        cnt = 0
        sum_cnt = 0
        len_cnt = []
        dis = []
        for scene in tqdm(a):
            context_kws_list = scene['context_kws'] + list(chain(*scene['target_kws']))
            kws_list = list(chain(*(scene['bedding_kws'] + scene['ending_kws'])))
            for i in range(len(kws_list)):
                sum_cnt += 1
                now = kws_list[i]
                before = kws_list[:i] + context_kws_list
                hop_set = graph.get_hops_set([now], hop=1)
                d = None
                for j in range(len(before)):
                    id = len(before) - j - 1
                    if before[id] in hop_set:
                        d = j + 1
                        cnt += 1
                        break
                if d is not None:
                    dis.append(d)
        print(f"cnt: {cnt}, sum_cnt: {sum_cnt}, ratio: {cnt / sum_cnt}")
        print(max(dis))
        get_len_distribution(dis, 50, (0, 300))


def table2_stat(file_name):
    print('file_name = ', file_name)
    with open(file_name, 'r') as f:
        data = json.load(f)

    target_lens = []
    keyword_lens = []  # bedding + ending
    node_in_kg = []

    for scene in tqdm(data):
        last_entry = scene['entries'][-1]
        sents = sent_tokenize(last_entry['description'])
        target = sents[scene['peak_idx']]
        target_lens.append(get_len(target))

        kws_list = list(chain(*(scene['bedding_kws'] + scene['ending_kws'])))

        keyword_lens.append(len(kws_list))
        node_in_kg.append(len(graph.filter_points(kws_list)))

    print('target_len = ', np.mean(target_lens))
    print('keywords_len = ', np.mean(keyword_lens))
    print('node_in_kg = ', np.mean(node_in_kg))


if __name__ == '__main__':
    graph = get_conceptnet()
    stat_train_valid_test()
    exit()
    for split in ['train', 'valid', 'test']:
        table2_stat(f"../data/{split}_dynamic_persona.json")
        print('=' * 100)
    # stat_father_kws_dis('../data/test_dynamic_persona.json')
    # stat_persona_kws_in_hopset_ratio_all("../data/test_dynamic_persona.json")
    # stat_kws_in_hopset_ratio('../data/test_dynamic.json')
    # stat_kws_in_hopset_ratio_all('../data/test_dynamic.json')
    # stat_kws_alpha_ratio('../data/test_dynamic.json')
    exit()
    for window in range(1, 4):
        # stat_kws_graph_dis(f'../data/test_dynamic.json', graph, window=window, sent_level=False)
        stat_kws_hopset(f'../data/test_dynamic.json', graph, window=window, sent_level=False)
    # stat('test_stories.json')
    # stat('valid_stories.json')
    # stat('train_stories.json')
    # stat_train_valid_test()
    # stat_train_valid_test_kw()
    # stat_train_valid_test_card()
    # stat_one_character_entry()
    # stat_outline_ave_len('gpt2_persona_guide_true_outline.json')
    # get_conceptnet()
    # stat_conceptnet('test_peak_context_target_kw.json')
    # stat_bedding()
