#from : https://github.com/cybertronai/bflm/blob/master/eval_lambada.py
# baseline: Accuracy: 0.4667 - gpt2

import argparse
import logging
import math
import os
import time

import pickle
from itertools import chain
from scipy.stats import ttest_rel, ttest_ind, ttest_ind_from_stats, chi2

import numpy as np
from scipy import stats
import torch
import torch.nn.functional as F
from tqdm import tqdm
from tensorboardX import SummaryWriter
from torch.utils.data import DataLoader, Dataset
from tqdm import trange


from core.modeling_gpt2 import GPT2ELMHeadModel as GPT2LMHeadModel
from core.tokenization_gpt2 import GPT2Tokenizer
from torch.utils.data import DataLoader, Dataset, Subset


parser = argparse.ArgumentParser()
parser.add_argument('--path', type=str, default='/ncluster/data/lambada/lambada_test_plain_text.txt',
                    help='location of lambada dataset')
parser.add_argument('--batch', type=int, default=8, help='batch size')
parser.add_argument('--max-batches', type=int, default=0, help='batch size')
parser.add_argument('--ignore-fragments', action='store_true', help="Whether to run training.")
parser.add_argument('--preprocess', action='store_true', help="strip quotes")
parser.add_argument('--jeff_suggestion', action='store_true',
                    help="use jeff's suggestion of prepending \n to each example")
parser.add_argument('--dryrun', action='store_true', help="test preprocessing pipeline")
args = parser.parse_args()

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
args.device = device


model_name = ''#complete your model dir
enc = GPT2Tokenizer.from_pretrained(model_name)
model = GPT2LMHeadModel.from_pretrained(model_name)
model.to(device)


# directory to cluster as predicted by the Learn2Ignore model. 
# Coreference model implementation can be found here: https://github.com/shtoshni92/long-doc-coref
# we serialized the outputs dictionary from the model (using Pickle) and process the dictionary later.
# Processing is required to align the tokenized text from BERT with the original text, so that entity annotations are aligned
# after GPT2 annotations. While the method implemented below is very simple and unoptimized, 
# it was tested and it has the exact same performance after alignment. 
cluster_path = 'cbt_clusters_p.pkl' 


def argmax(t):
    return int(torch.argmax(t).item())


# from https://github.com/openai/gpt-2/issues/131#issuecomment-492786058
def preprocess(text):
    text = text.replace("“", '"')
    text = text.replace("”", '"')
    text = text.replace("''", '"')
    text = text.replace("``", '"')
    return '\n' + text.strip()


def score_batch(batch):
    """Return number of last-word mismatches in a batch."""
    batch_encoded = []
    lengths = []
    fragments = []
    for line in batch:
        line = line.strip()
        if args.jeff_suggestion:
            line = '\n' + line

        line_encoded = enc.encode(line)
        encoded_last_word = enc.decode(line_encoded[-1:]).strip()
        actual_last_word = line.split()[-1].strip()
        if encoded_last_word != actual_last_word:
            fragments.append(True)
        else:
            fragments.append(False)
        batch_encoded.append(line_encoded)

    # array is ragged, so pad to turn into rectangular tensor
    max_len = max(len(encoded) for encoded in batch_encoded)
    batch_padded = []
    for encoded in batch_encoded:
        batch_padded.append(encoded + [0] * (max_len - len(encoded)))
        lengths.append(len(encoded))

    batch_padded = torch.tensor(batch_padded)
    batch_padded = batch_padded.to(device)
    if args.dryrun:
        return 0, 1

    output = model(batch_padded, labels=batch_padded)
    logits = output['logits']
    loss = output['loss']

    errors = 0
    total = 0
    prediction_list = []
    for i in range(args.batch):
        if i >= len(batch_padded):
            break
        last_idx = lengths[i] - 1
        observed = batch_encoded[i][last_idx]
        predicted = argmax(logits[i][last_idx - 1])
        if args.ignore_fragments and fragments[i]:
            continue
        total += 1
        errors += 0 if (observed == predicted) else 1

        if observed == predicted:
            prediction_list.append(1)
        else:
            prediction_list.append(0)

    return errors, total, loss.detach(), prediction_list


