import argparse
import logging
import os

import sacrebleu
import torch
import torch.nn as nn
import pytorch_lightning as pl
from torch.optim import AdamW
from transformers.optimization import get_cosine_schedule_with_warmup
from deepspeed.ops.adam import FusedAdam, DeepSpeedCPUAdam

from utils.file_utils import TorchFileModule

parser = argparse.ArgumentParser()
torch.set_num_threads(1)
# logger = get_logger()


class Base(pl.LightningModule):
    def __init__(self, hparams, **kwargs) -> None:
        super(Base, self).__init__()
        self.hparam_args = hparams
        self.beam_size = hparams.beam_size
        self.max_len = hparams.max_len
        self.fileutils = TorchFileModule()
        self.ckpt_dir = os.path.join(hparams.save_filename, 'checkpoints')
        self.gen_dir = os.path.join(hparams.save_filename, 'gen_files')
        self.ckpt_save_num = hparams.ckpt_save_num
        self.bos_token = '<s>'
        self.eos_token = '</s>'

    def configure_optimizers(self):
        # Prepare optimizer
        param_optimizer = list(self.model.named_parameters())
        no_decay = ['bias', 'LayerNorm.bias', 'LayerNorm.weight']
        optimizer_grouped_parameters = [
            {'params': [p for n, p in param_optimizer if not any(
                nd in n for nd in no_decay)], 'weight_decay': 0.001},
            {'params': [p for n, p in param_optimizer if any(
                nd in n for nd in no_decay)], 'weight_decay': 0.0}
        ]

        if self.hparam_args.optimizer == 'AdamW':
            optimizer = AdamW(optimizer_grouped_parameters, lr=self.hparam_args.lr)
        elif self.hparam_args.optimizer == 'FusedAdam':
            optimizer = FusedAdam(optimizer_grouped_parameters, lr=self.hparam_args.lr)
        elif self.hparam_args.optimizer == 'DeepSpeedCPUAdam':
            optimizer = DeepSpeedCPUAdam(optimizer_grouped_parameters, lr=self.hparam_args.lr)
        else:
            raise Exception('optimizer setting error')

        # warm up lr
        num_workers = (self.hparam_args.gpus if self.hparam_args.gpus is not None else 1) * (
            self.hparam_args.num_nodes if self.hparam_args.num_nodes is not None else 1)

        self.trainer.reset_train_dataloader()
        data_len = len(self.trainer.datamodule.train_dataloader().dataset)

        logging.info(f'number of workers {num_workers}, data length {data_len}')
        if self.hparam_args.max_epochs is not None:
            num_train_steps = int(data_len / (self.hparam_args.batch_size * num_workers) * self.hparam_args.max_epochs)
        else:
            num_train_steps = self.hparam_args.max_steps
        logging.info(f'num_train_steps : {num_train_steps}')
        num_warmup_steps = int(num_train_steps * self.hparam_args.warmup_ratio)
        logging.info(f'num_warmup_steps : {num_warmup_steps}')
        scheduler = get_cosine_schedule_with_warmup(
            optimizer,
            num_warmup_steps=num_warmup_steps, num_training_steps=num_train_steps)
        lr_scheduler = {'scheduler': scheduler,
                        'monitor': 'loss', 'interval': 'step',
                        'frequency': 1}
        return [optimizer], [lr_scheduler]


