"""
Contains a training routing to finetune GPT2 on some dataset.

Also contains a function to measure the geometric mean probability of words in a text.
"""

import pandas as pd
from transformers import GPT2LMHeadModel, GPT2Tokenizer
from transformers import GPT2Tokenizer, GPT2LMHeadModel, AdamW, get_linear_schedule_with_warmup
from tqdm import tqdm, trange
import torch.nn.functional as F
import torch
import random
import sys
sys.path.append("../")
from emb2emb.utils import read_all
import argparse
import numpy as np
import os

# Accumulated batch size (since GPT2 is so big)
# def pack_tensor(new_tensor, packed_tensor, max_seq_len):
#    if packed_tensor is None:
#        return new_tensor, True, None
#    if new_tensor.size()[1] + packed_tensor.size()[1] > max_seq_len:
#        return packed_tensor, False, new_tensor
#    else:
#        packed_tensor = torch.cat([new_tensor, packed_tensor[:, 1:]], dim=1)
#        return packed_tensor, True, None


def finetune_gpt(params):
    """
    Finetunes GPT2 on some dataset.
    """

    device = torch.device(params.device)

    # Load data
    text_data = np.array(read_all(params.dataset))

    # Get the tokenizer and model
    tokenizer = GPT2Tokenizer.from_pretrained('gpt2')
    # we need a padding token because we're dealing with short sentences
    tokenizer.pad_token = tokenizer.eos_token
    pad_token_id = tokenizer.convert_tokens_to_ids(tokenizer.pad_token)
    print(pad_token_id)
    model = GPT2LMHeadModel.from_pretrained('gpt2')

    model = model.to(device)
    model.train()

    # Hyperparameters
    batch_size = params.batch_size
    epochs = params.num_epochs
    lr = params.lr  # lr = 2e-5
    warmup_steps = params.warmup_steps  # warmup_steps = 200
    #gpt2_type = "gpt2"
    output_path = params.output_path
    #output_dir = "."
    #output_prefix = "wreckgar"
    #test_mode = False
    save_model_on_epoch = True

    optimizer = AdamW(model.parameters(), lr=lr)
    scheduler = get_linear_schedule_with_warmup(
        optimizer, num_warmup_steps=warmup_steps, num_training_steps=-1
    )

    loss = 0
    accumulating_batch_count = 0
    input_tensor = None

    print("Training...")
    for epoch in range(epochs):

        # shuffle the data
        indices = list(range(len(text_data)))
        random.shuffle(indices)

        text_data = text_data[indices]

        print(f"Training epoch {epoch}")
        print(loss)
        for stidx in range(0, len(text_data), params.batch_size):

            if (stidx % 10000) == 0:
                print(f"Currently at E{epoch}:{stidx}.")

            # prepare batch
            text_batch = text_data[stidx:stidx + params.batch_size]

            cur_batch_size = len(text_batch)

            input_tensor = tokenizer.batch_encode_plus(
                [f"{t}<|endoftext|>" for t in text_batch], padding=True)
            max_len = max([len(ids) for ids in input_tensor["input_ids"]])

            # make sure we pad att mask and input to all having the same length
            att_mask = torch.zeros(
                (cur_batch_size, max_len), device=model.device)
            inp = torch.full((cur_batch_size, max_len),
                             fill_value=pad_token_id, device=model.device)
            for i in range(len(text_batch)):
                inp_i = input_tensor["input_ids"][i]
                inp_i_len = len(inp_i)
                inp_i = torch.tensor(inp_i, device=model.device)
                inp[i, :inp_i_len] = inp_i

                att_mask_i = input_tensor["attention_mask"][i]
                att_mask_i_len = len(att_mask_i)
                att_mask_i = torch.tensor(att_mask_i, device=model.device)
                att_mask[i, :att_mask_i_len] = att_mask_i

            # mask the padding symbols
            labels = inp.clone()
            mask_padding = ((labels == tokenizer.convert_tokens_to_ids(tokenizer.pad_token)).long().cumsum(
                dim=1) > 1)  # we don't want to mask the EOS token
            labels[mask_padding] = -100

            # token_type_ids = torch.tensor(
            #    input_tensor["token_type_ids"]).to(device)
            outputs = model(inp, attention_mask=att_mask, labels=inp)
            loss = outputs[0]
            loss.backward()

            optimizer.step()
            scheduler.step()
            optimizer.zero_grad()
            model.zero_grad()

            input_tensor = None
    if save_model_on_epoch:
        print("Saving model...")
        torch.save(
            model.state_dict(),
            os.path.join(f"{params.output_path}"),
        )
    return model


