import os
import random
import argparse
from collections import Counter
from sklearn.metrics import confusion_matrix, accuracy_score, f1_score
import numpy as np
import pandas as pd
from tqdm import tqdm
from transformers import AutoTokenizer
import transformers
import torch

from huggingface_hub import login
access_token_read = ... # removed for peer review
access_token_write = ... # removed for peer review
login(token = access_token_read)

print(torch.cuda.is_available())

import sys
class Unbuffered(object):
   def __init__(self, stream):
       self.stream = stream
   def write(self, data):
       self.stream.write(data)
       self.stream.flush()
   def writelines(self, datas):
       self.stream.writelines(datas)
       self.stream.flush()
   def __getattr__(self, attr):
       return getattr(self.stream, attr)
sys.stdout = Unbuffered(sys.stdout)

parser = argparse.ArgumentParser()
parser.add_argument('--model', type=str, default='meta-llama/Llama-2-7b-chat-hf')
parser.add_argument('--load_model_from_disk', type=bool, default=False)
parser.add_argument('--dataset', type=str, default='scifact_oracle', help='dataset name')
parser.add_argument('--d', type=str, default='single_path', help='directory')
parser.add_argument('--output_dir', type=str, default="llama2_results")



args = parser.parse_args()

def read_data(dataset):
    paths = {
        "train": f"data/{dataset}/train.csv",
        "test": f"data/{dataset}/test.csv"
    }

    # load the split
    train = pd.read_csv(paths['train'], index_col='index')
    test = pd.read_csv(paths['test'], index_col='index')

    return train['claim'].tolist(), train['evidence'].tolist(), train['label'].tolist(), test['claim'].tolist(), test['evidence'].tolist(), test['label'].tolist()

def sample_t(labels_train, t=10, seed = 123):
    random.seed(seed)
    s = [i for i, label in enumerate(labels_train) if label =='SUPPORTS']
    n = [i for i, label in enumerate(labels_train) if label =='NOT_ENOUGH_INFO' or label == 'NOT ENOUGH INFO']
    c = [i for i, label in enumerate(labels_train) if label =='REFUTES']
    all_indexes = []
    for l in [s, n, c]:
        indexes = random.sample(l, t)
        all_indexes.extend(indexes)
    # for index in all_indexes:
    #     print(index, labels_train[index])
    return all_indexes

def generate_response(prompt_input):
    sequences = pipeline(
    prompt_input,
    do_sample=True,
    top_k=10,
    num_return_sequences=1,
    eos_token_id=tokenizer.eos_token_id,
    return_full_text=False,
    max_new_tokens=10,
    )
    return sequences[0]['generated_text'].strip("label: ").replace('\n', '')



if __name__ == '__main__':

    #load dataset
    train_claims, train_evidences, train_labels, test_claims, test_evidences, test_labels = read_data(dataset=args.dataset)
    print(f'{args.dataset} dataset loaded')



    # if torch.cuda.is_available():
        # torch.set_default_tensor_type(torch.cuda.HalfTensor)

    model = args.model
    print('model: ', model)

    tokenizer = AutoTokenizer.from_pretrained(model)
    pipeline = transformers.pipeline(
        "text-generation",
        model=model,
        torch_dtype=torch.float16,
        device_map="auto",
    )


    t_list = []; seed_list = []
    f1_list = []; acc_list = []; f1_mapped_list = []; acc_mapped_list = []; 
    prompt_list = []; preds_list = []; preds_mapped_list = []
    for t in range(1, 6):   # 1-shot to 5-shot
        f1s = []; accs=[]; f1s_mapped = []; accs_mapped = []
        for s in range(123, 223):   # 100 seeds

            idxs = sample_t(train_labels, t=t, seed=s)
            prefix = "Please perform the task of claim verification: you are given a claim and a piece of evidence, your goal is to classify the pair out of 'SUPPORTS', 'REFUTES' and 'NOT_ENOUGH_INFO'. Here are a few examples:"
            sufix = "What is the label for the following pair out of 'SUPPORTS', 'REFUTES' and 'NOT_ENOUGH_INFO'? Answer with the label only. "
            prompt = prefix + " " + " ".join([f"claim: {train_claims[idx]} evidence: {train_evidences[idx]} label: {train_labels[idx]}" for idx in idxs]) + " " + sufix
            print("Using the prompt:", prompt)

            y_preds = []; y_preds_mapped = []
            for i in tqdm(range(len(test_claims))):
                test_example = f"claim: {test_claims[i]} evidence: {test_evidences[i]}"
                try:
                    response = generate_response(prompt + test_example)
                except:
                    response = "Error"
                y_preds.append(response)
                if response in ['SUPPORTS', 'REFUTES', 'NOT_ENOUGH_INFO']:
                    y_preds_mapped.append(response)
                else:
                    y_preds_mapped.append('NOT_ENOUGH_INFO')
            # count what the model predicts
            y_preds_dist = Counter(y_preds)
            print("prediction distribution:", y_preds_dist) 
            y_preds_mapped_dist = Counter(y_preds_mapped)
            print("prediction distribution (mapped):", y_preds_mapped_dist)
            
            # evaluate
            acc = round(accuracy_score(test_labels, y_preds), 4)
            f1 = f1_score(test_labels, y_preds, average='macro').round(4)
            acc_mapped = round(accuracy_score(test_labels, y_preds_mapped), 4)
            f1_mapped = f1_score(test_labels, y_preds_mapped, average='macro').round(4)
            print("number of shots: {}; sampling seed {}; f1: {}; acc: {}; f1_mapped: {}; acc_mapped: {};".format(t, s, f1, acc, f1_mapped, acc_mapped))
            f1s.append(f1)
            accs.append(acc)
            f1s_mapped.append(f1_mapped)
            accs_mapped.append(acc_mapped)


            t_list.append(t)
            f1_list.append(f1)
            acc_list.append(acc)
            seed_list.append(s)
            prompt_list.append(prompt)
            preds_list.append(y_preds)
            f1_mapped_list.append(f1_mapped)
            acc_mapped_list.append(acc_mapped)
            preds_mapped_list.append(y_preds_mapped)

        print('t: ', t, 'mean f1: ', np.mean(f1s), 'mean acc: ', np.mean(accs), 'mean f1_mapped: ', np.mean(f1s_mapped), 'mean acc_mapped: ', np.mean(accs_mapped))
        print()

    #save results
    results = pd.DataFrame({'t': t_list, 'f1': f1_list, 'acc': acc_list, 'seed': seed_list, 'prompt': prompt_list, 'preds': preds_list})
    os.makedirs(args.output_dir, exist_ok=True)
    results.to_csv(args.output_dir + '/' + args.dataset + '_' + model.lstrip("meta-llama/") + '.csv', index=False)