def score_batch_1(batch):
    """Return number of last-word mismatches in a batch."""
    batch_encoded = []
    ents_encoded = []
    lengths = []
    fragments = []

    line = batch[0]
    ents = batch[1]

    new_enc = []
    ents_enc = []
    for i, entry in enumerate(line):
        new_enc.append(enc.encode(' ' + entry, is_split_into_words=True))
        ents_enc.append(ents[i])
    e = []
    e_p = []
    for i in range(len(new_enc)):
        for j in new_enc[i]:
            e.append(j)
            e_p.append(ents[i])

    line_encoded = e
    ents_enc = e_p

    encoded_last_word = enc.decode(line_encoded[-1:]).strip()
    actual_last_word = line[-1]
    if encoded_last_word != actual_last_word:
        fragments.append(True)
    else:
        fragments.append(False)
    batch_encoded.append(line_encoded)
    ents_encoded.append(ents_enc)

    # array is ragged, so pad to turn into rectangular tensor
    max_len = max(len(encoded) for encoded in batch_encoded)
    batch_padded = []
    ents_padded = []
    for encoded in batch_encoded:
        batch_padded.append(encoded + [0] * (max_len - len(encoded)))
        ents_padded.append(ents_encoded[0] + [0] * (max_len - len(encoded)) )
        lengths.append(len(encoded))

    batch_padded = torch.tensor(batch_padded)
    batch_padded = batch_padded.to(device)

    ents_padded = torch.tensor(ents_padded)
    ents_padded = ents_padded.to(device)

    if args.dryrun:
        return 0, 1

    
    output = model(batch_padded, entities=ents_padded, labels=batch_padded)
    logits = output['logits']
    loss = output['loss']

    errors = 0
    total = 0
    prediction_list = []
    for i in range(args.batch):
        if i >= len(batch_padded):
            break
        last_idx = lengths[i] - 1
        observed = batch_encoded[i][last_idx]
        predicted = argmax(logits[i][last_idx - 1])
        if args.ignore_fragments and fragments[i]:
            continue
        total += 1
        errors += 0 if (observed == predicted) else 1

        if observed == predicted:
            prediction_list.append(1)
        else:
            prediction_list.append(0)

    return errors, total, loss.detach(), prediction_list

def statistical_tests(base_preds,current_preds):
    #flatten to work easier
    base = list(chain.from_iterable(base_preds))
    current = list(chain.from_iterable(current_preds))

    stat, p = ttest_rel(current, base)
    print(f'Statistical t-test related. p: {p} and stat: {stat}')

    stat, p = ttest_ind(current, base)
    print(f'Statistical t-test indepedent. p: {p} and stat: {stat}')

    b = np.array(base)
    c = np.array(current)

    y_n = 0
    n_y = 0
    for x,y in zip(b, c):
        if x!=y:
            if x==1:
                y_n+=1
            if y==1:
                n_y+=1

    mc_stat = (y_n - n_y)**2 / (y_n + n_y)
    mc_p = chi2.sf(mc_stat, 1)
    print(f'McNemar p: {mc_p} and stat: {mc_stat}')

    return True

