import rouge
import torch
from rouge import Rouge
import json
import nltk
from nltk.corpus import stopwords
from nltk.translate.bleu_score import sentence_bleu, corpus_bleu, SmoothingFunction
from nltk import ngrams
from nltk.tokenize import sent_tokenize, word_tokenize
from bert_score import score, BERTScorer
from transformers import AutoTokenizer, GPT2Tokenizer
import re
from itertools import chain
from tqdm import tqdm
from math import ceil
import os, sys
import traceback
import numpy as np
from multiprocessing import Pool
import argparse

# from transformers import RobertaTokenizerFast

origin_stop_words = stopwords.words('english')
# stop_words += ['she', "'s", "'re"]
stop_words = origin_stop_words + [' ' + word for word in origin_stop_words]
stop_words.extend([word + ' ' for word in origin_stop_words])
stop_words.extend([' ' + word + ' ' for word in origin_stop_words])

model_type = 'sentence-transformers/roberta-large-nli-stsb-mean-tokens'
layers = 24
# model_type = 'sentence-transformers/bert-base-nli-mean-tokens'
# layers = 12
# model_type = "sentence-transformers/distilbert-base-nli-stsb-mean-tokens"
# layers = 6

max_length = 512

# tokenizer = AutoTokenizer.from_pretrained(model_type)
# tokenized_stop_words = set(chain(*[tokenizer.encode(word, add_special_tokens=False) for word in stop_words]))


# print(tokenized_stop_words)

from enum import Enum


class ComputePlace(Enum):
    bedding = 1
    bte_bedding = 2
    except_outline = 3
    all = 4
    plan_write_except_outline = 5
    bte_target = 6
    tbe_target = 7
    outline_target = 8
    except_outline_target = 9
    bte_outline = 10
    bte_outline_target = 11


def clean_text(f):
    def wrapper(*args, **kwargs):
        text = f(*args, **kwargs)
        special_tokens = ['<|endofcard|>', '<|endofprompt|>', '<|beginofbedding|>',
                          '<|beginoftarget|>', '<|beginofending|>', '<|endoftarget|>',
                          '<|endofoutline|>', '<|sepofoutline|>', '<end_card>']
        for t in special_tokens:
            text = text.replace(t, '')
        return text.strip()

    return wrapper


@clean_text
def process(text, mode: ComputePlace, is_dynamic=False):
    if mode == ComputePlace.bte_outline:
        idx = text.find('<|beginofbedding|>')
        return text[idx:]
    elif mode == ComputePlace.bte_outline_target:
        idx = text.find('<|beginoftarget|>')
        idx2 = text.find('<|beginofending|>')
        if idx == -1 or idx2 == -1:
            return ''
        return text[idx:idx2]
    elif mode == ComputePlace.bedding:
        idx = text.find('<|beginofbedding|>')
        idx2 = text.find('<|beginofending|>')
        return text[idx: idx2]
    elif mode == ComputePlace.bte_bedding:
        idx = text.find('<|beginofbedding|>')
        idx2 = text.find('<|beginoftarget|>')
        return text[idx: idx2]
    elif mode == ComputePlace.except_outline:
        if not is_dynamic:
            idx = text.find('<|endoftarget|>')
            idx2 = text.find('<|beginofbedding|>')
            return text[:idx] + text[idx2:]
        else:
            text = re.sub(r'(<\|endofsentence\|>|<\|beginofbedding\|>|<\|beginofending\|>).*?<\|endofoutline\|>', '',
                          text)
            # print(text)
            return text
            # 去掉<|endofsentence|>或<|beginofbedding|>或<|beginofending|>到<|endofoutline|>中间的所有字符
    elif mode == ComputePlace.except_outline_target:
        idx0 = text.find('<|beginoftarget|>')
        idx = text.find('<|endoftarget|>')
        idx2 = text.find('<|beginofbedding|>')
        text = text[idx0:idx] + text[idx2:]
        return text
        # return text
    elif mode == ComputePlace.all:
        return text
        # idx = text.find('<|beginofending|>')
        # return text[:idx]
    elif mode == ComputePlace.plan_write_except_outline:
        idx = text.find('<|endofoutline|>')
        if idx == -1:
            return text
        else:
            start_idx = idx + len('<|endofoutline|>')
            return text[start_idx:]
    elif mode == ComputePlace.bte_target:
        idx = text.find('<|beginoftarget|>')
        idx2 = text.find("<|beginofending|>")
        return text[idx: idx2]
    elif mode == ComputePlace.tbe_target:
        idx = text.find('<|beginoftarget|>')
        idx2 = text.find("<|beginofbedding|>")
        return text[idx: idx2]
    elif mode == ComputePlace.outline_target:
        idx = text.find('<|beginoftarget|>')
        idx2 = text.find("<|endoftarget|>")
        return text[idx: idx2]
    else:
        raise Exception('Unsupported mode!')


def rouge_one_file(file_name):
    with open(file_name, encoding='utf-8') as f:
        a = json.load(f)
        print('*' * 40 + file_name + '*' * 40)
        print('samples:', len(a))

        hyps = []
        refs = []

        for scene in a[:1500]:
            hyp = ''
            ref = ''
            for entry in scene['entries'][1:]:
                hyp += entry['description']
                if entry['role'] == 'narrator':
                    continue
                for card in entry['cards']:
                    # if len(nltk.word_tokenize(entry['description'])) > 1 and len(
                    #         nltk.word_tokenize(card['description'])) > 1:
                    #     hyps.append(entry['description'])
                    #     refs.append(card['description'])
                    ref += card['description']
            if len(hyp) > 10 and len(ref) > 10:
                hyps.append(hyp)
                refs.append(ref)

        rouge = Rouge()
        print(len(hyps))
        print(len(refs))
        scores = rouge.get_scores(hyps, refs, avg=True)
        print(scores)


