from typing import List
import pickle

import torch
from torch import nn
import numpy as np
import pytorch_lightning as pl
from pytorch_lightning import loggers as pl_loggers
from pytorch_lightning import callbacks as pl_callbacks
from pytorch_lightning.plugins import DDPPlugin
import argparse
import os
from torch.optim import AdamW
from transformers import (
    BertForPreTraining,
    AutoTokenizer,
    get_linear_schedule_with_warmup,
    BertConfig
)

from pruning import PruningCallback
from global_vars import Timers
from data.bert_dataset import BertDatasetModule
import argparse
from optimizers import Lamb
from arguments import (
    _add_regularization_args, _add_data_args, _add_initialization_args,
    _add_learning_rate_args,
    _add_network_size_args, _add_training_args,
    _add_checkpointing_args
)


# def inject_knowledge(batch: torch.tensor, knowledge_sentences: List[List[int]],
#                      cls_id: int, sep_id: int, pad_id: int):
def inject_knowledge(input_ids: torch.tensor, attention_mask: torch.tensor, token_type_ids: torch.tensor, 
                     masked_lm_labels: torch.tensor, next_sentence_label: torch.tensor,
                     knowledge_sentences: List[List[int]], knowledge_labels: List[List[int]],
                     cls_id: int, sep_id: int, pad_id: int):

    batch = input_ids

    num_ksentences = len(knowledge_sentences)
    assert len(batch) >= num_ksentences

    aux_bs_infos = []
    row_indexed_infos = dict()
    for idx, sentence in enumerate(batch):
        sent_length = torch.sum(sentence != pad_id).item()
        sep_pos = torch.where(sentence == sep_id)[0][0].item()
        aux_bs_infos.append(dict(
            row=idx, length=sep_pos - 1, pos='a'
        ))
        aux_bs_infos.append(dict(
            row=idx, length=sent_length - sep_pos - 1, pos='b'
        ))
        row_indexed_infos[idx] = (
                                    sentence[1:sep_pos].tolist(),
                                    sentence[sep_pos + 1:sent_length - 1].tolist(),
                                    masked_lm_labels[idx][1:sep_pos].tolist(),
                                    masked_lm_labels[idx][sep_pos + 1:sent_length - 1].tolist()
                                )

    # sort by length
    knowledge_sentences = sorted(knowledge_sentences, key=lambda x: len(x))
    aux_bs_infos = sorted(aux_bs_infos, key=lambda x: x['length'])

    # find shortest pairing that doesn't use same row in the batch
    # greedy, we could use hungarian but i didn't know how to avoid choosing
    # the same batch row
    pairs = []
    used_rows = set()
    for i, ks in enumerate(knowledge_sentences):
        len_ks = len(ks)
        for bs_info in aux_bs_infos:
            if bs_info['row'] not in used_rows and len_ks <= bs_info['length']:
                pairs.append(((ks, knowledge_labels[i]), bs_info))
                break
        used_rows.add(bs_info['row'])
    assert len(pairs) == len(knowledge_sentences)

    # rebuild the batch row, bcs len_ks is always inferior to the sentence length
    # we're good
    for k_info, bs_info in pairs:
        ks, kl = k_info
        tokens_a, tokens_b, labels_a, labels_b = row_indexed_infos[bs_info['row']]
        if bs_info['pos'] == 'a':
            tokens_a, labels_a = ks, kl
        elif bs_info['pos'] == 'b':
            tokens_b, labels_b = ks, kl
        
        tokens = [cls_id] + tokens_a + [sep_id] + tokens_b + [sep_id]
        sentence = torch.zeros(batch.size(1)) + pad_id
        sentence[:len(tokens)] = torch.LongTensor(tokens)
        batch[bs_info['row']] = sentence

        attention_mask_row = torch.zeros(batch.size(1))
        attention_mask_row[:len(tokens)] = 1
        attention_mask[bs_info['row']] = attention_mask_row

        token_type_ids_row = torch.ones(batch.size(1))
        token_type_ids_row[:len(tokens_a) + 1] = 0
        token_type_ids[bs_info['row']] = token_type_ids_row

        labels = [-1] + labels_a + [-1] + labels_b + [-1]
        sentence_labels = torch.zeros(batch.size(1)) - 1
        sentence_labels[:len(tokens)] = torch.LongTensor(labels)
        masked_lm_labels[bs_info['row']] = sentence_labels

        next_sentence_label[bs_info['row']] = 1

    return input_ids, attention_mask, token_type_ids, masked_lm_labels, next_sentence_label


