import matplotlib.pyplot as plt
import argparse
import pandas as pd
import numpy as np
import os
# from sentence_transformers import SentenceTransformer, util
from sklearn.metrics import accuracy_score, f1_score
import json
import glob
import evaluate
from transformers import AutoTokenizer

parser = argparse.ArgumentParser()
parser.add_argument('--data_dir', type=str, default="data")
parser.add_argument('--dataset_train', type=str, default="cfever")
parser.add_argument('--dataset_test', type=str, default="cfever")
parser.add_argument('--seq2seq_model_path', type=str, default="test/supervised_t5_small_cfever/supervised_t5_small_cfever/supervised_t5_small_cfever")
parser.add_argument('--output_dir', type=str, default='test/supervised_t5_small_cfever/supervised_t5_small_cfever/supervised_t5_small_cfever')
parser.add_argument('--experiment_id', type=str, default='supervised_t5_small_cfever')
parser.add_argument('--metric_name', type=str, default='rouge')


args = parser.parse_args()

def read_data(data_dir, dataset_train, dataset_test):
    data_train = pd.read_json(f'{data_dir}/{dataset_train}/train.jsonl', lines = True)
    evidences_train = data_train['evidence']; labels_train = data_train['label']; claims_train =data_train['claim']
    data_test = pd.read_json(f'{data_dir}/{dataset_test}/test.jsonl', lines = True)
    evidences_test = data_test['evidence']; labels_test = data_test['label']; claims_test =data_test['claim']
    return claims_train, evidences_train, labels_train, claims_test, evidences_test, labels_test

def read_generated_data(seq2seq_model_path, epoch, split):
    with open(f'{seq2seq_model_path}/epoch_{epoch}_{split}_split_predictions.json', 'r') as f:
        texts = [line['generated_text'] for line in json.load(f)]
    print(f'Number of generated {split} texts: {len(texts)}')
    return texts

def obtain_scores(claims_dev, evidences_dev, generated, metric, metric_name):
    mapping = {0: 'generated_claim', 1:'generated_evidence', 2: 'claim_evidence'}
    scores ={}
    pairs = [[generated, claims_dev],[generated, evidences_dev], [claims_dev, evidences_dev]]
    for k, pair in enumerate(pairs):
        scores[f'{mapping[k]}_{metric_name}'] = [run_metric([a], [b], metric, metric_name) for a, b in zip(pair[0], pair[1])]
    return scores

def run_metric(sentences1, sentences2, metric, metric_name):
    if metric_name == 'meteor':
        scores = metric.compute(predictions=sentences1, references=sentences2)['meteor']
    elif metric_name == 'bleu':
        try:
            scores = metric.compute(predictions=sentences1, references=sentences2)['bleu']
        except Exception as e:
            scores=0
    elif metric_name =='bleurt':
        scores = metric.compute(predictions=sentences1, references=sentences2)['scores'][0]
    elif metric_name =='rouge':
        scores = metric.get_scores(sentences1, sentences2)['rouge-l']['f']
    elif metric_name == 'sacrebleu':
        scores = metric.compute(predictions=sentences1, references=sentences2)['score']
    elif metric_name == 'bartscore':
        scores = metric.score(sentences1, sentences2, batch_size=4)[0]
    return scores

def load_metric(metric_name):
    if metric_name == 'meteor':
        metric = evaluate.load('meteor', experiment_id=args.experiment_id)
    elif metric_name == 'bleu':
        metric = evaluate.load('bleu', experiment_id=args.experiment_id)
    elif metric_name =='bleurt':
        metric = evaluate.load('bleurt', 'bleurt-large-512', experiment_id=args.experiment_id)
    elif metric_name =='rouge':
        import rouge
        metric = rouge.Rouge(metrics=['rouge-l'], stemming=True)
    elif metric_name == 'sacrebleu':
        metric = evaluate.load('sacrebleu', experiment_id=args.experiment_id)
    elif metric_name == 'bartscore':
        from BARTScore.bart_score import BARTScorer
        metric = BARTScorer(device='cuda:0', checkpoint='facebook/bart-large-cnn')
    return metric



if __name__ == '__main__':
    claims_train, evidences_train, labels_train, claims_test, evidences_test, labels_test = read_data(args.data_dir, args.dataset_train, args.dataset_test)
    t_list = []; f1_list = []; acc_list = []; seed_list = []
    os.makedirs(args.output_dir, exist_ok=True)
    metric_name = args.metric_name
    print('metric_name', metric_name)
    metric = load_metric(metric_name)
    for t in range(0, 21):
    # for t in range(0, 1):
        print('epoch', t)
        for split in ['val', 'test']:  # it's a hack, valid set is actually train set
        # for split in ['test']:  # it's a hack, valid set is actually train set
            print('split', split)
            generated = read_generated_data(args.seq2seq_model_path, t, split)
            # print('generated', generated[:5])
            if split == 'val':
                scores = obtain_scores(claims_train, evidences_train, generated, metric, metric_name=metric_name)
                scores['label'] = labels_train
            else:
                scores = obtain_scores(claims_test, evidences_test, generated,  metric, metric_name=metric_name)
                scores['label'] = labels_test
            df = pd.DataFrame.from_dict(scores)
            # print('df', df)
            file_name = f'{args.output_dir}/aggregate_{metric_name}/epoch_{t}_{split}_scores.csv'
            os.makedirs(os.path.dirname(file_name), exist_ok=True)
            df.to_csv(file_name, index=False)

