from random import random, sample, seed
import json
import pandas as pd
import string

from utils import *


def generate_data(path1, mode1, path2, mode2, path_out, name1='reweight', name2='plan_ahead', sample_size=100):
    with open(path1, encoding='utf-8') as f:
        d1 = json.load(f)

    with open(path2, encoding='utf-8') as f:
        d2 = json.load(f)

    assert len(d1) == len(d2)
    l = len(d1)
    all_idxs = list(range(l))

    def ok(idx):
        # persona, prompt = process_text(d1[idx]['prompt'], TextType.prompt)
        # if persona.startswith('You') or persona.startswith('you'):
        #     return True
        # return False

        def invalid(text):
            if '|beginof' in text:
                return True
            for t in text:
                if (not t.isalnum()) and (t not in string.punctuation) and (t not in [' ', '\t', '\n']):
                    # print('t = ', t)
                    # print(not t.isalnum())
                    # print(t not in string.punctuation)
                    # print(t not in [' ', '\t', '\n'])
                    return True
            return False

        out1 = process_text(d1[idx]['generated'], mode1)
        out2 = process_text(d2[idx]['generated'], mode2)
        if invalid(out1) or invalid(out2):
            return False
        return True

    all_idxs = [idx for idx in all_idxs if ok(idx)]
    print('valid samples = ', len(all_idxs))

    selected_idxs = sample(all_idxs, sample_size)
    sampled_d1 = [d1[idx] for idx in selected_idxs]
    sampled_d2 = [d2[idx] for idx in selected_idxs]
    personas = []
    prompts = []
    gen1 = []
    gen2 = []
    names1 = []
    names2 = []
    for x, y in zip(sampled_d1, sampled_d2):
        persona, prompt = process_text(x['prompt'], TextType.prompt)
        personas.append(persona)
        prompts.append(prompt)
        t1 = process_text(x['generated'], mode1)
        t2 = process_text(y['generated'], mode2)
        p = random()
        if p < 0.5:
            gen1.append(t1)
            gen2.append(t2)
            names1.append(name1)
            names2.append(name2)
        else:
            gen1.append(t2)
            gen2.append(t1)
            names1.append(name2)
            names2.append(name1)

    data_frame = pd.DataFrame({'persona': personas, 'premise': prompts, 'out1': gen1, 'out2': gen2})
    data_frame.index.names = ['id']
    data_frame.to_csv(path_out, sep=',', encoding='utf-8')

    df2 = pd.DataFrame({'gen1': names1, 'gen2': names2})

    def get_save_path(p):
        i = p.rfind('.')
        return p[:i] + '#map' + p[i:]

    path_out2 = get_save_path(path_out)
    df2.to_csv(path_out2, sep=',', encoding='utf-8')


def chunk_data(path, out_path, num=50):
    with open(path, encoding='utf-8') as f:
        df = pd.read_csv(f, encoding='utf-8')
    # print(df[:5])
    # w = df[:num]
    w = df[num:]
    w.index.names = ['id']
    print(f'path: {path}, out_path: {out_path}')
    print(f'chunk/all: {num}/{len(df)}')
    # print(type(df))
    # print(type(w))
    w.to_csv(out_path, sep=',', encoding='utf-8', index=False)


def compute_score(score_file, map_file=None):
    with open(score_file, encoding='utf-8') as f:
        # print(f.readlines())
        scores = pd.read_csv(f, encoding='utf-8')
    with open(map_file, encoding='utf-8') as f:
        maps = pd.read_csv(f, encoding='utf-8')
    pd.set_option('display.max_columns', None)
    # print(scores['reflection'][:10])
    # print(type(scores['reflection'][0]))
    # print(scores['reflection'][0])
    # print('---')
    # print(scores['coherence'][:10])
    col1 = maps[maps.columns[-2]][:10].tolist()
    col2 = maps[maps.columns[-1]][:10].tolist()
    print(col1)
    print(col2)
    res = {
        col1[0]: {
            'reflection': 0,
            'coherence': 0
        },
        col2[0]: {
            'reflection': 0,
            'coherence': 0
        }
    }
    for idx, x in enumerate(scores['reflection'][:10]):
        if x == 1.0:
            # print(x, 'in 1.0')
            res[col1[idx]]['reflection'] += 1
        else:
            # print(x, 'in 2.0')
            res[col2[idx]]['reflection'] += 1

    for idx, x in enumerate(scores['coherence'][:10]):
        if x == 1.0:
            res[col1[idx]]['coherence'] += 1
        else:
            res[col2[idx]]['coherence'] += 1

    print(res)