class StopCallback(pl.Callback):
    def __init__(self, stop_iter=None) -> None:
        self.stop_iter = stop_iter

    def on_batch_end(self, trainer: 'pl.Trainer', pl_module: 'pl.LightningModule') -> None:
        if self.stop_iter and trainer.global_step > 0 and trainer.global_step >= self.stop_iter:
            should_stop = True
            should_stop = trainer.training_type_plugin.reduce_boolean_decision(should_stop)
            trainer.should_stop = should_stop


class Transformer(pl.LightningModule):
    def __init__(
        self,
        num_layers: int,
        hidden_size: int,
        num_attention_heads: int,
        seq_length: int,
        max_position_embeddings: int,
        train_iters: int,
        save: str,
        load: str,
        data_path: str,
        tokenizer_model_type: str = 'bert-base-uncased',
        lr: float = 1e-4,
        lr_decay_style: str = 'linear',
        min_lr: float = 1e-5,
        lr_decay_iters: int = 990000,
        weight_decay: float = 1e-2,
        adam_epsilon: float = 1e-8,
        warmup: float = .01,
        clip_grad: float = 1.0,
        **kwargs
    ):
        super().__init__()
        self.timers = Timers()
        self.save_hyperparameters()
        self.tokenizer = AutoTokenizer.from_pretrained(tokenizer_model_type)

        tokenizer = self.tokenizer
        self.cls_id = tokenizer.vocab[tokenizer._cls_token]
        self.sep_id = tokenizer.vocab[tokenizer._sep_token]
        self.mask_id = tokenizer.vocab[tokenizer._mask_token]
        self.pad_id = tokenizer.vocab[tokenizer._pad_token]

        self.config = BertConfig(
            vocab_size_or_config_json_file=self.tokenizer.vocab_size,
            hidden_size=hidden_size,
            num_hidden_layers=num_layers,
            num_attention_heads=num_attention_heads,
            intermediate_size=4 * hidden_size,
            hidden_act='gelu',
            hidden_dropout_prob=0.1,
            attention_probs_dropout_prob=0.1,
            max_position_embeddings=max_position_embeddings,
            type_vocab_size=2,
            initializer_range=0.02,
            layer_norm_eps=1e-12,
            pad_token_id=0,
            gradient_checkpointing=False,
            position_embedding_type='absolute',
            use_cache=True
        )

        # self.config = BertConfig(
        #     vocab_size_or_config_json_file=self.tokenizer.vocab_size,
        #     hidden_size=hidden_size,
        #     num_hidden_layers=1,
        #     num_attention_heads=num_attention_heads,
        #     intermediate_size=1 * hidden_size,
        #     hidden_act='gelu',
        #     hidden_dropout_prob=0.1,
        #     attention_probs_dropout_prob=0.1,
        #     max_position_embeddings=max_position_embeddings,
        #     type_vocab_size=2,
        #     initializer_range=0.02,
        #     layer_norm_eps=1e-12,
        #     pad_token_id=0,
        #     gradient_checkpointing=False,
        #     position_embedding_type='absolute',
        #     use_cache=True
        # )
        # self.training=False

        self.model = BertForPreTraining(self.config)
        self.loss_fct = nn.CrossEntropyLoss(ignore_index=-1)
        self._total_loss_dict = dict()

    def forward(self, **inputs):
        self.timers('forward').start()

        input_ids=inputs['text']
        attention_mask=inputs['padding_mask']
        token_type_ids=inputs['types']
        masked_lm_labels = inputs['labels']
        next_sentence_label = inputs['is_random']

        pkl_fname = 'downstream/knowledge/data/batches_256.pkl'

        exs = [[input_ids, attention_mask, token_type_ids, masked_lm_labels, next_sentence_label]]

        if self.training:
            if os.path.exists(pkl_fname):
                with open(pkl_fname, 'rb') as f:
                    old_exs = pickle.load(f)
                    exs = old_exs + exs
                    if len(exs) >= 1000:
                        exit(1)

            with open(pkl_fname, 'wb') as f:
                pickle.dump(exs, f)

        # knowledge_sentences = [[1000, 2000, 3000],
        #                        [4000, 5000, 6000]]
        # knowledge_labels    = [[0, 10, 0],
        #                        [0, 13, 0]]
        
        # inject_result = inject_knowledge(input_ids, attention_mask, token_type_ids, masked_lm_labels, next_sentence_label,
        #                  knowledge_sentences, knowledge_labels,
        #                  self.cls_id, self.sep_id, self.pad_id)
        # input_ids, attention_mask, token_type_ids, masked_lm_labels, next_sentence_label = inject_result

        outputs = self.model(input_ids=inputs['text'], attention_mask=inputs['padding_mask'],
                             token_type_ids=inputs['types'])
        masked_lm_labels = inputs['labels']
        next_sentence_label = inputs['is_random']

        prediction_scores, seq_relationship_score = outputs[:2]
        masked_lm_loss = self.loss_fct(prediction_scores.view(-1, self.config.vocab_size), masked_lm_labels.view(-1))
        next_sentence_loss = self.loss_fct(seq_relationship_score.view(-1, 2), next_sentence_label.view(-1))
        total_loss = masked_lm_loss + next_sentence_loss
        outputs = (total_loss, masked_lm_loss, next_sentence_loss)
        self.timers('forward').stop()
        return outputs

    def on_train_start(self):
        self.timers('interval time').start()

    def training_step(self, batch, batch_idx):
        # print(self.lr_schedulers().optimizer.param_groups[0].get('lr'))
        self.training = True
        outputs = self(**batch)
        self.training = False
        loss = outputs[0]

        loss_dict = {'train/loss': outputs[0].item(),
                     'train/mlm_loss': outputs[1].item(),
                     'train/sop_loss': outputs[2].item()}
        for key in loss_dict:
            self._total_loss_dict[key] = self._total_loss_dict.get(key, 0.) + loss_dict[key]
        self.log_dict(loss_dict, sync_dist=True)
        return {'loss': loss, 'mlm_loss': outputs[1], 'sop_loss': outputs[2]}

    def validation_step(self, batch, batch_idx, dataloader_idx=0):
        outputs = self(**batch)
        val_loss = outputs[0]
        ppl = np.exp(min(20, val_loss.item()))
        self.log_dict({'valid/loss': val_loss, 'valid/ppl': ppl}, sync_dist=True)
        return {'loss': val_loss}

    def configure_optimizers(self):
        "Prepare optimizer and schedule (linear warmup and decay)"
        model = self.model
        no_decay = ["bias", "LayerNorm.weight"]
        optimizer_grouped_parameters = [
            {
                "params": [p for n, p in model.named_parameters() if not any(nd in n for nd in no_decay)],
                "weight_decay": self.hparams.weight_decay,
            },
            {
                "params": [p for n, p in model.named_parameters() if any(nd in n for nd in no_decay)],
                "weight_decay": 0.0,
            },
        ]
        if self.hparams.lr_decay_iters is not None:
            num_iters = self.hparams.lr_decay_iters
        else:
            num_iters = self.hparams.train_iters

        num_iters = max(1, num_iters)
        warmup_iter = self.hparams.warmup * num_iters

        # if self.hparams.use_lamb:
        #     optimizer = Lamb(optimizer_grouped_parameters, lr=self.hparams.lr, eps=self.hparams.adam_epsilon)
        # else:
        optimizer = AdamW(optimizer_grouped_parameters, lr=self.hparams.lr, eps=self.hparams.adam_epsilon)
        
        scheduler = get_linear_schedule_with_warmup(
            optimizer, num_warmup_steps=warmup_iter, num_training_steps=num_iters
        )

        scheduler = {
            'scheduler': scheduler,
            'interval': 'step',
            'frequency': 1
        }
        return [optimizer], [scheduler]


