import json
from matplotlib import pyplot as plt
plt.switch_backend('agg')
from transformers import GPT2Tokenizer
tokenizer = GPT2Tokenizer.from_pretrained('../gpt2_en_ckpt_origin')
import operator
from functools import reduce

def get_len(text):
    return len(tokenizer.encode(text))

def extract_outline(text):
    words = text.split()
    outline = []
    for i, word in enumerate(words):
        if word == '<|endoftarget|>':
            accum = []
            for j,word in enumerate(words[i+1:]):
                if word == '<|sepofoutline|>':
                    outline.append(' '.join(accum))
                    accum = []
                elif word == '<|beginofbedding|>':
                    return outline
                else:
                    accum.append(word)
    return outline

def extract_bedding(text):
    words = text.split()
    bedding = []
    for i, word in enumerate(words):
        if word == '<|beginofbedding|>':
            accum = []
            for j, word in enumerate(words[i+1:]):
                if word == '<|beginofending|>':
                    return ' '.join(bedding)
                else:
                    bedding.append(word)
    return ' '.join(bedding)


def get_len_distribution(data, file_name, bins, range, new=True, use_tokenizer=False):
    
    if use_tokenizer:
        lens = [get_len(i) for i in data]
    else:
        lens = [len(i) for i in data]

    import numpy as np
    print('total len = ', sum(lens))
    print('avg len = ', np.mean(lens))

    if use_tokenizer:
        filter_data = [i for i in data if get_len(i) <= 1]
        with open("bedding_filter.json", 'w', encoding='utf-8') as f:
            json.dump(filter_data, f, indent=4, separators=[',', ':'])
    
    if new:
        plt.figure()
    n, bins, patches = plt.hist(lens, bins=bins, range=range, ec='black')
    plt.savefig(file_name)


def get_word_distribution(data, output_path):
    result = {}
    for line in data:
        for word in line:
            if word not in result:
                result[word] = 1
            else:
                result[word] += 1
    sorted_result = {k: v for k, v in sorted(result.items(), key=lambda x: x[1], reverse=True)}
    
    with open(output_path, 'w', encoding='utf-8') as f:
        json.dump(sorted_result, f, indent=4, separators=[',', ':'])

    return sorted_result


def stat_generated_outline(path):
    with open("../result/{}".format(path), 'r', encoding='utf-8') as f:
        data = json.load(f)

    outlines = [extract_outline(i['generated']) for i in data]
    print('len = ', len(reduce(operator.add, outlines)))

    base_path = path.split('.')[0]
    get_len_distribution(outlines, f'{base_path}_outline_lens', 20, (0,100))
    get_word_distribution(outlines, f'{base_path}_generated_word_cnt.json')


def stat_ori_data_outline(filter=False):
    for split in ['train', 'valid', 'test']:
        with open(f"../data/{split}_add_node_ending_onecard.json", 'r', encoding='utf-8') as f:
            data = json.load(f)
        if filter:
            outlines = [i['filter_bedding_kws'] + i['ending_kws'] for i in data]
        else:
            outlines = [i['bedding_kws'] + i['ending_kws'] for i in data]

        suffix = 'filter' if filter else 'unfilter'
        get_len_distribution(outlines, f'outline_lens_{split}_{suffix}', 20, (0,20))
        get_word_distribution(outlines, f'generated_word_cnt_{split}_{suffix}.json')


def stat_generate_bedding():
    with open("../result/gpt2_kg_explicit_outline.json", 'r', encoding='utf-8') as f:
        data = json.load(f)

    beddings = [extract_bedding(i['generated']) for i in data]
    get_len_distribution(beddings, 'beddings_lens', 50, (0,200), use_tokenizer=True)


if __name__ == "__main__":

    # stat_ori_data_outline()
    stat_generated_outline('kg_gate_combine2_ending_outline_sqrt_onecard_1.0.json')
    # stat_generated_outline('gpt2_kg_gate_explicit_outline_onecard_1.0.json')
    # stat_generated_outline('gpt2_kg_gate_explicit_outline_onecard_1.0_truetarget.json')
    # stat_generated_outline('gpt2_baseline_explicit_outline_onecard_truetarget.json')
    # stat_generated_outline('gpt2_baseline_explicit_filter_outline_truetarget.json')
    # stat_generated_outline("gpt2_kg_explicit_outline.json")