def rouge_generate(file_name):
    with open(file_name, encoding='utf-8') as f:
        a = json.load(f)
        print('*' * 40 + file_name + '*' * 40)
        print('samples:', len(a))

        hyps = []
        refs = []

        for scene in a:
            hyp = scene['generated']
            ref = scene['cards']
            if len(hyp) > 10 and len(ref) > 10:
                hyps.append(hyp)
                refs.append(ref)

        rouge = Rouge()
        print(len(hyps))
        print(len(refs))
        scores = rouge.get_scores(hyps, refs, avg=True)
        print(scores)


def top_filter(x, ratio=0.001, return_idx=False):
    l = len(x)
    sel = min(l, ceil(l * ratio))
    y, idx = torch.topk(x, sel)
    if return_idx:
        return y.mean().item(), idx[0].item()
    return y.mean().item()


def bert_score_one_file(file_name, scorer, every_sent=True):
    with open(file_name, encoding='utf-8') as f:
        a = json.load(f)
        print('*' * 40 + file_name + '*' * 40)
        print('samples:', len(a))

        hyps = []
        refs = []

        # print('mask')
        # scorer = BERTScorer(lang='en', model_type=model_type,
        #                     rescale_with_baseline=False, idf=True, num_layers=layers,
        #                     idf_sents=get_idf_sents(), #masked_words=tokenized_stop_words
        #                     )
        # print('no mask')
        # scorer = BERTScorer(lang='en', model_type='roberta-large', rescale_with_baseline=True, idf=True,
        #                     idf_sents=get_idf_sents())
        sum_p = 0
        sum_r = 0
        sum_f = 0
        cnt = 0
        for scene in tqdm(a):
            temp_hyps = []
            temp_refs = []
            if every_sent:
                for entry in scene['entries'][-1:]:
                    hyp = entry['description']
                    hyp_texts = sent_tokenize(hyp)
                    assert entry['role'] != 'narrator'
                    # if entry['role'] == 'narrator':
                    #     continue

                    for card in entry['cards']:
                        if word_tokenize(card['description']):
                            for text in hyp_texts:
                                temp_hyps.append(text)
                                temp_refs.append(card['description'])
            else:
                ref = ''
                for card in scene['entries'][-1]['cards']:
                    ref += card['description']
                hyp = scene['entries'][-1]['description']
                temp_hyps.append(hyp)
                temp_refs.append(ref)

            if len(temp_hyps) > 0:
                try:
                    p, r, f = scorer.score(temp_hyps, temp_refs)
                    sum_p += top_filter(p)
                    sum_r += top_filter(r)
                    sum_f += top_filter(f)
                    cnt += 1
                except Exception as e:
                    traceback.print_exc()
                    pass

        print('precision:', sum_p / cnt)
        print('recall:', sum_r / cnt)
        print('f1:', sum_f / cnt)
        # scorer.plot_example()
        # p, r, f1 = scorer.score(hyps, refs)
        # print('precision:', p.mean())
        # print('recall:', r.mean())
        # print('f1:', f1.mean())


def bert_score_generate(file_name, scorer, every_sent=True, answer=False, max_sent=False):
    with open(file_name, encoding='utf-8') as f:
        a = json.load(f)
        print('*' * 40 + file_name + '*' * 40)
        print('samples:', len(a))

        hyps = []
        refs = []
        # print('mask')
        # scorer = BERTScorer(lang='en', model_type=model_type,
        #                     rescale_with_baseline=False, idf=True, num_layers=layers,
        #                     idf_sents=get_idf_sents(), batch_size=64, #masked_words=tokenized_stop_words
        #                     )
        # print('no mask')
        # scorer = BERTScorer(lang='en', model_type='roberta-large', rescale_with_baseline=True, idf=True,
        #                     idf_sents=get_idf_sents(), batch_size=64)
        sum_p = 0
        sum_r = 0
        sum_f = 0
        cnt = 0

        def process(sent):
            return sent.replace('<|endoftext|>', '').replace('<|beginoftarget|>', '')

        for scene in tqdm(a):
            temp_hyps = []
            temp_refs = []
            if every_sent:
                hyp = scene['generated'].replace('<|endoftext|>', '')
                hyp_texts = sent_tokenize(hyp)
                cards = scene['cards'].split('<end_card>')
                cards = [i for i in cards if word_tokenize(i)]
                # ref = scene['cards'].replace('<end_card>', '')

                for card in cards:
                    for hyp in hyp_texts:
                        temp_hyps.append(hyp)
                        temp_refs.append(card)
            else:
                if answer:
                    hyp = process(scene['answer'])
                else:
                    hyp = process(scene['generated'])

                ref = scene['cards'].replace('<end_card>', '')
                if max_sent:
                    sents = nltk.sent_tokenize(hyp)
                    for sent in sents:
                        temp_hyps.append(sent)
                        temp_refs.append(ref)
                else:
                    temp_hyps.append(hyp)
                    temp_refs.append(ref)

            if len(temp_hyps) > 0:
                try:
                    p, r, f = scorer.score(temp_hyps, temp_refs)
                    sum_p += top_filter(p)
                    sum_r += top_filter(r)
                    sum_f += top_filter(f)
                    cnt += 1
                except Exception as e:
                    traceback.print_exc()
                    # print(e)
                    # exit()
                    pass
            # hyps.append(hyp)
            # refs.append(ref)
            # if len(hyp) > 10 and len(ref) > 10:
            #     hyps.append(hyp)
            #     refs.append(ref)
            # p, r, f1 = scorer.score([hyp], [ref])
            # if r.item() > 0.7:
            #     print('hyp: ', hyp)
            #     print('ref: ', ref)
            #     print('recall:', r)

        print('precision:', sum_p / cnt)
        print('recall:', sum_r / cnt)
        print('f1:', sum_f / cnt)
        # p, r, f1 = scorer.score(hyps, refs)
        # print('precision:', p.mean())
        # print('recall:', r.mean())
        # print('f1:', f1.mean())