class ConsoleCallback(pl.Callback):
    def __init__(self, every_n_step):
        self.every_n_step = every_n_step

    def on_batch_end(self, trainer, pl_module):
        if trainer.global_step % self.every_n_step == 0 and trainer.global_step != 0:
            log_string = ' iteration {:8d}/{:8d} |'.format(trainer.global_step, args.train_iters)
            num_iterations = args.log_interval

            for key in pl_module._total_loss_dict:
                avg = pl_module._total_loss_dict[key] / float(num_iterations)
                log_string += ' {}: {:.6E} |'.format(key, avg)
                pl_module._total_loss_dict[key] = 0.0

            # Logging.
            timers_to_log = []

            def add_to_logging(name):
                if name in pl_module.timers.timers:
                    timers_to_log.append(name)

            add_to_logging('forward')
            add_to_logging('backward-backward')
            add_to_logging('interval time')

            if pl_module.global_rank == 0:
                print(log_string)

            pl_module.timers.log(timers_to_log, normalizer=args.log_interval)


def train(args):
    pl.seed_everything(args.seed, workers=True)

    args.rank = 0
    args.model_parallel_size = 1

    # build tokenizer and assign max position embeddings here
    # so that we keep track of this in the trained model
    tokenizer = AutoTokenizer.from_pretrained(args.tokenizer_model_type)
    tokenizer.max_len = args.max_position_embeddings

    data_module = BertDatasetModule(args, tokenizer)
    cb_pruning = PruningCallback(**args.__dict__)

    logger = pl_loggers.TensorBoardLogger(args.tensorboard_dir)
    csv_logger = pl_loggers.CSVLogger(args.tensorboard_dir)
    cb_lr = pl_callbacks.LearningRateMonitor()

    # pretraining
    cb_checkpoint = pl.callbacks.ModelCheckpoint(
        args.save,
        save_top_k=-1,                                                                                                                                              
        save_last=True,
        every_n_train_steps=args.save_interval or 50000)

    last_ckpt = args.save + "/last.ckpt"
    resume_ckpt = last_ckpt if os.path.exists(last_ckpt) else None
    print(f"Starting from checkpoint: {resume_ckpt}.")

    model = Transformer(**args.__dict__)
    trainer = pl.Trainer(
        max_steps=args.train_iters,
        gpus=args.num_gpus,
        num_nodes=args.num_nodes,
        precision=16,
        plugins=DDPPlugin(find_unused_parameters=False),
        gradient_clip_val=args.clip_grad,
        progress_bar_refresh_rate=args.log_interval,
        accumulate_grad_batches=args.accumulate_grad_batches,
        limit_val_batches=args.eval_iters,
        val_check_interval=args.eval_interval,
        callbacks=[cb_checkpoint, cb_lr, cb_pruning, StopCallback(args.stop_iter)],
        distributed_backend='ddp',
        logger=[logger, csv_logger],
        resume_from_checkpoint=resume_ckpt
    )

    if args.save_initial_checkpoint:
        ckpt_path = args.save + "/checkpoint-0/"
        print(ckpt_path)
        # os.mkdir(ckpt_path)
        model.model.save_pretrained(ckpt_path)
        model.tokenizer.save_pretrained(ckpt_path)
    
    trainer.fit(model, data_module)


if __name__ == "__main__":
    parser = argparse.ArgumentParser()

    _add_regularization_args(parser)
    _add_checkpointing_args(parser)
    _add_initialization_args(parser)
    _add_learning_rate_args(parser)
    _add_network_size_args(parser)
    _add_data_args(parser)
    _add_training_args(parser)
    parser = PruningCallback.add_argparse_args(parser)

    args = parser.parse_args()
    train(args)
