#from : https://github.com/cybertronai/bflm/blob/master/eval_lambada.py
# baseline: Accuracy: 0.4667 - gpt2

import argparse
import logging
import math
import os
import time

import pickle
from itertools import chain
from scipy.stats import ttest_rel, ttest_ind, ttest_ind_from_stats, chi2

import numpy as np
from scipy import stats
import torch
import torch.nn.functional as F
import tqdm
from tensorboardX import SummaryWriter
from torch.utils.data import DataLoader, Dataset
from tqdm import trange

# from core.modeling_gpt2 import GPT2LMHeadModel as GPT2LMHeadModel #for baseline
from core.modeling_gpt2 import GPT2ELMHeadModel as GPT2LMHeadModel
from core.tokenization_gpt2 import GPT2Tokenizer
from torch.utils.data import DataLoader, Dataset, Subset


parser = argparse.ArgumentParser()
parser.add_argument('--path', type=str, default='/ncluster/data/lambada/lambada_test_plain_text.txt',
                    help='location of lambada dataset')
parser.add_argument('--batch', type=int, default=8, help='batch size')
parser.add_argument('--max-batches', type=int, default=0, help='batch size')
parser.add_argument('--ignore-fragments', action='store_true', help="Whether to run training.")
parser.add_argument('--preprocess', action='store_true', help="strip quotes")
parser.add_argument('--jeff_suggestion', action='store_true',
                    help="use jeff's suggestion of prepending \n to each example")
parser.add_argument('--dryrun', action='store_true', help="test preprocessing pipeline")
args = parser.parse_args()

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
args.device = device

model_name = '' #complete your model dir
enc = GPT2Tokenizer.from_pretrained(model_name)
model = GPT2LMHeadModel.from_pretrained(model_name)
model.to(device)


def argmax(t):
    return int(torch.argmax(t).item())


# from https://github.com/openai/gpt-2/issues/131#issuecomment-492786058
def preprocess(text):
    text = text.replace("“", '"')
    text = text.replace("”", '"')
    text = text.replace("''", '"')
    text = text.replace("``", '"')
    return '\n' + text.strip()


def score_batch(batch):
    """Return number of last-word mismatches in a batch."""
    batch_encoded = []
    lengths = []
    fragments = []
    for line in batch:
        line = line.strip()
        if args.jeff_suggestion:
            line = '\n' + line
        line_encoded = enc.encode(line)
        encoded_last_word = enc.decode(line_encoded[-1:]).strip()
        actual_last_word = line.split()[-1].strip()
        if encoded_last_word != actual_last_word:
            fragments.append(True)
        else:
            fragments.append(False)
        batch_encoded.append(line_encoded)

    max_len = max(len(encoded) for encoded in batch_encoded)
    batch_padded = []
    for encoded in batch_encoded:
        batch_padded.append(encoded + [0] * (max_len - len(encoded)))
        lengths.append(len(encoded))

    batch_padded = torch.tensor(batch_padded)
    batch_padded = batch_padded.to(device)
    if args.dryrun:
        return 0, 1

    output = model(batch_padded, labels=batch_padded)
    logits = output['logits']
    loss = output['loss']

    errors = 0
    total = 0
    prediction_list = []
    for i in range(args.batch):
        if i >= len(batch_padded):
            break
        last_idx = lengths[i] - 1
        observed = batch_encoded[i][last_idx]
        predicted = argmax(logits[i][last_idx - 1])
        if args.ignore_fragments and fragments[i]:
            continue
        total += 1
        errors += 0 if (observed == predicted) else 1

        if observed == predicted:
            prediction_list.append(1)
        else:
            prediction_list.append(0)

    return errors, total, loss.detach(), prediction_list

def statistical_tests(base_preds,current_preds):
    #flatten to work easier
    base = list(chain.from_iterable(base_preds))
    current = list(chain.from_iterable(current_preds))

    stat, p = ttest_rel(current, base)
    print(f'Statistical t-test related. p: {p} and stat: {stat}')

    stat, p = ttest_ind(current, base)
    print(f'Statistical t-test indepedent. p: {p} and stat: {stat}')

    b = np.array(base)
    c = np.array(current)

    y_n = 0
    n_y = 0
    for x,y in zip(b, c):
        if x!=y:
            if x==1:
                y_n+=1
            if y==1:
                n_y+=1

    mc_stat = (y_n - n_y)**2 / (y_n + n_y)
    mc_p = chi2.sf(mc_stat, 1)
    print(f'McNemar p: {mc_p} and stat: {mc_stat}')

    return True

def main():
    print(model_name)
    ds_raw = open(f'{args.path}').read()
    if args.preprocess:
        ds_raw = preprocess(ds_raw)

    ds = ds_raw.strip().split('\n')

    # special handling for jsonl file
    lines = []
    if args.path.endswith('.jsonl'):
        # special handling for file from Jeff
        for line in ds:
            #            candidate1 = eval(line)['text']
            #            lines.append(candidate1)
            candidate2 = line[len('{"text": "'):-len('"}')]
            candidate2 = f'''"""{candidate2}"""'''
            lines.append(eval(candidate2))

            #            lines.append(eval(line))
            # print(line)
            #            break
            #            print(line)
            #            eprint(lines[-1])
        ds = lines
    data_loader = DataLoader(ds, batch_size=args.batch, shuffle=False)

    errors = 0
    total = 0
    losses = []
    accum_preds = []
    for i, batch in enumerate(data_loader):
        errors_batch, total_batch, loss, preds = score_batch(batch)
        errors += errors_batch
        total += total_batch
        losses.append(loss)
        accum_preds.append(preds)
        if args.max_batches and i >= args.max_batches - 1:
            break

    loss = sum(losses)/len(losses)
    print("Accuracy: %.4f" % (1 - errors / total,))
    print(f"Loss: {loss}")
    print(f"PPL: {math.exp(loss)}")

    #binomial testing
    print('\n Binomial testing:')
    success = total-errors
    ts = stats.binom_test(success,n=total,alternative='two-sided')
    print(f"Two-sided: {ts}")
    gr = stats.binom_test(success, n=total, alternative='greater')
    print(f"Greater: {gr}")
    ls = stats.binom_test(success, n=total, alternative='less')
    print(f"Less: {ls}")

    outfile = model_name+'lambada_pred.pkl'
    print(f'Saving prediction correctness in: {outfile}')
    with open(outfile, 'wb') as f:
        pickle.dump(accum_preds, f, protocol=pickle.HIGHEST_PROTOCOL)

    print(f'Loading GPT2 predictions.')
    with open('../checkpoint/gpt2/lambada_pred.pkl', 'rb') as pr:
        gpt2_preds = pickle.load(pr)

    statistical_tests(gpt2_preds,accum_preds)

if __name__ == '__main__':
    main()