import matplotlib.pyplot as plt
import argparse
import pandas as pd
import os
from sentence_transformers import SentenceTransformer, util
import json

parser = argparse.ArgumentParser()
parser.add_argument('--data_dir', type=str, default="data")
parser.add_argument('--dataset_train', type=str, default="cfever")
parser.add_argument('--dataset_test', type=str, default="cfever")
parser.add_argument('--seq2seq_model_path', type=str, default="test/supervised_t5_small_cfever/supervised_t5_small_cfever/supervised_t5_small_cfever")
parser.add_argument('--output_dir', type=str, default='test/supervised_t5_small_cfever/supervised_t5_small_cfever/supervised_t5_small_cfever')
parser.add_argument('--smodel', type=str, default="all-mpnet-base-v2", help='sentence transformer model such as '
                                                                            'all-MiniLM-L6-v2'
                                                                            'all-mpnet-base-v2')


args = parser.parse_args()

def read_data(data_dir, dataset_train, dataset_test):
    data_train = pd.read_json(f'{data_dir}/{dataset_train}/train.jsonl', lines = True)
    evidences_train = data_train['evidence']; labels_train = data_train['label']; claims_train =data_train['claim']
    data_test = pd.read_json(f'{data_dir}/{dataset_test}/test.jsonl', lines = True)
    evidences_test = data_test['evidence']; labels_test = data_test['label']; claims_test =data_test['claim']
    return claims_train, evidences_train, labels_train, claims_test, evidences_test, labels_test

def read_generated_data(seq2seq_model_path, epoch, split):
    with open(f'{seq2seq_model_path}/epoch_{epoch}_{split}_split_predictions.json', 'r') as f:
        texts = [line['generated_text'] for line in json.load(f)]
    print(f'Number of generated {split} texts: {len(texts)}')
    return texts

def obtain_scores(claims_dev, evidences_dev, generated):
    mapping = {0: 'generated_claim', 1:'generated_evidence', 2: 'claim_evidence'}
    scores ={}
    pairs = [[generated, claims_dev],[generated, evidences_dev], [claims_dev, evidences_dev]]
    for k, pair in enumerate(pairs):
        scores[mapping[k]] = pair_distance(pair[0], pair[1])
    return scores

def pair_distance(sentences1, sentences2):
    #Compute embedding for both lists
    embeddings1 = smodel.encode(sentences1, convert_to_tensor=True)
    embeddings2 = smodel.encode(sentences2, convert_to_tensor=True)

    #Compute cosine-similaritieslineplot_0-10
    cosine_scores = util.cos_sim(embeddings1, embeddings2)
    # print(cosine_scores)
    return cosine_scores.diagonal().cpu().numpy()

if __name__ == '__main__':
    claims_train, evidences_train, labels_train, claims_test, evidences_test, labels_test = read_data(args.data_dir, args.dataset_train, args.dataset_test)
    t_list = []; f1_list = []; acc_list = []; seed_list = []
    smodel = SentenceTransformer(args.smodel)
    os.makedirs(args.output_dir, exist_ok=True)

    for t in range(0, 21):
        accs = []; f1s = []
        for split in ['val', 'test']: #it's a hack, valid set is actually train set
            print('split', split)
            generated = read_generated_data(args.seq2seq_model_path, t, split)
            # print('generated', generated[:5])
            if split == 'val':
                scores = obtain_scores(claims_train, evidences_train, generated)
                scores['label'] = labels_train
            else:
                scores = obtain_scores(claims_test, evidences_test, generated)
                scores['label'] = labels_test
            # for k, v in scores.items():
            #     print(k, len(v))

            df = pd.DataFrame.from_dict(scores)
            df.to_csv(f'{args.output_dir}/epoch_{t}_{split}_scores.csv', index=False)

