from collections import defaultdict
import pdb
import math
from nltk.translate.bleu_score import sentence_bleu
from nltk.translate.bleu_score import corpus_bleu
from nltk.translate.bleu_score import SmoothingFunction
from bert_score import score
smooth = SmoothingFunction()

# 用于计算输出的distinct得分
def calc_diversity(txt):
    tokens = [0.0,0.0,0]
    types = [defaultdict(int),defaultdict(int),defaultdict(int)]
    for words in txt:
        for n in range(3):
            for idx in range(len(words)-n):
                ngram = ' '.join(words[idx:idx+n+1])# 获取n+1gram的词
                types[n][ngram] = 1 # 将n+1gram存入词典中
                tokens[n] += 1 # 获取总的词语个数
    div1 = len(types[0].keys())/tokens[0] # 所有n+1gram的种类个数 除以所有词的个数
    div2 = len(types[1].keys())/tokens[1]
    return [div1, div2],'DIS-1 = {:.4f} \t DIS-2 = {:.4f}'.format(div1,div2)


def get_all_dist(file_path):
    ref = open(file_path)
    res_txt = open(file_path+'.dialogue','w')
    fline = []
    for d in ref:
        fline.append(d)
    fdic={}
    all_sys_txt = []
    all_ref_txt = []
    filter_txt = []
    all_ref_txt_d =[]
    score_sum =0
    pcount=0
    for index in range(len(fline)):
        sen = fline[index]
        if 'S-'==sen[:2]:
            s=' '.join(sen.split()[1:])
            t=' '.join(fline[index+1].split()[1:])
            h=' '.join(fline[index+2].split()[2:])
            all_sys_txt.append(h.split())
            all_ref_txt.append([t.split()])
            all_ref_txt_d.append(t.split())
            if s not in fdic:
                fdic[s]={'ref':[],'res':''}
                fdic[s]['ref'].append(t)
                fdic[s]['res']=h
                filter_txt.append(h.split())
            else:
                fdic[s]['ref'].append(t)
        if 'P-'==sen[:2]:
            lk = sen.split()[1:]
            lk = [float(x)*math.log(2) for x in lk]
            score_sum+=sum(lk)
            pcount+=len(lk)
    count =0
    max_bleu=[]
    for k in fdic.keys():
        count+=1
        sys = fdic[k]['res'].split()# 得到模型输出
        ref = []
        res_txt.write('\n'+'*'*100+'\n')
        res_txt.write('对话输入：'+k+'\n')
        res_txt.write('模型输出：'+fdic[k]['res']+'\n')
        res_txt.write('参考例句：'+'\n')
        for r in fdic[k]['ref']:
            ref.append(r.split())# 获取所有的参考例句
            res_txt.write('         '+r+'\n')
        def get_max_bleu(ref,sys):
            rs = [sentence_bleu([r],sys,weights=(0.5, 0.5),smoothing_function=smooth.method1)for r in ref]
            return max(rs)
        max_bleu.append(get_max_bleu(ref,sys)) 
        def get_max_bert_score(ref,sys):
            P, R, F1 = score([sys], [ref[0]], lang="zh", verbose=True)
            pdb.set_trace()
        # get_max_bert_score(ref,sys)

    rs_str = ''
    d1,t1=calc_diversity(all_sys_txt)
    _,t2=calc_diversity(filter_txt)
    df,tf = calc_diversity(all_ref_txt_d)
    b2 = corpus_bleu(all_ref_txt,all_sys_txt,weights=(0.5, 0.5),smoothing_function=smooth.method1)
    b4 = corpus_bleu(all_ref_txt,all_sys_txt,smoothing_function=smooth.method1)
    
    mean_max_bleu = sum(max_bleu)/len(max_bleu)

    rs_str+='\nPPL = {:.2f}\n'.format(2**(-score_sum / pcount / math.log(2)))
    rs_str+='BLEU-2gram = {:.2f}\n'.format(b2*100)
    rs_str+='BLEU-4gram = {:.2f}\n'.format(b4*100)
    rs_str+=t1+'\n'
    # rs_str+='不去重得分：'+t1+'\n'
    rs_str+='去重后得分：'+t2+'\n'
    rs_str+='MAX Sentence BLEU-2gram Mean = {:.2f}\n'.format(mean_max_bleu*100)

    res_txt.write(rs_str)
    res_txt.close()
    print(file_path)
    print(rs_str)

    return rs_str,d1
# pdb.set_trace()

import os

model_dir = 'base_model_reddit_sentence_oracles_decay_20_t4'
dis_file = model_dir+'.metric'
dis = open(model_dir+'/'+dis_file,'w+')
all_files = [f for f in os.listdir(model_dir)]
all_files.sort()
max_d = 0
max_f = ''
max_s = ''
for f in all_files:
    # pdb.set_trace()
    if f[-11:]=='.result.txt' and 'checkpoint' in f:
      print(f)
      dis.write(f)
      rs,d1 = get_all_dist(model_dir+'/'+f)
      if d1[1]>max_d:
          max_d=d1[1]
          max_f=f 
          max_s=rs
      dis.write(rs)
dis.write('\nMAX:'+max_f+max_s)
dis.close()


# file_path = 'base_model_reddit_sentence_oracles_decay_20_t4/checkpoint48.pt.result.txt'
# get_all_dist(file_path)




# print('BLEU-1gram:',sentence_bleu(ref, sys,weights=(1, 0, 0, 0),smoothing_function=smooth.method1))
        # print('BLEU-2gram:',sentence_bleu(ref, sys,weights=(0.5, 0.5, 0, 0),smoothing_function=smooth.method1))
        # print('BLEU-3gram:',sentence_bleu(ref, sys,weights=(0.33, 0.33, 0.33, 0),smoothing_function=smooth.method1))
        # print('BLEU-4gram:',sentence_bleu(ref, sys,weights=(0.25, 0.25, 0.25, 0.25),smoothing_function=smooth.method1))

# from transformers import AutoTokenizer, AutoModelForMaskedLM

# tokenizer = AutoTokenizer.from_pretrained("bert-base-cased")

# model = AutoModelForMaskedLM.from_pretrained("bert-base-cased")