def main():
    print(model_name)
    ds_raw = open(f'{args.path}').read()
    if args.preprocess:
        ds_raw = preprocess(ds_raw)

    ds = ds_raw.strip().split('\n')

    # special handling for jsonl file
    lines = []
    if args.path.endswith('.jsonl'):
        # special handling for file from Jeff
        for line in ds:
            #            candidate1 = eval(line)['text']
            #            lines.append(candidate1)
            candidate2 = line[len('{"text": "'):-len('"}')]
            candidate2 = f'''"""{candidate2}"""'''
            lines.append(eval(candidate2))

            #            lines.append(eval(line))
            # print(line)
            #            break
            #            print(line)
            #            eprint(lines[-1])
        ds = lines
    #add the clusters and make it double list
    #resolve and construct per batch.
    output_clusters = pickle.load(open(cluster_path, 'rb'))
    

    def aline_text_and_clusters(ds,output_clusters):
        assert len(ds)==len(output_clusters)
        tok_data = []
        alignment_data = []
        missmatches = []
        for doc_id in tqdm(range(0,len(ds))):
            output = output_clusters[doc_id]
            out_text = output_clusters[doc_id]['tokenized_doc']['sentences'][0]
            data = ds[doc_id].strip().split(' ')
            tok_out_alignment = []
            i = 0; j = 0
            while i<len(data) and j<len(out_text):
                try:
                    cd = data[i]
                    co = out_text[j]
                except:
                    print(i,j)
                    cd = data[i]
                    co = out_text[j]
                if data[i]==out_text[j]:
                    try:
                        tok_out_alignment[i].extend([j])
                    except:
                        tok_out_alignment.append([j])
                    i+=1;j+=1
                elif out_text[j] in data[i]:
                    try:
                        tok_out_alignment[i].extend([j])
                    except:
                        tok_out_alignment.append([j])
                    j+=1
                elif out_text[j].startswith('##'):
                    out_text[j] = out_text[j][2:]
                    if out_text[j] in data[i]:
                        try:
                            tok_out_alignment[i].extend([j])
                        except:
                            tok_out_alignment.append([j])
                        j+=1
                else:
                    i+=1 #word should be over, move to the next

            if len(tok_out_alignment)!=len(data):
                missmatches.append(doc_id)
            tok_data.append(data)
            alignment_data.append(tok_out_alignment)

        docs = []
        for doc_id in tqdm(range(len(tok_data))):
            doc = tok_data[doc_id]
            align = alignment_data[doc_id]
            clusters = output_clusters[doc_id]['clusters']
            doc_ents = [0 for x in np.zeros_like(doc).tolist()]
            doc_ent_id = 1
            for cluster in clusters:
                if len(cluster)>1: #no singletons
                    for mention in cluster:
                        (start_idx,end_idx), _ = mention
                        
                        if start_idx == end_idx:
                            for index in range(len(align)):
                                if start_idx in align[index]:
                                    doc_ents[index] = doc_ent_id
                        else:
                            for ent_index in [i for i in range(start_idx,end_idx+1)]:
                                for index in range(len(align)):
                                    if ent_index in align[index]:
                                        doc_ents[index] = doc_ent_id
                    doc_ent_id+=1

            docs.append([doc,doc_ents])
        return docs

    print('Aligning clusters with text...')
    docs = aline_text_and_clusters(ds,output_clusters)

    print('Starting calculations...')
    errors = 0
    total = 0
    losses = []
    accum_preds = []
    for i, batch in tqdm(enumerate(docs)):
        errors_batch, total_batch, loss, preds = score_batch_1(batch)
        errors += errors_batch
        total += total_batch
        losses.append(loss)
        accum_preds.append(preds)
        if args.max_batches and i >= args.max_batches - 1:
            break

    loss = sum(losses)/len(losses)
    print("Accuracy: %.4f" % (1 - errors / total,))
    print(f"Loss: {loss}")
    print(f"PPL: {math.exp(loss)}")

    #binomial testing
    print('\n Binomial testing:')
    success = total-errors
    ts = stats.binom_test(success,n=total,alternative='two-sided')
    print(f"Two-sided: {ts}")
    gr = stats.binom_test(success, n=total, alternative='greater')
    print(f"Greater: {gr}")
    ls = stats.binom_test(success, n=total, alternative='less')
    print(f"Less: {ls}")

    outfile = model_name+'lambada_pred.pkl'
    print(f'Saving prediction correctness in: {outfile}')
    with open(outfile, 'wb') as f:
        pickle.dump(accum_preds, f, protocol=pickle.HIGHEST_PROTOCOL)

    print(f'Loading GPT2 predictions.')
    with open('../checkpoint/gpt2/lambada_pred.pkl', 'rb') as pr:
        gpt2_preds = pickle.load(pr)

    statistical_tests(gpt2_preds,accum_preds)


if __name__ == '__main__':
    main()