from functools import reduce
import math
from itertools import chain
import os
from scipy import stats
from scipy.stats import ttest_rel, ttest_ind, ttest_ind_from_stats, chi2
import pickle
from tqdm import tqdm
import torch
from torch.utils.data import DataLoader
import numpy as np

from transformers import GPT2Tokenizer
# from core.modeling_gpt2 import GPT2LMHeadModel #swap when testing GPT2 model. 
from core.modeling_gpt2 import GPT2ELMHeadModel as GPT2LMHeadModel


from datasets import load_dataset
from sacremoses import MosesDetokenizer


variant = 'P'
model_path = 'corelm' #complete your model dir
tokenizer = GPT2Tokenizer.from_pretrained(model_path)
model = GPT2LMHeadModel.from_pretrained(model_path)

# directory to cluster as predicted by the Learn2Ignore model. 
# Coreference model implementation can be found here: https://github.com/shtoshni92/long-doc-coref
# we serialized the outputs dictionary from the model (using Pickle) and process the dictionary later.
# Processing is required to align the tokenized text from BERT with the original text, so that entity annotations are aligned
# after GPT2 annotations. While the method implemented below is very simple and unoptimized, 
# it was tested and it has the exact same performance after alignment. 
cluster_path = 'cbt_clusters_p.pkl' 

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model.to(device)


def argmax(t):
    return int(torch.argmax(t).item())

def preprocess(text):
    text = text.replace("“", '"')
    text = text.replace("”", '"')
    text = text.replace("''", '"')
    text = text.replace("``", '"')
    text = text.replace('.', ' . ')
    text = text.replace('  ', ' ')

    text = text.replace("-LRB-", '(')
    text = text.replace("-RRB-", ')')
    text = text.replace("-", "")
    return '\n'+text.strip()

def score_batch_1(example):

    context = ''
    for sent in example['sentences']:
        context+=preprocess(sent)#+' <|endoftext|>' #this is bad
    choices = [preprocess(example['question'].replace('XXXXX', option)) for option in example['options']] 


    encoded_context = tokenizer.encode(context)
    encoded_choices = [tokenizer.encode(s) for s in choices]

    encoded_inputs = []
    for choice in encoded_choices:
        full = encoded_context + choice
        encoded_inputs.append(full[-1024:])

    encoded_inputs_padded = encoded_inputs
    example_losses = []

    for inp in encoded_inputs_padded:
        input_ids = torch.tensor(inp)
        input_ids = input_ids.to(device)
        output = model(input_ids, labels=input_ids)
        example_losses.append(output.loss.detach())

    pred = example_losses.index(min(example_losses))
    answer_location = [example['options'].index(example['answer'])]

    if pred == answer_location[0]:
        correct = 1
        incorrect = 0
    else:
        correct = 0
        incorrect = 1
    return pred, correct, incorrect

def score_batch_1_ents(example):
    inputs = example[0]
    inputs_ents = example[1]
    answer_location = example[2]

    encoded_inputs = []
    encoded_ents = []
    for i, input in enumerate(inputs):
        encoded_input = []
        encoded_input_ents =[]
        for j, token in enumerate(input):
            encoded_input.append(tokenizer.encode(' ' + token, is_split_into_words=True))
            encoded_input_ents.append(inputs_ents[i][j])

        unraveled_encoded_input = []
        unraveled_encoded_input_entities = []
        for j in range(len(encoded_input)):
            for enc_id in encoded_input[j]:
                unraveled_encoded_input.append(enc_id)
                unraveled_encoded_input_entities.append(encoded_input_ents[j])

        # 1024 for max len
        encoded_inputs.append(unraveled_encoded_input[-1024:])
        encoded_ents.append(unraveled_encoded_input_entities[-1024:])


    example_losses = []

    for i, inp in enumerate(encoded_inputs):
        input_ids = torch.tensor(inp)
        input_ids = input_ids.to(device)

        input_ent_ids = torch.tensor([encoded_ents[i]])
        input_ent_ids = input_ent_ids.to(device)

        output = model(input_ids, entities=input_ent_ids, labels=input_ids)

        example_losses.append(output.loss.detach())

    pred = example_losses.index(min(example_losses))


    if pred == answer_location[0][0]:
        correct = 1
        incorrect = 0
    else:
        correct = 0
        incorrect = 1
    return pred, correct, incorrect

def statistical_tests(base,current):

    stat, p = ttest_rel(current, base)
    print(f'Statistical t-test related. p: {p} and stat: {stat}')

    stat, p = ttest_ind(current, base)
    print(f'Statistical t-test indepedent. p: {p} and stat: {stat}')

    b = np.array(base)
    c = np.array(current)

    y_n = 0
    n_y = 0
    for x,y in zip(b, c):
        if x!=y:
            if x==1:
                y_n+=1
            if y==1:
                n_y+=1

    mc_stat = (y_n - n_y)**2 / (y_n + n_y)
    mc_p = chi2.sf(mc_stat, 1)
    print(f'McNemar p: {mc_p} and stat: {mc_stat}')

    return True

