import json
from utils import *
from bert_score import BERTScorer
from metric import get_idf_sents
import os
from tqdm import tqdm
import nltk

from munkres import Munkres, DISALLOWED, UnsolvableMatrix, print_matrix
import munkres
import sys
sys.path.append('..')
from classification.predictor import RobertaPredictor

import numpy as np


def km(matrix, max=5):
    # print_matrix(matrix)
    matrix = munkres.make_cost_matrix(
        matrix, lambda cost: max - cost  # 最大化转化为最小化
    )
    indices = m.compute(matrix)
    # print(indices)
    result = [i[1] for i in indices]
    return result

def max_pred(matrix):
    return np.argmax(matrix, axis=-1)



def cal_delta(matrix):
    match = []
    mismatch = []
    line_mismatch = []
    num = len(matrix)
    for i in range(num):
        tmp_negative_sum = []
        for j in range(num):
            if i == j:
                match.append(matrix[i][j])
            else:
                mismatch.append(matrix[i][j])
                tmp_negative_sum.append(matrix[i][j])
        line_mismatch.append(np.mean(tmp_negative_sum))
    assert len(match) == 10
    assert len(line_mismatch) == 10
    assert len(mismatch) == 90
    all_delta = (np.array(match) - np.array(line_mismatch)).tolist()
    delta = np.mean(np.array(match) - np.array(line_mismatch)) # TODO: 算方差
    positive_score = np.mean(match)
    negative_score = np.mean(mismatch)

    return delta, positive_score, negative_score, all_delta


def generate_data(in_path, out_path):
    with open(in_path, encoding='utf-8') as f:
        d = json.load(f)
    from copy import deepcopy
    num = 10
    for i in tqdm(range(0, len(d), num)):
        end = i + num
        if end > len(d):
            break
        persona_kws = {}
        for j in range(i, end):
            persona_kws[j] = d[j]['persona_kws']
        c_kws = d[i]['context_kws'][len(persona_kws[i]):]
        for j in range(i + 1, end):
            cards = d[j]['entries'][-1]['cards']
            # kws = d[j]['persona_kws']
            d[j] = deepcopy(d[i])
            d[j]['entries'][-1]['cards'] = cards
            d[j]['context_kws'] = persona_kws[j] + c_kws
            # d[j]['persona_kws'] = kws

    with open(out_path, 'w', encoding='utf-8') as f:
        json.dump(d, f, ensure_ascii=False, indent=1)


def score_one_model(path, mode, scorer, num=10):
    with open(path, encoding='utf-8') as f:
        d = json.load(f)

    personas = []
    outputs = []
    for x in d:
        persona, prompt = process_text(x['prompt'], TextType.prompt)
        output = process_text(x['generated'], mode)
        personas.append(persona)
        outputs.append(output)
        # print(x)
        # print(persona)
        # print(output)
        # break
    output_sents = []
    for o in outputs:
        # output_sents.append(nltk.sent_tokenize(o))
        output_sents.append([o])
    acc = 0
    sample_cnt = 0
    delta, positive_score, negative_score = 0, 0, 0
    all_delta = []
    for i in tqdm(range(0, len(personas), num)):
        end = i + num
        if end > len(personas):
            break
        sample_cnt += 1
        scores = []
        idxs = []
        temp_personas = []
        temp_outputs = []
        for j in range(i, end): #10 * 10, 相邻10个是同一个persona和10个不同的故事
            # score = []
            for q in range(i, end):
                for sent in output_sents[q]:
                    temp_personas.append(personas[j])
                    temp_outputs.append(sent)
                idxs.append(len(temp_personas))

        r = scorer.score(temp_outputs, temp_personas) # 得到100个打分
        start = 0
        temp = []
        for q, idx in enumerate(idxs):
            if start == idx:
                temp.append(0)
            else:
                temp.append(r[start: idx].max().item())
            if q % num == num - 1:
                scores.append(temp)
                temp = []
            start = idx
        # [1, 100] -> [10, 10]
        # print(scores)
        # preds = km(scores)
        # TODO: 每一组，正样本score: a[i][i]   负样本score: mean([a[i][j] for j in range(10) if j != i])
        # 最后所有组取平均
        preds = max_pred(scores)
        tmp_delta, tmp_positive_score, tmp_negative_score, tmp_all_delta = cal_delta(scores)
        # print(preds)
        w = 0
        for i, val in enumerate(preds):
            if i == val:
                w += 1
        acc += w / len(preds)

        delta += tmp_delta
        positive_score += tmp_positive_score
        negative_score += tmp_negative_score
        all_delta += tmp_all_delta

    with open(f"delta/{path.split('/')[-1]}", 'w') as f:
        json.dump(all_delta, f)
    # TODO: 对all_delta算方差
    print(f"num:{num}, path:{path}, acc:{acc / sample_cnt}, delta:{delta / sample_cnt}, var:{np.var(all_delta)}, 'positive_score':{positive_score/sample_cnt}, 'negative_score': {negative_score/sample_cnt}, sample_cnt:{sample_cnt}")


if __name__ == '__main__':
    # generate_data('../data/test_dynamic_split.json', '../data/test_dynamic_persona_pair10.json')
    # generate_data('../baselines/plan-ahead/data/test_plan_ahead.json', '../baselines/plan-ahead/data/test_plan_ahead_pair10.json')
    # exit()
    m = Munkres()

    # model_type = 'sentence-transformers/roberta-large-nli-stsb-mean-tokens'
    # layers = 24
    #
    # scorer = BERTScorer(lang='en', model_type=model_type, rescale_with_baseline=False, idf=True, num_layers=layers,
    #                     nthreads=os.cpu_count(), idf_sents=get_idf_sents(), batch_size=64)
    #
    scorer = RobertaPredictor('../classification/models/persona/version_7', gpu=0, bc=100)
    print('finish create scorer')

    for num in [10]:
        score_one_model("../result/plan_write_pair10.json", TextType.bte_outline, scorer, num=num)
        # score_one_model('../result/gpt2_baseline_bte_version0_1.0_persona_pair.json',
        #             TextType.bte, scorer, num=num)
        # score_one_model('../baselines/plan-ahead/result/plan_ahead_bte_onecard_version_num11_1.0.json',
        #                 TextType.bte, scorer, num=num)
        # score_one_model('../result/kg_combine4_reweight_outline_onecard_version17_1.0_pair10.json',
        #                 TextType.plan_tbe, scorer, num=num)
        # score_one_model('../result/plan_ahead_bte_onecard_1.0.json', TextType.bte, scorer, num=num)
        # score_one_model('../result/gpt2_baseline_bte_onecard_1.0.json', TextType.bte, scorer, num=num)
