import argparse
import glob
from itertools import combinations
import json
import os
from collections import defaultdict
import copy
import random
import pickle
from typing import List

import pandas as pd
from tqdm import tqdm
import torch
from torch import nn
import numpy as np
import torch
import pytorch_lightning as pl

from torch.optim import AdamW
from transformers import BertForPreTraining, BertConfig

from pretrain_bert import Transformer, train
from .helpers import inject_knowledge


# deep copy the pl module over to a temp module
# then use trainer & checkpoint to load optimizer
def deepcopy_model(pl_model, temp_pl_model, trainer, checkpoint):
    temp_pl_model.model = copy.deepcopy(pl_model.model)

    # reinit optimizer for temp module
    trainer.optimizers, _, _ = trainer.init_optimizers(temp_pl_model)

    # restore the optimizer state
    optimizer_states = checkpoint['optimizer_states']
    for optimizer, opt_state in zip(trainer.optimizers, optimizer_states):
        optimizer.load_state_dict(opt_state)

        for state in optimizer.state.values():
            for k, v in state.items():
                if isinstance(v, torch.Tensor):
                    state[k] = v.cuda()

    optimizer = trainer.optimizers[0]

    return temp_pl_model, optimizer


def eval_model_batch(model, tokenizer, optimizer, examples, raw_batch, control=None, effective_batch_size=64):
    model.eval()

    # convert training examples to knowledge sentences and labels
    knowledge_sentences = []
    knowledge_labels = []

    for triple in examples:
        sentence = []
        label = []
        for ex in triple:
            object_ids = tokenizer.convert_tokens_to_ids([ex["object"]])
            assert len(object_ids) == 1
            object_id = object_ids[0]

            text = ex["verbalization"]
            
            if control == 'negation':
                text = text.replace(ex["object"], "not " + ex["object"])

            input = tokenizer.encode(text, add_special_tokens=False)

            mask_token_index = input.index(object_id)
            input[mask_token_index] = tokenizer.mask_token_id
            
            mask_label = [-1] * len(input)
            mask_label[mask_token_index] = object_id
            if control == 'unused':
                mask_label[mask_token_index] = tokenizer.convert_tokens_to_ids([f"[unused{int(random.random()*992)+1}]"])[0]
            elif control == 'wordnet':
                mask_label[mask_token_index] = tokenizer.convert_tokens_to_ids([ex["negative_object"]])[0]
            # negation, subject/object

            sentence += input
            label += mask_label

        knowledge_sentences.append(sentence)
        knowledge_labels.append(label)

    input_ids, attention_mask, token_type_ids, masked_lm_labels, next_sentence_label = raw_batch

    inject_result = inject_knowledge(input_ids, attention_mask, token_type_ids, masked_lm_labels, next_sentence_label,
                                     knowledge_sentences, knowledge_labels,
                                     tokenizer.cls_token_id, tokenizer.sep_token_id, tokenizer.pad_token_id)
    
    input_ids, attention_mask, token_type_ids, masked_lm_labels, next_sentence_label, masked_knowledge_labels = inject_result


    batch_size = input_ids.shape[0]
    ebs = effective_batch_size # int(batch_size/2) # effective batch size
    total_loss = 0.0
    for i in range(0, batch_size, ebs):

        with torch.no_grad():
            outputs = model(input_ids=input_ids[i:i+ebs], 
                            attention_mask=attention_mask[i:i+ebs],
                            token_type_ids=token_type_ids[i:i+ebs])
            
            prediction_scores, seq_relationship_score = outputs[:2]

            loss_fct = nn.CrossEntropyLoss(ignore_index=-1)
            masked_lm_loss = loss_fct(prediction_scores.view(-1, tokenizer.vocab_size), masked_knowledge_labels[i:i+ebs].view(-1))

            prediction_scores = prediction_scores.cpu()

            next_sentence_loss = loss_fct(seq_relationship_score.view(-1, 2), next_sentence_label[i:i+ebs].view(-1))
            total_loss = masked_lm_loss + next_sentence_loss

            total_loss += masked_lm_loss.item()

    return total_loss / (float(batch_size) / ebs)


# one gradient step on a batch of examples
def update_model(model, tokenizer, optimizer, examples, raw_batch):
    model.train()

    # convert training examples to knowledge sentences and labels
    knowledge_sentences = []
    knowledge_labels = []

    for triple in examples:
        sentence = []
        label = []
        for ex in triple:
            object_ids = tokenizer.convert_tokens_to_ids([ex["object"]])
            assert len(object_ids) == 1
            object_id = object_ids[0]

            input = tokenizer.encode(ex["verbalization"], add_special_tokens=False)

            mask_token_index = input.index(object_id)
            input[mask_token_index] = tokenizer.mask_token_id
            
            mask_label = [-1] * len(input)
            mask_label[mask_token_index] = object_id

            sentence += input
            label += mask_label

        knowledge_sentences.append(sentence)
        knowledge_labels.append(label)

    input_ids, attention_mask, token_type_ids, masked_lm_labels, next_sentence_label = raw_batch

    inject_result = inject_knowledge(input_ids, attention_mask, token_type_ids, masked_lm_labels, next_sentence_label,
                                     knowledge_sentences, knowledge_labels,
                                     tokenizer.cls_token_id, tokenizer.sep_token_id, tokenizer.pad_token_id)
    
    input_ids, attention_mask, token_type_ids, masked_lm_labels, next_sentence_label, masked_knowledge_labels = inject_result


    batch_size = input_ids.shape[0]
    ebs = int(batch_size/8) # effective batch size
    for i in range(0, batch_size, ebs):

        outputs = model(input_ids=input_ids[i:i+ebs], 
                        attention_mask=attention_mask[i:i+ebs],
                        token_type_ids=token_type_ids[i:i+ebs])
        
        prediction_scores, seq_relationship_score = outputs[:2]

        loss_fct = nn.CrossEntropyLoss(ignore_index=-1)
        masked_lm_loss = loss_fct(prediction_scores.view(-1, tokenizer.vocab_size), masked_lm_labels[i:i+ebs].view(-1))
        next_sentence_loss = loss_fct(seq_relationship_score.view(-1, 2), next_sentence_label[i:i+ebs].view(-1))
        total_loss = masked_lm_loss + next_sentence_loss

        total_loss.backward()

    optimizer.step()
    optimizer.zero_grad()
    

