from functools import reduce
import math
from itertools import chain
from scipy import stats
from scipy.stats import ttest_rel, ttest_ind, ttest_ind_from_stats, chi2
import pickle
from tqdm import tqdm
import torch
from torch.utils.data import DataLoader
import numpy as np

from transformers import GPT2Tokenizer
# from core.modeling_gpt2 import GPT2LMHeadModel #swap when testing GPT2 model. 
from core.modeling_gpt2 import GPT2ELMHeadModel as GPT2LMHeadModel


from datasets import load_dataset
from sacremoses import MosesDetokenizer


variant = 'P'
model_path = 'corelm' #complete your model dir
tokenizer = GPT2Tokenizer.from_pretrained(model_path)
model = GPT2LMHeadModel.from_pretrained(model_path)

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
# device = "cpu" #doesn't fit in my GPU - have to CPU.
model.to(device)


def argmax(t):
    return int(torch.argmax(t).item())

def preprocess(text):
    text = text.replace("“", '"')
    text = text.replace("”", '"')
    text = text.replace("''", '"')
    text = text.replace("``", '"')
    text = text.replace('.', ' . ')
    text = text.replace('  ', ' ')

    text = text.replace("-LRB-", '(')
    text = text.replace("-RRB-", ')')
    text = text.replace("-", "")
    return '\n'+text.strip()

def score_batch_1(example):
    context = ''
    for sent in example['sentences']:
        context+=preprocess(sent)#+' <|endoftext|>' #this is bad
    choices = [preprocess(example['question'].replace('XXXXX', option)) for option in example['options']] #default


    encoded_context = tokenizer.encode(context)
    encoded_choices = [tokenizer.encode(s) for s in choices]

    encoded_inputs = []
    for choice in encoded_choices:
        full = encoded_context + choice
        encoded_inputs.append(full[-1024:])


    encoded_inputs_padded = encoded_inputs
    example_losses = []
    for inp in encoded_inputs_padded:
        input_ids = torch.tensor(inp)
        input_ids = input_ids.to(device)
        output = model(input_ids, labels=input_ids)
        example_losses.append(output.loss.detach())

    pred = example_losses.index(min(example_losses))
    answer_location = [example['options'].index(example['answer'])]

    if pred == answer_location[0]:
        correct = 1
        incorrect = 0
    else:
        correct = 0
        incorrect = 1
    return pred, correct, incorrect

def statistical_tests(base,current):

    stat, p = ttest_rel(current, base)
    print(f'Statistical t-test related. p: {p} and stat: {stat}')

    stat, p = ttest_ind(current, base)
    print(f'Statistical t-test indepedent. p: {p} and stat: {stat}')

    b = np.array(base)
    c = np.array(current)

    y_n = 0
    n_y = 0
    for x,y in zip(b, c):
        if x!=y:
            if x==1:
                y_n+=1
            if y==1:
                n_y+=1

    mc_stat = (y_n - n_y)**2 / (y_n + n_y)
    mc_p = chi2.sf(mc_stat, 1)
    print(f'McNemar p: {mc_p} and stat: {mc_stat}')

    return True

def main():
    print(f'Loading model from: {model_path}')

    dataset = load_dataset("cbt", variant)
    print(f'Dataset variant: {variant}')

    test_dataset = dataset['test']

    errors = 0
    total = 0
    prediction_list = []
    lm_losses = []
    mc_losses = []
    for example in tqdm(test_dataset):
        pred, correct, incorrect = score_batch_1(example)
        total+=1
        errors+=0 if (correct==1) else 1

        if correct == 1:
            prediction_list.append(1)
        else:
            prediction_list.append(0)

    print(f"Results for CBT-{variant} with model {model_path}:")
    print("Accuracy: %.4f" % (1 - errors / total,))

    print('\n Binomial testing:')
    success = total - errors
    ts = stats.binom_test(success, n=total, alternative='two-sided')
    print(f"Two-sided: {ts}")
    gr = stats.binom_test(success, n=total, alternative='greater')
    print(f"Greater: {gr}")
    ls = stats.binom_test(success, n=total, alternative='less')
    print(f"Less: {ls}")


    outfile = model_path + '/cbt_'+variant+'_pred.pkl'
    print(f'Saving prediction correctness in: {outfile}')
    with open(outfile, 'wb') as f:
        pickle.dump(prediction_list, f, protocol=pickle.HIGHEST_PROTOCOL)

    print(f'Loading GPT2 predictions.')
    with open('../checkpoint/gpt2/cbt_'+variant+'_pred.pkl', 'rb') as pr:
        gpt2_preds = pickle.load(pr)

    statistical_tests(gpt2_preds, prediction_list)

if __name__=='__main__':
    main()