import os
from sentence_transformers import SentenceTransformer, models
import jsonlines
import random
from transformers import AutoTokenizer, AutoModel
import torch
from datasets import load_dataset, concatenate_datasets, Dataset, DatasetBuilder
import argparse
from collections import Counter
from sklearn.metrics import confusion_matrix, accuracy_score, f1_score
from sklearn.model_selection import train_test_split
import numpy as np
from numpy import dot, mean, absolute
from numpy.linalg import norm
import pandas as pd
from tqdm import tqdm

import sys
class Unbuffered(object):
   def __init__(self, stream):
       self.stream = stream
   def write(self, data):
       self.stream.write(data)
       self.stream.flush()
   def writelines(self, datas):
       self.stream.writelines(datas)
       self.stream.flush()
   def __getattr__(self, attr):
       return getattr(self.stream, attr)
sys.stdout = Unbuffered(sys.stdout)

parser = argparse.ArgumentParser()
parser.add_argument('--model', type=str, default='bert-base-nli-mean-tokens')   #bert-base-uncased
parser.add_argument('--load_model_from_disk', type=bool, default=False)
parser.add_argument('--dataset', type=str, default='scifact_oracle', help='dataset name')
parser.add_argument('--norm', type=bool, default=False)
parser.add_argument('--abs', type=bool, default=True)
parser.add_argument('--dis', type=str, default='euclidean')
parser.add_argument('--pooling', type=str, default='mean')
parser.add_argument('--output_dir', type=str, default="seed_results")
parser.add_argument('--cross_eval', type=bool, default=False, help='cross dataset evaluation')



args = parser.parse_args()

def distance(a, b):
    '''Euclidean distance'''
    if args.dis == 'euclidean':
        dist = 1 - norm(a - b)
    elif args.dis == 'cosine':
        dist = dot(a, b)/(norm(a)*norm(b))
    return dist


def predict_dis(mean_vecs, X_dev_sampled):
    # make predictions based on Euclidean distance.
    # This works because the Euclidean distance is the l2 norm, and the default value of the ord parameter in numpy.linalg.norm is 2.
    y_list = []
    for diff in tqdm(X_dev_sampled):
        similarity_0 = distance(diff, mean_vecs[0])
        similarity_1 = distance(diff, mean_vecs[1])
        similarity_2 = distance(diff, mean_vecs[2])
        y_hat = np.array([similarity_0, similarity_1, similarity_2]).argmax()
        y_list.append(y_hat)
    return y_list


def evaluate(mean_vecs, X_dev, y_truth):

    # print("y_truth", Counter(y_truth))
    y_pred = predict_dis(mean_vecs, X_dev)
    # print("y_pred", Counter(y_pred))
    y_pred = ['SUPPORTS' if i == 0 else 'NOT_ENOUGH_INFO' if i ==1 else 'REFUTES' for i in y_pred]   # 0:s, 1:n, 2:c
    # print("y_truth", Counter(y_truth))
    # print("y_pred", Counter(y_pred))

    acc = round(accuracy_score(y_truth, y_pred), 4)
    f1 = f1_score(y_truth, y_pred, average='macro').round(4)

    return acc, f1

def diff(claims, evidences, model, abs):
    claim_embeddings = model.encode(claims)
    evidence_embeddings = model.encode(evidences)
    if abs == True:
        # print('calculating diff with abs')
        diffs = absolute(evidence_embeddings - claim_embeddings)
    else:
        # print('calculating diff without abs')
        diffs = evidence_embeddings - claim_embeddings
    return diffs


def filter_features(df):
    # df.rename(columns={'labels': 'label', 'claims_0':'claim', 'evidences_0':'evidence'}, inplace=True)
    labels = df['label'].tolist()
    features = df.filter(['claim', 'evidence'], axis=1)
    DF = pd.concat([features, df['label']], axis=1)
    return features, labels, DF

def read_data(dataset):
    paths = {
        "train": f"./data/{dataset}/train.csv",
        "test": f"./data/{dataset}/test.csv"
    }

    # load the split
    train = pd.read_csv(paths['train'], index_col='index')
    test = pd.read_csv(paths['test'], index_col='index')
    train_features, train_labels, train_DF = filter_features(train)
    test_features, test_labels, test_DF = filter_features(test)
    print(train_DF)
    print(train_features)
    return train_DF, train_labels, train_features, test_DF, test_labels, test_features

def sample_t(labels_train, t=10, seed = 123):
    random.seed(seed)
    s = [i for i, label in enumerate(labels_train) if label =='SUPPORTS']
    n = [i for i, label in enumerate(labels_train) if label =='NOT_ENOUGH_INFO' or label == 'NOT ENOUGH INFO']
    c = [i for i, label in enumerate(labels_train) if label =='REFUTES']
    all_indexes = []
    for l in [s, n, c]:
        indexes = random.sample(l, t)
        all_indexes.extend(indexes)
    # for index in all_indexes:
    #     print(index, labels_train[index])
    return all_indexes

if __name__ == '__main__':

    abs = args.abs
    print('dataset:', args.dataset)
    print('cross_eval:', args.cross_eval)

    if args.load_model_from_disk:
        word_embedding_model = models.Transformer(args.model)
        pooling_model = models.Pooling(word_embedding_dimension = word_embedding_model.get_word_embedding_dimension(), pooling_mode = args.pooling)
        model = SentenceTransformer(modules=[word_embedding_model, pooling_model])
    else:
        model = SentenceTransformer(args.model)

    # read scifact train datset as fit set here
    train_DF, train_labels, train_features, dev_DF, dev_labels, dev_features = read_data(dataset=args.dataset)
    X_dev = diff(dev_features['claim'].tolist(), dev_features['evidence'].tolist(), model, abs)
    y_dev = dev_labels

    t_list = []; f1_list = []; acc_list = []; seed_list = []
    # for t in range(1, 11):
    # for t in range(100, 101):
    for t in list(range(1, 10)) + list(range(10, 60, 10)):
        f1s = []; accs=[]
        for s in range(123, 224):
            train_index = sample_t(train_labels, t=t, seed=s)
            trainset_s_sampled = train_features.iloc[train_index[:t], :]
            trainset_n_sampled = train_features.iloc[train_index[t:2*t], :]
            trainset_c_sampled = train_features.iloc[train_index[2*t:3*t], :]
            vecs = []
            for set in [trainset_s_sampled, trainset_n_sampled, trainset_c_sampled]:
                diff_ = diff(set['claim'].tolist(), set['evidence'].tolist(), model, abs)
                vecs.append(mean(diff_, axis=0))

            # evaluating scifact dev on scifact vectors
            acc, f1 = evaluate(vecs, X_dev, y_dev)
            # print('f1: ', f1)
            f1s.append(f1)
            accs.append(acc)

            t_list.append(t)
            f1_list.append(f1)
            acc_list.append(acc)
            seed_list.append(s)
        print('t: ', t)
        print('mean f1: ', np.mean(f1s), 'mean acc: ', np.mean(accs))
        print()

    #save results
    results = pd.DataFrame({'t': t_list, 'f1': f1_list, 'acc': acc_list, 'seed': seed_list})
    os.makedirs(args.output_dir, exist_ok=True)
    if args.cross_eval:
        results.to_csv(args.output_dir + '/' + args.dataset + '_' + args.model + '_' + args.pooling + '_' + str(abs) + '_cross_eval.csv', index=False)
    else:
        results.to_csv(args.output_dir + '/' + args.dataset + '_' + args.model + '_' + args.pooling + '_' + str(abs) + '.csv', index=False)