def bert_score_target(file_name, scorer, every_sent=True, answer=False, max_sent=False, begin_token='<|beginoftarget|>',
                      end_token='<|endoftarget|>', sort_out_file_name=None, max=False):
    with open(file_name, encoding='utf-8') as f:
        a = json.load(f)
        print('*' * 40 + file_name + '*' * 40)
        print('samples:', len(a))

        hyps = []
        refs = []

        sum_p = 0
        sum_r = 0
        sum_f = 0
        cnt = 0

        def process(sent):
            idx1 = sent.find(begin_token)
            idx2 = sent.find(end_token)
            if idx1 == -1 or idx2 == -1:
                # print(sent)
                return None
            return sent[idx1: idx2].replace(begin_token, '')
            # return sent.replace('<|endoftext|>', '').replace('<|beginoftarget|>', '')

        res = []

        for scene in tqdm(a):
            temp_hyps = []
            temp_refs = []
            if every_sent:
                hyp = scene['generated'].replace('<|endoftext|>', '')
                hyp_texts = sent_tokenize(hyp)
                cards = scene['cards'].split('<end_card>')
                cards = [i for i in cards if word_tokenize(i)]
                # ref = scene['cards'].replace('<end_card>', '')

                for card in cards:
                    for hyp in hyp_texts:
                        temp_hyps.append(hyp)
                        temp_refs.append(card)
            else:
                if answer:
                    hyp = process(scene['answer'])
                    if hyp is None:
                        continue
                else:
                    hyp = process(scene['generated'])
                    if hyp is None:
                        continue

                ref = scene['cards'].replace('<end_card>', '')
                if max_sent:
                    sents = nltk.sent_tokenize(hyp)
                    for sent in sents:
                        temp_hyps.append(sent)
                        temp_refs.append(ref)
                else:
                    temp_hyps.append(hyp)
                    temp_refs.append(ref)

            if len(temp_hyps) > 0:
                try:
                    p, r, f = scorer.score(temp_hyps, temp_refs)
                    sum_p += top_filter(p)
                    sum_r += top_filter(r)
                    sum_f += top_filter(f)
                    cnt += 1
                    scene['recall'] = r.item()
                    res.append(scene)
                except Exception as e:
                    traceback.print_exc()
                    # print(e)
                    # exit()
                    pass
            # hyps.append(hyp)
            # refs.append(ref)
            # if len(hyp) > 10 and len(ref) > 10:
            #     hyps.append(hyp)
            #     refs.append(ref)
            # p, r, f1 = scorer.score([hyp], [ref])
            # if r.item() > 0.7:
            #     print('hyp: ', hyp)
            #     print('ref: ', ref)
            #     print('recall:', r)

        print('precision:', sum_p / len(a))
        print('recall:', sum_r / len(a))
        print('f1:', sum_f / len(a))
        print('cnt:', cnt)

        if sort_out_file_name:
            with open(sort_out_file_name, 'w', encoding='utf-8') as f:
                sorted_res = sorted(res, key=lambda x: x['recall'], reverse=True)
                json.dump(sorted_res, f, indent=1, ensure_ascii=False)


def get_best_sent(file_name, out_file_name, scorer, answer=False, generate=True):
    with open(file_name, encoding='utf-8') as f:
        a = json.load(f)
        print('*' * 40 + file_name + '*' * 40)
        print('samples:', len(a))

        sum_r = 0
        cnt = 0
        answer_sum_r = 0

        def process(sent):
            return sent.replace('<|endoftext|>', '').replace('<|beginoftarget|>', '')

        res = []

        def g(hyp, ref):
            temp_hyps = []
            temp_refs = []
            sents = nltk.sent_tokenize(hyp)
            for sent in sents:
                temp_hyps.append(sent)
                temp_refs.append(ref)
            if len(temp_hyps) > 0:
                p, r, f = scorer.score(temp_hyps, temp_refs)
                dr, best_idx = top_filter(r, return_idx=True)
                # scene['best_sent'] = sents[best_idx]
                # scene['best_sent_score'] = dr
                # res.append(scene)

                return dr, sents[best_idx]
            print('len(temp_hyps)==0!!!')
            return 0, None

        for scene in tqdm(a[:101]):
            # temp_hyps = []
            # temp_refs = []
            ref = scene['cards'].replace('<end_card>', '')
            cnt += 1
            if answer:
                hyp = process(scene['answer'])
                dr, best_sent = g(hyp, ref)
                answer_sum_r += dr
                scene['best_answer_sent'] = best_sent
                scene['best_answer_sent_recall'] = dr

            if generate:
                hyp = process(scene['generated'])
                dr, best_sent = g(hyp, ref)
                sum_r += dr
                scene['best_generate_sent'] = best_sent
                scene['best_generate_sent_recall'] = dr

            res.append(scene)

        if generate:
            print('generate recall:', sum_r / cnt)
        if answer:
            print('answer recall', answer_sum_r / cnt)
        with open(out_file_name, 'w', encoding='utf-8') as outf:
            json.dump(res, outf, ensure_ascii=False, indent=1)