class PlModelModule(Base):
    def __init__(self, tokenizer, model, hparams, **kwargs):
        super(PlModelModule, self).__init__(hparams, **kwargs)
        self.tokenizer = tokenizer
        self.model = model
        self.model.train()  # 이걸로 dropout activate

        self.pad_token_id = self.tokenizer.pad_token_id
        self.bos_token_id = self.tokenizer.bos_token_id
        self.eos_token_id = self.tokenizer.eos_token_id

        self.loss_fn = nn.CrossEntropyLoss(ignore_index=self.pad_token_id)
        self.vocab_size = self.model.model.config.vocab_size

        self.model.set_external(
            loss_fn=self.loss_fn,
            vocab_size=self.vocab_size,
            mask_token_id=self.tokenizer.mask_token_id
        )

    def make_input(self, inputs):
        output = {}
        nmt_src_attention_mask = inputs['input_ids'].ne(self.pad_token_id).float()
        nmt_decoder_attention_mask = inputs['dec_input_ids'].ne(self.pad_token_id).float()
        output['src_attention_mask'] = nmt_src_attention_mask
        output['decoder_attention_mask'] = nmt_decoder_attention_mask
        output['src_ids'] = inputs['input_ids']
        output['labels'] = inputs['label_ids']
        output['decoder_input_ids'] = inputs['dec_input_ids']
        return output

    def forward(self, inputs):
        modified_inputs = self.make_input(inputs)
        return self.model(current_step=self.global_step, return_dict=True, **modified_inputs)

    def training_step(self, batch, batch_idx):
        self.model.train()
        model_out = self(batch)
        return model_out["loss"]

    def save_middle_ckpt(self, loss):
        self.fileutils.save_one(
            plself=self,
            loss=loss,
            bleu=0,
            filename=os.path.join(
                self.ckpt_dir, 'step{}.pt'.format(str(self.global_step))
            )
        )

    @torch.no_grad()
    def validation_step(self, batch, batch_idx):
        self.model.eval()
        modified_inputs = self.make_input(batch)

        n = modified_inputs['labels'].shape[0]
        input_temp = modified_inputs['src_ids']
        outs = self.model.model.generate(
            input_temp,
            num_beams=self.beam_size,
            max_length=self.max_len,
            early_stopping=True
        )

        nmt_input = self.tokenizer.batch_decode(
            input_temp.detach().cpu().tolist(), skip_special_tokens=False
        )
        nmt_candidate = self.tokenizer.batch_decode(
            outs.detach().cpu().tolist(), skip_special_tokens=True
        )
        nmt_reference = self.tokenizer.batch_decode(
            modified_inputs['labels'].detach().cpu().tolist(), skip_special_tokens=True
        )

        _, nmt_loss = self.model.forward_seq2seq(**modified_inputs)
        assert n is not None
        return n, nmt_loss, nmt_candidate, nmt_reference, nmt_input

    def validation_epoch_end(self, outputs):
        outs = {'loss': [], 'candidate': [], 'reference': [], 'input': []}
        tot = 0
        for n, nmt_loss, nmt_candidate, nmt_reference, nmt_input in outputs:
            outs['loss'].append(nmt_loss * n)
            outs['candidate'].extend(nmt_candidate)
            outs['reference'].extend(nmt_reference)
            outs['input'].extend(nmt_input)
            tot += n

        print('current steps: ', self.global_step)

        loss = torch.sum(torch.tensor(outs['loss'], dtype=torch.float)) / tot

        tmp = [(i, j) for i, j in zip(outs['candidate'], outs['reference'])
               if (j != '') and (i != '')]
        candis = list(map(lambda x: x[0], tmp))
        refs = list(map(lambda x: x[1], tmp))

        bleu = sacrebleu.corpus_bleu(candis, [refs]).score

        if self.current_epoch == 0:
            self.fileutils.write_lines(refs, os.path.join(self.gen_dir, 'ref.ref'))
            self.fileutils.write_lines(outs['input'], os.path.join(self.gen_dir, 'inp.inp'))
            self.best_bleu = bleu
        else:
            self.best_bleu = max(bleu, self.best_bleu)

        candi_filename = f'epoch{format(self.current_epoch, "03")}_bleu{format(bleu, ".4f")}.candi'
        self.fileutils.write_lines(candis, os.path.join(self.gen_dir, candi_filename))

        self.fileutils.ckpt_save(plself=self, loss=loss, score=bleu, score_name='bleu')

        self.log('valid/nmt_loss', loss, prog_bar=False)
        self.log("current_step", self.global_step, prog_bar=True)
        self.log("valid/epoch_end_bleu", bleu, prog_bar=True)
        self.log("valid/best_bleu", self.best_bleu, prog_bar=True)