def main():
    print(f'Loading model from: {model_path}')

    

    dataset = load_dataset("cbt", variant)
    print(f'Dataset variant: {variant}')

    test_dataset = dataset['test']

    print(f'Loading clusters from {cluster_path}')
    output_clusters = pickle.load(open(cluster_path, 'rb'))

    def aline_text_and_clusters(ds,output_clusters):
        assert len(ds)==len(output_clusters)
        docs = []
        docs_ents = []
        docs_answer = []
        for doc_id in tqdm(range(0,len(ds))):
            example = ds[doc_id]
            example_clusters = output_clusters[doc_id]
            answer = [example['options'].index(example['answer'])]



            context = ''
            for sent in example['sentences']:
                context += preprocess(sent)  # +' <|endoftext|>' #this is bad
            choices = [preprocess(example['question'].replace('XXXXX', option)) for option in example['options']]  # default

            sentences = [context + choice for choice in choices]
            #each of sentences is data from lambada. but not split and striped.
            texts = [example_clusters[l]['tokenized_doc']['sentences'][0] for l in range(len(example_clusters))]

            doc = []
            doc_ents = []
            doc_answer = []
            di = 0 #option index/document(instance) index)
            for sent, text in zip(sentences,texts):

                tok_out_alignment = []
                i=0;j=0
                sent = sent.strip().split(' ')
                while i < len(sent) and j < len(text):
                    try:
                        cd = sent[i]
                        co = text[j]
                    except:
                        print(i, j)
                        cd = sent[i]
                        co = text[j]
                    if sent[i] == text[j]:
                        try:
                            tok_out_alignment[i].extend([j])
                        except:
                            tok_out_alignment.append([j])
                        i += 1;
                        j += 1
                    elif text[j] in sent[i]:
                        try:
                            tok_out_alignment[i].extend([j])
                        except:
                            tok_out_alignment.append([j])
                        j += 1
                    elif text[j].startswith('##'):
                        text[j] = text[j][2:]
                        if text[j] in sent[i]:
                            try:
                                tok_out_alignment[i].extend([j])
                            except:
                                tok_out_alignment.append([j])
                            j += 1
                    else:
                        i += 1  # word should be over, move to the next



                clusters = example_clusters[di]['clusters']
                ents = [0 for x in np.zeros_like(sent).tolist()]
                doc_ent_id = 1
                for cluster in clusters:
                    if len(cluster)>1:
                        for mention in cluster:
                            (start_idx, end_idx), _ = mention
                            if start_idx == end_idx:
                                for index in range(len(tok_out_alignment)):
                                    if start_idx in tok_out_alignment[index]:
                                        ents[index] = doc_ent_id
                            else:
                                for ent_index in [i for i in range(start_idx, end_idx + 1)]:
                                    for index in range(len(tok_out_alignment)):
                                        if ent_index in tok_out_alignment[index]:
                                            ents[index] = doc_ent_id
                        doc_ent_id += 1
                doc.append(sent)
                doc_ents.append(ents)
                doc_answer.append(answer)
                di += 1

            docs.append(doc)
            docs_ents.append(doc_ents)
            docs_answer.append(doc_answer)

        return [docs, docs_ents, docs_answer]

    print('Aligning datasets...')
    data = aline_text_and_clusters(test_dataset, output_clusters)



    errors = 0
    total = 0
    prediction_list = []
    lm_losses = []
    mc_losses = []
    print('Starting predictions...')
    for i in tqdm(range(len(data[0]))):
        example = [data[0][i], data[1][i], data[2][i]]

        pred, correct, incorrect = score_batch_1_ents(example)
        total+=1
        errors+=0 if (correct==1) else 1

        if correct == 1:
            prediction_list.append(1)
        else:
            prediction_list.append(0)

    print(f"Results for CBT-{variant} with model {model_path}:")
    print("Accuracy: %.4f" % (1 - errors / total,))

    print('\n Binomial testing:')
    success = total - errors
    ts = stats.binom_test(success, n=total, alternative='two-sided')
    print(f"Two-sided: {ts}")
    gr = stats.binom_test(success, n=total, alternative='greater')
    print(f"Greater: {gr}")
    ls = stats.binom_test(success, n=total, alternative='less')
    print(f"Less: {ls}")


    outfile = model_path + '/cbt_'+variant+'_pred.pkl'
    print(f'Saving prediction correctness in: {outfile}')
    with open(outfile, 'wb') as f:
        pickle.dump(prediction_list, f, protocol=pickle.HIGHEST_PROTOCOL)

    print(f'Loading GPT2 predictions.')
    with open('../checkpoint/gpt2/cbt_'+variant+'_pred.pkl', 'rb') as pr:
        gpt2_preds = pickle.load(pr)

    statistical_tests(gpt2_preds, prediction_list)

if __name__=='__main__':
    main()