def compute_amt_score(amt_files, map_file):
    def stat_one_amt_file(amt_file):
        with open(amt_file, encoding='utf-8') as f:
            amt = pd.read_csv(f, encoding='utf-8')
        # print(amt.column.names)
        coherence_columns = amt.iloc[:, -8:-5]
        persona_columns = amt.iloc[:, -5: -2]
        id_columns = amt['Input.id']
        status_columns = amt['AssignmentStatus']
        row_num = len(coherence_columns)
        print(row_num)

        def f(scores, ids, status):
            reject_cnt = 0
            # A win 0, A lose 1, tie 2
            res = []
            for i in range(row_num):
                score = scores.iloc[i]
                id = ids.iloc[i]
                tmp_status = status.iloc[i]
                if tmp_status == 'Rejected':
                    reject_cnt += 1
                    continue
                # print(score)
                for j, q in enumerate(score):
                    if q == True:
                        res.append([id, j])
                        break
                # print(id)
                # return
            print("rejected cnt = ", reject_cnt)
            print("valid cnt = ", len(res))
            return res

        coh = f(coherence_columns, id_columns, status_columns)
        per = f(persona_columns, id_columns, status_columns)

        return coh, per

    scores = {
        'coh': [],
        'per': []
    }
    for f in amt_files:
        coh, per = stat_one_amt_file(f)
        scores['coh'].extend(coh)
        scores['per'].extend(per)

    all_agree_ids = {}

    def aggregate(s, key):
        all_agree_ids[key] = []
        d = {}
        for id, score in s:
            if id not in d:
                d[id] = [0, 0, 0]
            d[id][score] += 1
        res = []
        for id in d:
            score_list = d[id]
            # assert sum(score_list) == 3  # 3 annotators
            if score_list[0] == 3 or score_list[1] == 3:
                all_agree_ids[key].append(id)
            answer = -1
            for i in range(len(score_list)):
                t = score_list[i]
                flag = True
                for j in range(len(score_list)):
                    if j == i:
                        continue
                    if t <= score_list[j]:
                        flag = False
                        break
                if flag:
                    answer = i
                    break
            if answer == -1:  # 1 win, 1 lose, 1 tie
                answer = 2
            res.append([id, answer])
        return res

    scores['coh'] = aggregate(scores['coh'], key='coh')
    scores['per'] = aggregate(scores['per'], key='per')

    with open(map_file, encoding='utf-8') as f:
        map_pd = pd.read_csv(f, encoding='utf-8')

    # map_lines = map_pd.iloc[:]
    # print(type(map_lines))
    # print(len(map_lines))
    idlens = len(map_pd)
    # print(map_lines)
    id2name = {}
    id2nameb = {}
    for i in range(idlens):
        # print(type(line))
        line = map_pd.iloc[i]
        id2name[line[0]] = line[1]
        id2nameb[line[0]] = line[2]

    print(id2name)
    res = {'coh': {}, 'per': {}}

    def stat(key):
        idstat = {}
        for id, s in scores[key]:
            name = id2name[id]
            nameb = id2nameb[id]

            if name not in idstat:
                idstat[name] = [[], [], []]
            if nameb not in idstat:
                idstat[nameb] = [[], [], []]
            idstat[name][s].append(id)
            if s == 2:
                idstat[nameb][s].append(id)
            else:
                idstat[nameb][1 - s].append(id)

            w = res[key]
            if name not in w:
                w[name] = [0, 0, 0]
            if nameb not in w:
                w[nameb] = [0, 0, 0]
            w[name][s] += 1
            if s == 2:
                w[nameb][s] += 1
            else:
                w[nameb][1 - s] += 1
        print(key)
        print(idstat)
        a = set(idstat['reweight'][1] + idstat['reweight'][2])
        # a = set(idstat['reweight'][1])
        ex = set(all_agree_ids[key])
        b = a - ex
        print(f'reject ids: {b}')

    stat('coh')
    stat('per')
    print(f'all_agree_ids:{all_agree_ids}')
    print(res)

    def sign_test(key, model_name):
        x = res[key][model_name]
        win = x[0]
        # win = 51
        lose = x[1]
        # lose = 28
        tie = x[2]
        a = [1] * win + [-1] * lose + [0] * tie
        # from statsmodels.stats.descriptivestats import sign_test
        # return sign_test(a)
        #
        from scipy.stats import wilcoxon
        # x = [0] * len(a)
        return wilcoxon(a)

    coh_p = sign_test('coh', 'reweight')
    per_p = sign_test('per', 'reweight')
    print(f'coh_p: {coh_p}, per_p: {per_p}')