def get_idf_sents(path='../data'):
    def get_idf_docs():
        if os.path.exists(path + '/train_doc_idf.json'):
            with open(path + '/train_doc_idf.json', encoding='utf-8') as f:
                return json.load(f)
        file_name = path + '/train.json'
        with open(file_name, encoding='utf-8') as f:
            a = json.load(f)
            # print('*' * 40 + file_name + '*' * 40)
            # print('samples:', len(a))

            hyps = []
            refs = []
            for scene in a:

                for entry in scene['entries']:
                    hyps.append(entry['description'])
                    # hyp += entry['description']
                for card in scene['entries'][-1]['cards']:
                    refs.append(card['description'])
            with open('../data/train_doc_idf.json', 'w', encoding='utf-8') as fi:
                json.dump(hyps + refs, fi, ensure_ascii=False)
            print('finish get_idf_sent')

            return hyps + refs

    def get_idf_split_sent():
        if os.path.exists('train_sent_idf.json'):
            with open('train_sent_idf.json', encoding='utf-8') as f:
                return json.load(f)
        file_name = 'train.json'
        with open(file_name, encoding='utf-8') as f:
            a = json.load(f)
            # print('*' * 40 + file_name + '*' * 40)
            # print('samples:', len(a))

            hyps = []
            refs = []
            for scene in a:

                for entry in scene['entries']:
                    # hyps.append(entry['description'])
                    hyps.extend(sent_tokenize(entry['description']))
                    # hyp += entry['description']
                for card in scene['entries'][-1]['cards']:
                    # refs.append(card['description'])
                    refs.extend(sent_tokenize(card['description']))

            with open('train_sent_idf.json', 'w', encoding='utf-8') as fi:
                json.dump(hyps + refs, fi, ensure_ascii=False)
            print('finish get_idf_sent')
            return hyps + refs

    # return get_idf_split_sent()
    return get_idf_docs()


def prune_sent(sent):
    word_list = nltk.word_tokenize(sent)
    filtered_word_list = [word for word in word_list if word not in stop_words]
    return ' '.join(filtered_word_list)


def plot_score(cards, text, scorer):
    from matplotlib import pyplot as plt
    sents = nltk.sent_tokenize(text)
    card_text = ' '.join(cards)

    def f(t):
        p, r, f1 = scorer.score([t], [card_text])
        return r.mean().item()

    scores = []
    for i in range(len(sents)):
        t = ' '.join(sents[:i + 1])
        scores.append(f(t))
        print(i, sents[i])

    xs = np.arange(0, len(scores), 1)
    plt.plot(xs, scores)
    plt.show()


def mean_ave_precision(file_name):
    def avep(hyp, ref):
        cnt = 0
        res = 0
        for idx, word in enumerate(hyp):
            if word in ref:
                cnt += 1
                res += cnt / (idx + 1)
        if cnt > 0:
            res /= cnt
        return res

    with open(file_name, encoding='utf-8') as f:
        score = 0
        a = json.load(f)
        for scene in tqdm(a):
            hyp = scene['predict_outline']
            ref = scene['outline']
            score += avep(hyp, ref)

        print('map score:', score / len(a))


def compute_sent_bleu(tup):
    return sentence_bleu(nltk.word_tokenize(tup[0]), nltk.word_tokenize(tup[1]),
                         weights=(0.5, 0.5, 0, 0),
                         smoothing_function=SmoothingFunction().method1)


def rouge_bleu_after_target(file_name):
    def process(text):
        idx = text.find('<|beginofbedding|>')
        idx2 = text.find('<|beginofending|>')
        if idx == -1 or idx2 == -1:
            return None
        return text[idx:idx2]

    with open(file_name, encoding='utf-8') as f:
        a = json.load(f)
        hyps = []
        refs = []
        for scene in a:
            hyp = process(scene['generated'])
            ref = process(scene['answer'])
            if hyp is None or ref is None:
                continue
            hyps.append(process(scene['generated']))
            refs.append(process(scene['answer']))

    rouge = Rouge()
    print(len(hyps))
    # print(len(refs))
    scores = rouge.get_scores(hyps, refs, avg=True)
    print('rouge score for {}: {}'.format(file_name, scores))

    # print('begin compute bleu score')
    # score = 0
    # with Pool(os.cpu_count()) as pool:
    #     bleu_scores = pool.map(compute_sent_bleu, zip(refs, hyps))
    # # for hyp, ref in zip(hyps, refs):
    # #     score += sentence_bleu(nltk.word_tokenize(ref), nltk.word_tokenize(hyp))
    # for s in bleu_scores:
    #     score += s
    # score /= len(hyps)
    # print(len(hyps))
    # print('bleu score for {}: {}'.format(file_name, score))