def evaluate(args):
    pl.seed_everything(0)

    res = []

    rule_types = ["statement", "property", "implicit_rule"]
    training_sets = [set for i in range(len(rule_types)) for set in combinations(rule_types, i)]
    training_sets = [[args.training_set]]

    with open(args.data_fname) as f:
        raw_examples = [json.loads(line) for line in f]

    # remove negation for now
    examples = []
    for ex in raw_examples:
        if not any([' not ' in ex[rule]["verbalization"] for rule in rule_types]):
            examples.append(ex)
    
    # put 1/5 of training data in each batch
    num_splits = 5
    # ex_subset = examples[:100]
    print("Total examples:", len(examples))
    ex_subset = examples

    random.shuffle(ex_subset)

    with open(args.pkl_fname, 'rb') as f:
        batch_exs = pickle.load(f)
        batch_exs = [[y.cpu() for y in x] for x in batch_exs]

    for ckpt in tqdm([9] + list(range(49999, 1000000, 50000))):
    # for ckpt in tqdm([9] + list(range(4999, 100000, 5000))):
        ckpt_fname = os.path.join(args.ckpt_dir, "epoch=0-step={}.ckpt".format(ckpt))
        print(ckpt_fname)
        pl_model = Transformer.load_from_checkpoint(ckpt_fname)
        temp_pl_model = Transformer.load_from_checkpoint(ckpt_fname)

        hparams, tokenizer, base_model = pl_model.hparams, pl_model.tokenizer, pl_model.model
        base_model = base_model.cuda().eval()

        # load checkpoint data for initializing optimizer
        checkpoint = torch.load(ckpt_fname)
        trainer = pl.Trainer(gpus=1, resume_from_checkpoint=ckpt_fname)

        log_probs = []

        ks_size = int(len(ex_subset) / num_splits)
        for repeats in range(args.num_repeats):
            for i in tqdm(range(0, len(ex_subset), ks_size)):
                ex_batch = ex_subset[i:i+ks_size]

                example_probs = []

                for train_set in training_sets:
                    raw_batch_train = copy.deepcopy(batch_exs[i])
                    raw_batch_test = copy.deepcopy(batch_exs[-1*repeats])
                    raw_batch_standard = copy.deepcopy(batch_exs[-1*repeats])

                    raw_batch_train = [x.cuda() for x in raw_batch_train]
                    raw_batch_test = [x.cuda() for x in raw_batch_test]
                    raw_batch_standard = [x.cuda() for x in raw_batch_standard]

                    training_examples = [[example[rule] for rule in train_set] for example in ex_batch]

                    # reset model and load optimizer
                    temp_pl_model, optimizer = deepcopy_model(pl_model, temp_pl_model, trainer, checkpoint)

                    optimizer.zero_grad()
                    if args.constant_learning_rate:
                        for g in optimizer.param_groups:
                            g['lr'] = 0.00001

                    model = temp_pl_model.model.cuda()

                    example_probs.append([])
                    for rule in rule_types:
                    # eval base model on statement
                        log_prob = eval_model_batch(model, tokenizer, optimizer, [[ex[rule]] for ex in ex_batch], raw_batch_standard, control=args.control)
                        example_probs[-1].append(log_prob)

                    if train_set:
                        update_model(model, tokenizer, optimizer, training_examples, raw_batch_train)

                    example_probs.append([])
                    for rule in rule_types:
                    # eval base model on statement
                        log_prob = eval_model_batch(model, tokenizer, optimizer, [[ex[rule]] for ex in ex_batch], raw_batch_test, control=args.control)
                        example_probs[-1].append(log_prob)

                log_probs.append(example_probs)

            res.append(log_probs)

    

    res = np.save(args.out_dir, np.array(res))

def main():
    parser = argparse.ArgumentParser()
    # parser.add_argument('--ckpt-dir', type=str, default="/datadrive/bert_checkpoints/")
    parser.add_argument('--ckpt-dir', type=str, default="/datadrive/checkpoints/bert_bsz256_len128_2phase/")
    parser.add_argument('--data-fname', type=str, default="downstream/logic/data/train_filtered.jsonl")
    parser.add_argument('--pkl-fname', type=str, default='downstream/logic/data/batches_256.pkl')
    parser.add_argument('--out-dir', type=str, default="downstream/logic/output/res_256.npy")
    parser.add_argument('--constant-learning-rate', action='store_true')
    parser.add_argument('--control', type=str, default="none")
    parser.add_argument('--num-repeats', type=int, default=1)

    parser.add_argument('--training-set', type=str, default="statement")

    args = parser.parse_args()

    evaluate(args)

if __name__ == "__main__":
    main()
    