def evaluate_fluency(params, texts):
    """
    Evaluates the fluency of some generated text by assessing the geometric mean 
    probability of each of its words according to a language model.
    """
    device = torch.device(params.device)

    # Get the tokenizer and model
    tokenizer = GPT2Tokenizer.from_pretrained('gpt2')
    # we need a padding token because we're dealing with short sentences
    tokenizer.pad_token = tokenizer.eos_token
    pad_token_id = tokenizer.convert_tokens_to_ids(tokenizer.pad_token)
    print(pad_token_id)
    state_dict = torch.load(params.output_path)
    model = GPT2LMHeadModel.from_pretrained('gpt2', state_dict=state_dict)

    model = model.to(device)
    model.eval()
    text_data = texts

    geometric_means = []
    print("Evaluating...")
    for stidx in range(0, len(text_data), params.batch_size):

        if (stidx % 10000) == 0:
            print(f"Currently at index {stidx}.")

        # prepare batch
        text_batch = text_data[stidx:stidx + params.batch_size]

        cur_batch_size = len(text_batch)
        input_tensor = tokenizer.batch_encode_plus(
            [f"{t}<|endoftext|>" for t in text_batch], padding=True)
        max_len = max([len(ids) for ids in input_tensor["input_ids"]])

        # make sure we pad att mask and input to all having the same length
        att_mask = torch.zeros(
            (cur_batch_size, max_len), device=model.device)
        inp = torch.full((cur_batch_size, max_len),
                         fill_value=pad_token_id, device=model.device)
        for i in range(len(text_batch)):
            inp_i = input_tensor["input_ids"][i]
            inp_i_len = len(inp_i)
            inp_i = torch.tensor(inp_i, device=model.device)
            inp[i, :inp_i_len] = inp_i

            att_mask_i = input_tensor["attention_mask"][i]
            att_mask_i_len = len(att_mask_i)
            att_mask_i = torch.tensor(att_mask_i, device=model.device)
            att_mask[i, :att_mask_i_len] = att_mask_i

        # print(input_tensor.size())

    #            (input_tensor, carry_on, remainder) = pack_tensor(
    #                entry, input_tensor, 768)

        # if carry_on and idx != len(train_dataloader) - 1:
        #   continue

        #inp = torch.tensor(input_tensor["input_ids"]).to(device)
        #att_mask = torch.tensor(input_tensor["attention_mask"]).to(device)
        outputs = model(inp, attention_mask=att_mask)
        logits = outputs[0]

        # we don't need the predictions after the EOS token
        logsoftmax = torch.nn.LogSoftmax(dim=-1)
        predictions = logsoftmax(logits)[:, :-1, :]
        vocab_mask = torch.arange(
            predictions.size(-1), device=device).unsqueeze(0).unsqueeze(0)
        vocab_mask = vocab_mask.expand_as(predictions)

        labels = inp[:, 1:]
        labels = labels.unsqueeze(-1).expand_as(vocab_mask)
        vocab_mask = vocab_mask == labels
        probabilities = (predictions * vocab_mask).sum(-1)

        length = att_mask[:, 1:].sum(1)
        average_log_prob = (probabilities * att_mask[:, 1:]).sum(1) / length
        geometric_means.append(average_log_prob.exp().cpu().detach())

        input_tensor = None

    return geometric_means


if __name__ == "__main__":
    parser = argparse.ArgumentParser(description='Finetune GPT2')

    # paths
    parser.add_argument("--dataset", type=str,
                        default='../data/yelp/train_all', help="Path to dataset")
    parser.add_argument("--output_path", type=str,
                        default='savedir/model.pt', help="Output path")
    parser.add_argument("--device", type=str, required=True)
    parser.add_argument("--batch_size", type=int, default=16)
    parser.add_argument("--lr", type=float, default=0.00001)
    parser.add_argument("--num_epochs", type=int, default=1)
    parser.add_argument("--train", action="store_true")
    parser.add_argument("--warmup_steps", type=int, default=0)
    params, unknown = parser.parse_known_args()
    if len(unknown) > 0:
        raise ValueError("Got unknown parameters " + str(unknown))

    print(params)

    if params.train:
        finetune_gpt(params)
    else:
        texts = read_all(params.dataset)
        geo_means = evaluate_fluency(params, texts)
        avg_score = torch.cat(geo_means, dim=0).mean()
        print("Avg score: ", avg_score.item())

    print("<<<JOB_FINISHED>>>")