def compute_kw_usage(file_name):
    def process(text):
        idx = text.find('<|beginofbedding|>')
        if idx == -1:
            return None
        return text[idx:]

    with open(file_name, encoding='utf-8') as f:
        a = json.load(f)
        res = 0
        cnt = 0
        for scene in a:
            text = process(scene['generated'])
            if text is None:
                continue
            cnt += 1
            temp = 0
            for word in scene['predict_outline']:
                if word in text:
                    temp += 1
            res += temp / len(scene['predict_outline'])

        print('cnt:{}, res:{}'.format(cnt, res / cnt))


'''
following are useful api
'''


def compute_self_bleu(file_name, mode: ComputePlace = ComputePlace.except_outline, bleu_n=1):
    weights = [(1, 0, 0, 0), (0.5, 0.5, 0, 0), (1.0 / 3, 1.0 / 3, 1.0 / 3), (0.25, 0.25, 0.25, 0.25)]
    weight = weights[bleu_n - 1]

    wrong_num = 0

    def sample_self_bleu(text):
        sents = sent_tokenize(text)
        refs = []
        for sent in sents:
            refs.append(word_tokenize(sent))

        score = 0
        try:
            for idx, sent in enumerate(refs):
                # print(sent)
                # print(refs[:idx] + refs[idx + 1:])
                score += sentence_bleu(refs[:idx] + refs[idx + 1:], sent, weights=weight,
                                       smoothing_function=SmoothingFunction().method1)
            return score / len(sents)

        except:
            nonlocal wrong_num
            wrong_num += 1
            return 0

        # if len(sents) == 0:
        #     return 0

    with open(file_name, encoding='utf-8') as f:
        a = json.load(f)
        hyps = []
        for scene in a:
            hyps.append(process(scene['generated'], mode))

        score = 0
        for hyp in hyps:
            score += sample_self_bleu(hyp)

        score /= len(a)

        print(f'self-bleu score for {file_name}: {score}, bleu-{bleu_n}')
        print(f"wrong_num:{wrong_num}")


def compute_rouge(file_name, mode: ComputePlace = ComputePlace.bedding):
    # assert mode == ComputePlace.bedding or mode == ComputePlace.except_outline or mode == ComputePlace.all

    def rouge_score(hyps, refs):
        rouge = Rouge()
        scores = rouge.get_scores(hyps, refs, avg=True)
        return scores

    with open(file_name, encoding='utf-8') as f:
        a = json.load(f)
        hyps = []
        refs = []
        for scene in a:
            x = process(scene['generated'], mode)
            y = process(scene['answer'], mode)
            if x and y:
                hyps.append(x)
                refs.append(y)

        score = rouge_score(hyps, refs)

        for k in score:
            d = score[k]
            for w in d:
                d[w] = d[w] * len(hyps) / len(a)

        print(f'rouge score for {file_name}: {score}\n len(hyps): {len(hyps)}')


def bleu_bedding(file_name, bleu_n=1, begin_token='<|beginofbedding|>', end_token='<|beginofending|>'):
    def process(text):
        idx = text.find(begin_token)
        idx2 = text.find(end_token)
        return text[idx: idx2]

    weights = [(1, 0, 0, 0), (0.5, 0.5, 0, 0), (1.0 / 3, 1.0 / 3, 1.0 / 3), (0.25, 0.25, 0.25, 0.25)]
    weight = weights[bleu_n - 1]

    with open(file_name, encoding='utf-8') as f:
        a = json.load(f)
        hyps = []
        refs = []
        for scene in a:
            hyps.append(process(scene['generated']))
            refs.append(process(scene['answer']))

        score = 0
        for hyp, ref in tqdm(zip(hyps, refs)):
            score += sentence_bleu([nltk.word_tokenize(ref)], nltk.word_tokenize(hyp),
                                   weights=weight, smoothing_function=SmoothingFunction().method1)

        score /= len(hyps)

        print(f'bleu score for {file_name}: {score}, bleu-{bleu_n}')


def compute_bleu(file_name, mode: ComputePlace = ComputePlace.bedding, bleu_n=1, is_dynamic=False):
    # assert mode == ComputePlace.bedding or mode == ComputePlace.except_outline or mode == ComputePlace.all

    weights = [(1, 0, 0, 0), (0.5, 0.5, 0, 0), (1.0 / 3, 1.0 / 3, 1.0 / 3), (0.25, 0.25, 0.25, 0.25)]
    weight = weights[bleu_n - 1]

    with open(file_name, encoding='utf-8') as f:
        a = json.load(f)
        hyps = []
        refs = []
        for scene in a:
            hyp = process(scene['generated'], mode, is_dynamic)
            ref = process(scene['answer'], mode, is_dynamic)
            hyps.append(hyp)
            refs.append(ref)
            # print(f"hyp_origin:{scene['generated']}\nhyp:{hyp}\nref_origin:{scene['answer']}\nref:{ref}")
            # break
        # hyps.append(process)
        # print('hyps = ', hyps[0])
        # print('refs = ', refs[0])
        score = 0
        for hyp, ref in tqdm(zip(hyps, refs)):
            score += sentence_bleu([nltk.word_tokenize(ref)], nltk.word_tokenize(hyp),
                                   weights=weight, smoothing_function=SmoothingFunction().method1)

        score /= len(a)

        print(f'bleu score for {file_name}: {score}, bleu-{bleu_n}')


