#!/usr/bin/env python
# !-*-coding:utf-8 -*-
import pandas as pd

from evaluator.meteor.meteor import Meteor
from evaluator.bleu.smooth_bleu import codenn_smooth_bleu
from evaluator.rouge.rouge import Rouge
from evaluator.cider.cider import Cider
import warnings
import argparse
import logging

warnings.filterwarnings('ignore')
logging.basicConfig(format='[%(asctime)s - %(levelname)s - %(name)s] %(message)s',
                    datefmt='%m/%d/%Y %H:%M:%S',
                    level=logging.INFO)


def Commitbleus(refs, preds):
    r_str_list = []
    p_str_list = []
    for r, p in zip(refs, preds):
        if len(r[0]) == 0 or len(p) == 0:
            continue
        r_str_list.append([" ".join([str(token_id) for token_id in r[0]])])
        p_str_list.append(" ".join([str(token_id) for token_id in p]))
    try:
        bleu_list = codenn_smooth_bleu(r_str_list, p_str_list)
    except:
        bleu_list = [0, 0, 0, 0]
    codenn_bleu = bleu_list[0]

    B_Norm = round(codenn_bleu, 4)

    return B_Norm


def read_to_list(filename):
    f = open(filename, 'r', encoding="utf-8")
    res = []
    for row in f:
        # (rid, text) = row.split('\t')
        res.append(row.lower().split())
    return res


def metetor_rouge_cider(refs, preds):
    refs_dict = {}
    preds_dict = {}
    for i in range(len(preds)):
        preds_dict[i] = [" ".join(preds[i])]
        refs_dict[i] = [" ".join(refs[i][0])]

    score_Meteor, scores_Meteor = Meteor().compute_score(refs_dict, preds_dict)
    print("Meteor: ", round(score_Meteor * 100, 2))

    score_Rouge, scores_Rouge = Rouge().compute_score(refs_dict, preds_dict)
    print("Rouge-L: ", round(score_Rouge * 100, 2))

    score_Cider, scores_Cider = Cider().compute_score(refs_dict, preds_dict)
    print("Cider: ", round(score_Cider, 2))


def main():
    # parser = argparse.ArgumentParser()10
    # parser.add_argument('--refs_filename', type=str, default="../saved_model/tlcodesum/UNLC/ref.txt", required=False)
    # parser.add_argument('--preds_filename', type=str,
    #                     default="../saved_model/tlcodesum/UNLC/dlen500-clen30-dvoc30000-cvoc30000-bs-ddim64-cdim-rhs64-lr0_Medit_pred.txt",
    #                     required=False)
    # args = parser.parse_args()
    # refs = read_to_list(args.refs_filename)
    # refs = [[t] for t in refs]
    # preds = read_to_list(args.preds_filename)
    #
    # bleus_score = Commitbleus(refs, preds)
    # print("BLEU: %.2f" % bleus_score)
    # metetor_rouge_cider(refs, preds)
    prefix = 'one'
    df = pd.read_csv(f'{prefix}_shot_output.csv')

    column_lists = {}
    for column in df.columns:
        column_lists[column] = df[column].tolist()

    column_name = [f'{prefix}_prompt_output_1', f'{prefix}_prompt_output_2', f'{prefix}_prompt_output_3']
    for i in range(3):
        print(f"prompt-{i+1}:")

        refs = []
        for idx, t in enumerate(column_lists['references']):
            tmp = [str(idx)]
            tmp.extend(t.lower().split())
            refs.append(tmp)
        refs = [[t] for t in refs]

        preds = []
        for idx, t in enumerate(column_lists[column_name[i]]):
            tmp = [str(idx)]
            tmp.extend(t.lower().replace('\r\n', ' ').replace('\n', ' ').replace('```', ' ').strip().split())
            preds.append(tmp)

        bleus_score = Commitbleus(refs, preds)
        print("BLEU: %.2f" % bleus_score)
        metetor_rouge_cider(refs, preds)
    #     with open(f'zero_prompt_output_{i + 1}.txt', 'w', encoding='utf-8') as file:
    #         for item in column_lists[column_name[i]]:
    #             file.write(item.replace('\r\n', ' ').replace('\n', ' ') + '\n')
    #
    #     with open(f'references.txt', 'w') as file:
    #         for item in column_lists['references']:
    #             file.write(item.strip() + '\n')

if __name__ == '__main__':
    main()