def fleiss_kappa_amt(amt_files, map_file):
    all_scores = {
        'per': [[0, 0, 0] for _ in range(100)],
        'coh': [[0, 0, 0] for _ in range(100)]
    }

    # print(scores)

    with open(map_file, encoding='utf-8') as f:
        map_pd = pd.read_csv(f, encoding='utf-8')

    idlens = len(map_pd)
    id2name = {}
    id2nameb = {}
    for i in range(idlens):
        line = map_pd.iloc[i]
        id2name[line[0]] = line[1]
        id2nameb[line[0]] = line[2]

    def score_one_file(amt_file):
        with open(amt_file, encoding='utf-8') as f:
            amt = pd.read_csv(f, encoding='utf-8')
        # print(amt.column.names)
        coherence_columns = amt.iloc[:, -8:-5]
        persona_columns = amt.iloc[:, -5: -2]
        id_columns = amt['Input.id']
        status_columns = amt['AssignmentStatus']
        row_num = len(coherence_columns)
        print(row_num)

        def f(scores, ids, status, key):
            reject_cnt = 0
            # A win 0, A lose 1, tie 2
            model_name = 'reweight'
            for i in range(row_num):
                score = scores.iloc[i]
                id = ids.iloc[i]
                tmp_status = status.iloc[i]
                if tmp_status == 'Rejected':
                    reject_cnt += 1
                    continue
                # print(score)
                for j, q in enumerate(score):
                    if q == True:
                        if sum(all_scores[key][id]) == 3:
                            break
                        s = j
                        # print(id2name[id])
                        if id2name[id] != model_name:
                            if j == 2:
                                s = 2
                            else:
                                s = 1 - j
                        all_scores[key][id][s] += 1
                        break
                # print(id)
                # return
            # print("rejected cnt = ", reject_cnt)
            # print("valid cnt = ", len(res))
            # return res

        f(coherence_columns, id_columns, status_columns, 'coh')
        f(persona_columns, id_columns, status_columns, 'per')

    for amt_file in amt_files:
        score_one_file(amt_file)

    print(all_scores)
    cnt = 0
    id = -1
    for i, q in enumerate(all_scores['coh']):
        if sum(q) != 3:
            print(q)
            cnt += 1
            id = i
        # assert sum(q) == 3
    # for q in all_scores['per']:
    #     assert sum(q) == 3
    all_scores['coh'].pop(id)
    all_scores['per'].pop(id)

    def compute_win_lose_tie(s):
        res = [0, 0, 0]
        for score_list in s:
            # assert sum(score_list) == 3  # 3 annotators
            answer = -1
            for i in range(len(score_list)):
                t = score_list[i]
                flag = True
                for j in range(len(score_list)):
                    if j == i:
                        continue
                    if t <= score_list[j]:
                        flag = False
                        break
                if flag:
                    answer = i
                    break
            if answer == -1:  # 1 win, 1 lose, 1 tie
                answer = 2
            res[answer] += 1
        return res

    print(f'coh: {compute_win_lose_tie(all_scores["coh"])}')
    print(f'per: {compute_win_lose_tie(all_scores["per"])}')

    def filter(s):
        res = []
        for q in s:
            if q[0] == 1 and q[1] == 1 and q[2] == 1:
                continue
            res.append(q)
        return res

    all_scores["coh"] = filter(all_scores["coh"])
    print(len(all_scores["coh"]))
    all_scores["per"] = filter(all_scores["per"])
    print(len(all_scores["per"]))

    # from fleiss import fleissKappa
    from statsmodels.stats.inter_rater import fleiss_kappa
    print(f'coh:{fleiss_kappa(all_scores["coh"], method="uniform")}')
    print(f'per:{fleiss_kappa(all_scores["per"], method="uniform")}')