def compute_distinct(file_name, mode: ComputePlace = ComputePlace.except_outline, distinct_n=1, generate=True):
    def sample_distinct(text):
        n_grams = ngrams(word_tokenize(text), distinct_n)
        n_grams = list(n_grams)
        return n_grams
        # if len(n_grams) == 0:
        #     return 0
        # return len(set(n_grams)) / len(n_grams)

    with open(file_name, encoding='utf-8') as f:
        a = json.load(f)
        hyps = []
        for scene in a:
            if generate:
                hyps.append(process(scene['generated'], mode))
            else:
                hyps.append(process(scene['answer'], mode))
        score = 0
        n_grams = []
        for hyp in hyps:
            # score += sample_distinct(hyp)
            n_grams.extend(sample_distinct(hyp))

        # score /= len(a)

        print(f'distinct score for {file_name}: {len(set(n_grams)) / len(n_grams)}, distinct-{distinct_n}')


def compute_outlen(file_name, mode: ComputePlace = ComputePlace.except_outline, generate=True):
    with open(file_name, encoding='utf-8') as f:
        a = json.load(f)
        hyps = []
        for scene in a:
            if generate:
                hyps.append(process(scene['generated'], mode))
            else:
                hyps.append(process(scene['answer'], mode))
        sum_len = 0
        for hyp in hyps:
            sum_len += len(hyp.split())

        print(f'out len for {file_name}: {sum_len / len(a)}')


def compute_repetition(file_name, mode: ComputePlace = ComputePlace.except_outline, generate=True):
    def sample_repetition(text):
        n_grams = list(ngrams(word_tokenize(text), 4))
        w = set(n_grams)
        if len(w) == len(n_grams):
            return 0
        else:
            return 1

    with open(file_name, encoding='utf-8') as f:
        a = json.load(f)
        hyps = []
        for scene in a:
            if generate:
                hyps.append(process(scene['generated'], mode))
            else:
                hyps.append(process(scene['answer'], mode))
        score = 0

        for hyp in hyps:
            # score += sample_distinct(hyp)
            score += sample_repetition(hyp)

        # score /= len(a)

        print(f'repetition score for {file_name}: {score / len(a)}')


def compute_outline_prf(file_name, is_dynamic=False):
    print('is_dynamic = ', is_dynamic)

    def extract_outline(text):
        if not is_dynamic:
            idx = text.find('<|endoftarget|>')
            idx2 = text.find('<|beginofbedding|>')
            t = text[idx: idx2].replace('<|endoftarget|>', '')
            words = t.split('<|sepofoutline|>')
            res = [word.strip() for word in words]
            res = list(set(res))
            return res
        else:
            outlines = re.findall(
                '(<\|endofsentence\|>|<\|beginofbedding\|>|<\|beginofending\|>)(.*?)<\|endofoutline\|>', text)
            res = []
            for outline in outlines:
                res += [i.strip() for i in outline[1].split('<|sepofoutline|>')]
            res = list(set(res))
            return res

    wrong_num = 0
    hyp_ref_wrong_num = 0

    def compute_prf(hyp, ref):
        try:
            p_cnt = 0
            for word in hyp:
                if word in ref:
                    p_cnt += 1
            r_cnt = 0
            for word in ref:
                if word in hyp:
                    r_cnt += 1
            p = p_cnt / len(hyp)
            r = r_cnt / len(ref)
            nonlocal hyp_ref_wrong_num
            if len(hyp) == 0 and len(ref) == 0:
                hyp_ref_wrong_num += 1
            f1 = 2 * (p * r) / (p + r)
            return p, r, f1
        except:
            nonlocal wrong_num
            wrong_num += 1
            return 0, 0, 0

    with open(file_name, encoding='utf-8') as f:
        a = json.load(f)
        p = 0
        r = 0
        f1 = 0
        for scene in a:
            hyp = extract_outline(scene['generated'])
            ref = extract_outline(scene['answer'])
            # print(hyp)
            # print(ref)
            # exit()
            _p, _r, _f1 = compute_prf(hyp, ref)

            p += _p
            r += _r
            f1 += _f1

        print(f"file_name:{file_name}, precision:{p / len(a)},recall:{r / len(a)},f1:{f1 / len(a)}")
        print(
            f"file_name:{file_name}, precision:{p / (len(a) - wrong_num)},recall:{r / (len(a) - wrong_num)},"
            f"f1:{f1 / (len(a) - wrong_num)}")

        print(f"wrong_num:{wrong_num}")
        print(f"hyp_ref_wrong_num:{hyp_ref_wrong_num}")


def compute_outline_repetition(file_name):
    def extract_outline(text):
        idx = text.find('<|endoftarget|>')
        idx2 = text.find('<|beginofbedding|>')
        t = text[idx: idx2].replace('<|endoftarget|>', '')
        # words = t.split('<|sepofoutline|>')
        words = t.split()
        res = [word.strip() for word in words]
        res = [word for word in words if not ('<|sepofoutlinesent|>' in word or '<|sepofoutline|>' in word)]
        return res

    generate_cnt = {'total': 0, 'repetition': 0}
    label_cnt = {'total': 0, 'repetition': 0}

    def repetition_cnt(outline, counter):
        counter['total'] += len(outline)
        counter['repetition'] += len(outline) - len(set(outline))

    with open(file_name, encoding='utf-8') as f:
        a = json.load(f)

    for scene in a:
        generate_outline = extract_outline(scene['generated'])
        label_outline = extract_outline(scene['answer'])
        print('generate_outline = ', generate_outline)
        print('label outline = ', label_outline)
        repetition_cnt(generate_outline, generate_cnt)
        repetition_cnt(label_outline, label_cnt)

    print(
        f"label outline: total = {label_cnt['total']}, repetition = {label_cnt['repetition']}, {label_cnt['repetition'] * 100 / label_cnt['total']}%")
    print(
        f"generate outline: total = {generate_cnt['total']}, repetition = {generate_cnt['repetition']}, {generate_cnt['repetition'] * 100 / generate_cnt['total']}%")


def compute_bert_score(file_name, scorer, every_sent=True, answer=False, max_sent=False, mode=None,
                       sort_out_file_name=None, on_bedding=False):
    with open(file_name, encoding='utf-8') as f:
        a = json.load(f)
        print('*' * 40 + file_name + '*' * 40)
        print('samples:', len(a))

        hyps = []
        refs = []

        sum_p = 0
        sum_r = 0
        sum_f = 0
        cnt = 0

        res = []

        illegal_cnt = 0

        for scene in tqdm(a):
            temp_hyps = []
            temp_refs = []
            if every_sent:
                hyp = scene['generated'].replace('<|endoftext|>', '')
                hyp_texts = sent_tokenize(hyp)
                cards = scene['cards'].split('<end_card>')
                cards = [i for i in cards if word_tokenize(i)]
                # ref = scene['cards'].replace('<end_card>', '')

                for card in cards:
                    for hyp in hyp_texts:
                        temp_hyps.append(hyp)
                        temp_refs.append(card)
            else:
                if answer:
                    hyp = process(scene['answer'], mode)
                    if not hyp:
                        continue
                else:
                    hyp = process(scene['generated'], mode)
                    if not hyp:
                        illegal_cnt += 1
                        continue

                if not on_bedding:
                    ref = scene['cards'].replace('<end_card>', '')
                else:
                    ref = process(scene['answer'], mode)

                if max_sent:
                    sents = nltk.sent_tokenize(hyp)
                    for sent in sents:
                        temp_hyps.append(sent)
                        temp_refs.append(ref)
                else:
                    # print(hyp)
                    temp_hyps.append(hyp)
                    temp_refs.append(ref)

            if len(temp_hyps) > 0:
                try:
                    p, r, f = scorer.score(temp_hyps, temp_refs)
                    # print('r = ', r)
                    # print('type = ', type(r))
                    sum_p += top_filter(p)
                    dr = top_filter(r)
                    # print('dr = ', dr)
                    if is_plan_write:
                        dr = max(0, dr)
                    sum_r += dr
                    sum_f += top_filter(f)
                    cnt += 1
                    scene['recall'] = dr
                    res.append(scene)
                except Exception as e:
                    # scene['recall'] = 0
                    traceback.print_exc()
                    # print(e)
                    # exit()
                    pass

        print('precision:', sum_p / len(a))
        print('recall:', sum_r / len(a))
        print('f1:', sum_f / len(a))
        print('cnt:', cnt)
        print('illegal_cnt', illegal_cnt)

        if sort_out_file_name:
            with open(sort_out_file_name, 'w', encoding='utf-8') as f:
                sorted_res = sorted(res, key=lambda x: x['recall'], reverse=True)
                json.dump(sorted_res, f, indent=1, ensure_ascii=False)