reject_ids = None


def reject_amt(in_file, out_file):
    with open(in_file, encoding='utf-8') as f:
        amt = pd.read_csv(f, encoding='utf-8')

    id_columns = amt['Input.id']
    status_columns = amt['AssignmentStatus']
    submittime_columns = amt['SubmitTime']
    reject_columns = amt['Reject']
    reason = 'You need to make a more fine-grained comparsion when annotating persona consistency, as it is very common that neither of the two systems can clearly reflect the given persona. Be extra careful when you choose tie.'
    l = len(amt)
    cnt = 0
    for i in range(l):
        id = id_columns.iloc[i]
        if id not in reject_ids:
            continue
        if status_columns.iloc[i] == 'Rejected':
            continue
        t = submittime_columns.iloc[i]
        # if t.split()[0] == 'Sun':
        #     continue
        reject_columns.iloc[i] = reason
        cnt += 1
    print(amt.columns)
    print(cnt)
    amt.to_csv(out_file, sep=',', encoding='utf-8', header=True, index=False)


if __name__ == '__main__':
    # reject_ids = {1, 2, 3, 4, 5, 6, 10, 13, 15, 16, 19, 22, 26, 27, 28, 30, 32, 33, 34, 35, 36, 38, 39, 40, 41, 42, 44,
    #               49, 52,
    #               53, 57, 58, 59, 61, 62,
    #               64, 65, 68, 69, 70, 72, 74, 78, 80, 81, 83, 84, 85, 87, 89, 90, 93, 94, 98}
    # reject_amt('files_human_eval/ours_gpt2_100.csv',
    #            'files_human_eval/ours_gpt2_100_autoreject.csv')
    # fleiss_kappa_amt([
    #     'files_human_eval/ours_planahead_post50_reject2.csv',
    #     'files_human_eval/ours_planahead_50_twoannotators_rescore.csv',
    #     'files_human_eval/ours_planahead_50_oneannotator.csv'
    # ], 'files_human_eval/reweight#plan_ahead#evaluation#map.csv')
    fleiss_kappa_amt(['files_human_eval/ours_gpt2_100_reject2.csv'],
                      'files_human_eval/reweight#gpt2#evaluation#map.csv')
    # compute_amt_score(['files_human_eval/ours_gpt2_100_reject2.csv'],
    #                   'files_human_eval/reweight#gpt2#evaluation#map.csv')

    # compute_amt_score(
    #     [
    #         # 'files_human_eval/ours_planahead_post50_reject2.csv',
    #         # 'files_human_eval/ours_planahead_50_twoannotators_rescore.csv',
    #         # 'files_human_eval/ours_planahead_50_oneannotator.csv'
    #         'files_human_eval/ours_gpt2_100.csv'
    #     ],
    #     #  'files_human_eval/ours_planahead_post50_threeannotators.csv'],
    # 'files_human_eval/reweight#plan_ahead#evaluation#map.csv')
    exit()
    generate_data('../result/final/kg_combine4_reweight_outline_onecard_version17_1.0.json',
                  TextType.plan_tbe,
                  '../result/final/gpt2_baseline_bte_onecard_1.0_ori.json',
                  TextType.bte,
                  'files_human_eval/reweight#gpt2#evaluation.csv',
                  )

    # chunk_data('files_human_eval/reweight#plan_ahead#evaluation.csv',
    #    'files_human_eval/reweight#plan_ahead#evaluation_post50.csv')