if __name__ == '__main__':
    parser = argparse.ArgumentParser()
    parser.add_argument('--path', type=str, default='path to result file')
    parser.add_argument('--bleu', action='store_true')
    parser.add_argument('--recall', action='store_true')
    parser.add_argument('--rouge', action='store_true')
    parser.add_argument('--outline_prf', action='store_true')
    parser.add_argument('--distinct', action='store_true')
    parser.add_argument('--out_len', action='store_true')
    parser.add_argument('--repetition', action='store_true')
    parser.add_argument('--outline_repetition', action='store_true')

    parser.add_argument('--max', action='store_true')  # if true, calculate max bertscore recall based on whole story
    parser.add_argument('--label', action='store_true')  # if true, calculation are based on answer
    parser.add_argument('--bedding', action='store_true')  # if true, bertscore calculation are based on bedding
    parser.add_argument('--all', action='store_true')

    args = parser.parse_args()

    is_tbe = 'tbe' in args.path
    is_bte = 'bte' in args.path or 'plan_ahead' in args.path or 'fusion' in args.path or 's2s' in args.path
    is_outline = 'outline' in args.path or 'wo_knowledge' in args.path
    is_plan_write = 'plan_write' in args.path
    is_dynamic = 'dynamic' in args.path and not is_outline
    is_target = 'target' in args.path
    is_bte_outline = 'bte_outline' in args.path
    # is_target = False
    print(
        f"is_tbe = {is_tbe}, is_bte = {is_bte}, is_outline = {is_outline}, is_plan_write = {is_plan_write}, is_dynamic = {is_dynamic}, is_bte_outline = {is_bte_outline}")
    assert is_tbe or is_bte or is_outline or is_plan_write or is_bte_outline
    is_dynamic = False
    path = args.path
    if args.bleu:
        for i in range(1, 3):
            # if is_bte or is_plan_write:
            #     bleu_bedding(path, bleu_n=i, end_token='<|beginoftarget|>')
            # else:
            #     bleu_bedding(path, bleu_n=i)
            if is_bte_outline:
                compute_bleu(path, bleu_n=i, mode=ComputePlace.bte_outline)
            elif is_tbe:
                compute_bleu(path, bleu_n=i, mode=ComputePlace.all)
            elif is_bte:
                compute_bleu(path, bleu_n=i, mode=ComputePlace.all if args.all else ComputePlace.bte_bedding,
                             is_dynamic=is_dynamic)
            elif is_plan_write:
                compute_bleu(path, bleu_n=i,
                             mode=ComputePlace.plan_write_except_outline if args.all else ComputePlace.bte_bedding,
                             is_dynamic=is_dynamic)
            else:
                if not is_target:
                    compute_bleu(path, bleu_n=i, mode=ComputePlace.except_outline if args.all else ComputePlace.bedding,
                                 is_dynamic=is_dynamic)
                else:
                    compute_bleu(path, bleu_n=i,
                                 mode=ComputePlace.except_outline_target if args.all else ComputePlace.bedding,
                                 is_dynamic=is_dynamic)
    elif args.recall:
        scorer = BERTScorer(lang='en', model_type=model_type, rescale_with_baseline=False, idf=True, num_layers=layers,
                            nthreads=os.cpu_count(), idf_sents=get_idf_sents(), batch_size=64)
        print('finish create scorer')
        if not args.max and not args.bedding:  # 算target和card

            def get_sort_out_file_name(path):
                idx = path.rfind('.')
                return path[:idx] + '_sort.json'


            if is_bte_outline:
                compute_bert_score(path, scorer, every_sent=False, answer=args.label,
                                   sort_out_file_name=get_sort_out_file_name(path), mode=ComputePlace.bte_outline_target)

            elif is_bte or is_plan_write:
                compute_bert_score(path, scorer, every_sent=False, answer=args.label,
                                   sort_out_file_name=get_sort_out_file_name(path), mode=ComputePlace.bte_target)
            # elif is_plan_write:
            #     bert_score_target(path, scorer, every_sent=False, answer=args.label,
            #                       sort_out_file_name=None, begin_token='<|beginoftarget|>', end_token='<|beginofending|>')
            elif is_tbe:
                compute_bert_score(path, scorer, every_sent=False, answer=args.label,
                                   sort_out_file_name=get_sort_out_file_name(path), mode=ComputePlace.tbe_target)
            elif is_outline:
                compute_bert_score(path, scorer, every_sent=False, answer=args.label,
                                   sort_out_file_name=get_sort_out_file_name(path), mode=ComputePlace.outline_target)
            else:
                print('invalid path!')

        elif args.max:  # 整个story和card算，取max
            if is_bte_outline:
                compute_bert_score(path, scorer, every_sent=False, answer=args.label,
                                   sort_out_file_name=None, mode=ComputePlace.bte_outline, max_sent=True)
            if is_bte:
                compute_bert_score(path, scorer, every_sent=False, answer=args.label,
                                   sort_out_file_name=None, mode=ComputePlace.all, max_sent=True)
            elif is_plan_write:
                compute_bert_score(path, scorer, every_sent=False, answer=args.label,
                                   sort_out_file_name=None, mode=ComputePlace.plan_write_except_outline, max_sent=True)
            else:
                compute_bert_score(path, scorer, every_sent=False, answer=args.label,
                                   sort_out_file_name="sorted_bert_score.json", mode=ComputePlace.except_outline,
                                   max_sent=True)

        elif args.bedding:  # bedding和bedding算
            if is_bte or is_plan_write:
                compute_bert_score(path, scorer, every_sent=False, answer=args.label,
                                   sort_out_file_name=None, mode=ComputePlace.bte_bedding, max_sent=False,
                                   on_bedding=True)
            else:
                compute_bert_score(path, scorer, every_sent=False, answer=args.label,
                                   sort_out_file_name=None, mode=ComputePlace.bedding, max_sent=False, on_bedding=True)
        else:
            raise Exception("invalid argument config")

    elif args.rouge:
        if is_bte:
            compute_rouge(path, mode=ComputePlace.all if args.all else ComputePlace.bte_bedding)
        elif is_plan_write:
            compute_rouge(path, mode=ComputePlace.plan_write_except_outline if args.all else ComputePlace.bte_bedding)
        else:
            compute_rouge(path, mode=ComputePlace.except_outline if args.all else ComputePlace.bedding)

    elif args.distinct:
        for n in range(1, 5):
            if args.label:
                compute_distinct(path, mode=ComputePlace.all, distinct_n=n, generate=False)
            else:
                if is_plan_write:
                    compute_distinct(path, mode=ComputePlace.plan_write_except_outline, distinct_n=n)
                if is_bte or is_tbe:
                    compute_distinct(path, mode=ComputePlace.all, distinct_n=n)
                elif is_outline:
                    compute_distinct(path, mode=ComputePlace.except_outline, distinct_n=n)
    elif args.out_len:
        if args.label:
            compute_outlen(path, mode=ComputePlace.all, generate=False)
        else:
            if is_plan_write:
                compute_outlen(path, mode=ComputePlace.plan_write_except_outline)
            if is_bte or is_tbe:
                compute_outlen(path, mode=ComputePlace.all)
            elif is_outline:
                compute_outlen(path, mode=ComputePlace.except_outline)
    elif args.repetition:
        if args.label:
            compute_repetition(path, mode=ComputePlace.all, generate=False)
        else:
            if is_plan_write:
                compute_repetition(path, mode=ComputePlace.plan_write_except_outline)
            if is_bte or is_tbe:
                compute_repetition(path, mode=ComputePlace.all)
            elif is_outline:
                compute_repetition(path, mode=ComputePlace.except_outline)

    elif args.outline_prf:
        compute_outline_prf(path, is_dynamic=is_dynamic)

    elif args.outline_repetition:
        compute_outline_repetition(path)

    else:
        print('invalid execution')
