import string
from functools import partial
from itertools import chain

import pytorch_lightning as pl
from typing import Tuple, Dict, List, Callable, Any, Union, Generator

import os
import pandas as pd
import numpy as np
import hydra
import omegaconf
import pathlib
import IPython
import torch
from torch.optim import AdamW
from torch.utils.data import DataLoader
from tqdm import tqdm
import logging

from transformers import (
    AdamW,
    AutoModelWithLMHead,
    AutoTokenizer,
    T5ForConditionalGeneration,
    T5Tokenizer,
    get_linear_schedule_with_warmup,
)

from LMBenchmarkEvaluator import BaseEvaluationModule
from Utils import ClassificationDataset, my_df_flatmap

logger = logging.getLogger(__name__)

def my_handler(type, value, tb):
    logger.exception("Uncaught exception: {0}".format(str(value)))

import sys
# Install exception handler
sys.excepthook = my_handler


class T2TModel(pl.LightningModule):
    def __init__(self, config: omegaconf.dictconfig.DictConfig):
        super().__init__()
        self.hparams = dict(config)
        self.root_path = pathlib.Path(__file__).parent.absolute()

        self.tokenizer = AutoTokenizer.from_pretrained(config["model"], cache_dir="/nas/home/qasemi/model_cache", use_fast=False)
        logger.warning('Loading AutoModelWithLMHead : {}'.format(config["model"]))
        self.embedder = AutoModelWithLMHead.from_pretrained(config["model"], cache_dir="/nas/home/qasemi/model_cache")

        # self.lr_scheduler = None

    def forward(self, batch):
        results = self.embedder(**batch)
        return results

    def _step(self, batch):
        lm_labels = batch['output_tokens']["input_ids"]
        lm_labels[lm_labels[:, :] == self.tokenizer.pad_token_id] = -100
        outputs = self(
            {
                **batch['input_tokens'],
                # 'decoder_input_ids': batch['output_tokens']['input_ids'],
                'decoder_attention_mask': batch['output_tokens']['attention_mask'],
                'lm_labels': lm_labels,
            }
        )

        loss = outputs[0]

        return loss

    def training_step(self, batch, batch_idx):
        loss = self._step(batch)

        tensorboard_logs = {"train_loss": loss}
        return {
            "loss": loss,
            "log": tensorboard_logs
        }

    # def training_epoch_end(self, outputs):
    #     avg_train_loss = torch.stack([o["loss"] for o in outputs]).mean()
    #     tensorboard_logs = {"avg_train_loss": avg_train_loss}
    #     return {
    #         "avg_train_loss": avg_train_loss,
    #         "log": tensorboard_logs,
    #         'progress_bar': tensorboard_logs
    #     }

    def validation_step(self, batch, batch_idx):
        # logger.info(f'Running validation_step for ID:{batch_idx}')
        return self.test_step(batch, batch_idx)

    def test_step(self, batch, batch_idx):
        # logger.info(f'Compute loss for ID:{batch_idx}')
        loss = self._step(batch)

        # _keys = batch['input_tokens'].keys()
        # logger.info(f'Generate response with ID:{batch_idx} {_keys}')
        gen_outs = self.embedder.generate(**batch['input_tokens'], max_length=200)

        # logger.info(f'Finished step ID:{batch_idx}')
        return {
            # "loss": loss.detach().cpu(),
            "loss": loss,
            'input_text': batch['input_text'],
            'output_text': batch['output_text'],
            # 'gen_outs': gen_outs.detach().cpu(),
            'gen_outs': gen_outs
        }

    def test_epoch_end(self, outputs: List[Dict[str, Any]]) -> Dict[str, Dict[str, Any]]:
        accuracy, df_out, _loss_mean = self._process_t2t_results(outputs)

        logger.info(f'Test_Accuracy: {accuracy}')

        df_out.to_csv(f"Test_Results.csv")
        logger.info(f'Test_Avg_loss {_loss_mean}')

        self.logger.experiment.add_scalar(f'Test_Avg_loss', _loss_mean)
        self.logger.experiment.add_scalar(f'Test_Accuracy', accuracy)

        return {
            'Test accuracy': accuracy,
        }

    def validation_epoch_end(self, outputs: List[Dict[str, Any]]) -> Dict[str, Dict[str, Any]]:
        logger.info(f'Running validation_epoch_end')
        accuracy, df_out, _loss_mean = self._process_t2t_results(outputs)
        logger.info(f'Eval_Accuracy: {accuracy}')
        df_out.to_csv(f"Eval_Results.csv")
        logger.info(f'Eval_Avg_Loss {_loss_mean}')

        self.logger.experiment.add_scalar(f'Eval_Avg_Loss', _loss_mean)
        self.logger.experiment.add_scalar(f'Eval_Accuracy', accuracy)

        return {
            'Eval_accuracy': accuracy,
            "progress_bar": {
                f"Eval_accuracy": accuracy,
            }
        }

    def _process_t2t_results(self, outputs):
        # gen_f = partial(lambda o, f: f(o['gen_outs']), f=self.tokenizer.batch_decode)
        gen_f = partial(lambda o, f: f(o['gen_outs'].detach().cpu()), f=self.tokenizer.batch_decode)
        pd_output_gen = (
            pd.Series(chain.from_iterable(map(gen_f, outputs)))
                .apply(lambda s: s.split('.')[0] + '.')
                .str.strip()
        )
        pd_output_target = pd.Series(chain.from_iterable([o['output_text'] for o in outputs])).str.strip()
        pd_input = pd.Series(chain.from_iterable([o['input_text'] for o in outputs])).str.strip()
        assert len(pd_output_target) == len(pd_output_gen)
        df_out = pd.DataFrame({'Input': pd_input, 'Target': pd_output_target, 'Generated': pd_output_gen})
        accuracy = (df_out['Target'].str.lower() == df_out['Generated'].str.lower()).mean()

        _loss_mean = torch.stack([o['loss'] for o in outputs]).mean().detach().cpu()
        # _loss_mean = torch.stack([o['loss'] for o in outputs]).mean()
        return accuracy, df_out, _loss_mean

    # def configure_optimizers(self):
    #     "Prepare optimizer and schedule (linear warmup and decay)"
    #
    #     model = self.embedder
    #     no_decay = ["bias", "LayerNorm.weight"]
    #     optimizer_grouped_parameters = [
    #         {
    #             "params": [p for n, p in model.named_parameters() if not any(nd in n for nd in no_decay)],
    #             "weight_decay": self.hparams.weight_decay,
    #         },
    #         {
    #             "params": [p for n, p in model.named_parameters() if any(nd in n for nd in no_decay)],
    #             "weight_decay": 0.0,
    #         },
    #     ]
    #     optimizer = AdamW(optimizer_grouped_parameters,
    #                       lr=float(self.hparams.learning_rate),
    #                       eps=float(self.hparams.adam_epsilon))
    #     self.opt = optimizer
    #     return optimizer

    def configure_optimizers(self):
        optimizer = AdamW(self.parameters(), lr=float(self.hparams["learning_rate"]),
                          eps=float(self.hparams["adam_epsilon"]))

        return optimizer

    @pl.data_loader
    def train_dataloader(self):
        dataloader = DataLoader(
            self.dataloader(pathlib.Path(self.hparams["benchmark_path"]) / 'train.csv'),
            batch_size=self.hparams["batch_size"], collate_fn=self.collate,
            num_workers=self.hparams['cpu_limit'], shuffle=True,
        )
        # t_total = (
        #         (len(dataloader.dataset) // self.hparams.batch_size)
        #         // self.hparams.accumulate_grad_batches
        #         * float(self.hparams.max_epochs)
        # )
        # scheduler = get_linear_schedule_with_warmup(
        #     self.opt, num_warmup_steps=self.hparams.warmup_steps, num_training_steps=t_total
        # )
        #
        # self.lr_scheduler = scheduler
        return dataloader

    @pl.data_loader
    def val_dataloader(self):
        loader = DataLoader(
            self.dataloader(pathlib.Path(self.hparams["benchmark_path"]) / 'eval.csv'),
            batch_size=self.hparams["batch_size"], collate_fn=self.collate,
            num_workers=16,
            # shuffle=True
        )
        logger.debug(f'Successfully loaded validation dataset')
        return loader

    @pl.data_loader
    def test_dataloader(self) -> Union[DataLoader, List[DataLoader]]:
        return DataLoader(
            self.dataloader(pathlib.Path(self.hparams["benchmark_path"]) / 'test.csv'),
            batch_size=self.hparams["batch_size"], collate_fn=self.collate,
            num_workers=16
        )

    def _convert_to_t2t(self, ins: pd.Series) -> Generator[pd.Series, None, None]:
        cols = set(ins.index)
        if all([c in cols for c in ['choices_0']]):
            question = ins['question']
            if '?' not in question:
                question = question + '?'
            in_sent = question + "".join(
                [' ({}) {}'.format(letter, ins[f'choices_{i}']) for i, letter in enumerate(string.ascii_uppercase[:4])]
            )
            out_sent = ins['choices_{}'.format(ins['answer'])]
            yield pd.Series({
                'input': in_sent,
                'output': out_sent,
            })
        elif all([c in cols for c in ['prompt', 'refs_0']]):
            in_sent = ins['prompt']
            for i in range(3):
                yield pd.Series({
                    'input': in_sent,
                    'output': ins[f'refs_{i}'],
                })
        else:
            raise ValueError(f'{ins.index}')

    def dataloader(self, x_path: Union[str, pathlib.Path]):
        df: pd.DataFrame = pd.read_csv(x_path).fillna('')

        df_t2t = my_df_flatmap(df, func=self._convert_to_t2t)
        # logger.info(df_t2t.head())
        return ClassificationDataset(df_t2t.to_dict("record"))

    def collate(self, examples):
        batch_size = len(examples)
        df = pd.DataFrame(examples)
        results_inp = self.tokenizer.batch_encode_plus(
            df['input'].values.tolist(),
            add_special_tokens=True,
            max_length=self.hparams["max_length"],
            return_tensors='pt',
            return_token_type_ids=False,
            return_attention_masks=True,
            pad_to_max_length=True,
            truncation=True,
        )

        assert results_inp["input_ids"].shape[0] == batch_size, \
            f"Invalid shapes {results_inp['input_ids'].shape} {batch_size}"

        results_out = self.tokenizer.batch_encode_plus(
            df['output'].values.tolist(),
            add_special_tokens=True,
            max_length=self.hparams["max_length"],
            return_tensors='pt',
            return_token_type_ids=False,
            return_attention_masks=True,
            pad_to_max_length=True,
            truncation=True,
        )
        assert results_out["input_ids"].shape[0] == batch_size, \
            f"Invalid shapes {results_out['input_ids'].shape} {batch_size}"

        return {
            'input_tokens': {**results_inp},
            'input_text': df['input'].values.tolist(),
            'output_tokens': {**results_out},
            'output_text': df['output'].values.tolist(),
        }


@hydra.main(config_path='../Configs/LMBenchEval.yaml')
def main(config: omegaconf.dictconfig.DictConfig):

    _module = T2TModel(config)

    # checkpoint_callback = pl.callbacks.ModelCheckpoint(
    #     filepath='lightning_logs/checkpoint.ckpt',
    #     verbose=True,
    #     monitor='Eval_accuracy',
    #     mode='max'
    # )

    logger.info('Creating Trainer')
    trainer = pl.Trainer(
        gradient_clip_val=0,
        gpus=str(config.gpus),
        show_progress_bar=True,
        accumulate_grad_batches=config["accumulate_grad_batches"],
        limit_train_batches=config["limit_train_batches"],
        max_epochs=config["max_epochs"],
        min_epochs=1,
        val_check_interval=config['val_check_interval'],
        limit_val_batches=100,
        weights_summary='top',
        num_sanity_val_steps=config.warmup_steps,
        resume_from_checkpoint=None,
        # checkpoint_callback=True,
        # checkpoint_callback=checkpoint_callback,
        # distributed_backend=config.distributed_backend
    )

    logger.info('Running 0-Shot Results')
    trainer.test(_module)

    if config.do_train:
        logger.info('Running Train')
        trainer.fit(_module)

    logger.info('Running Trained Results')
    trainer.test(_module)


if __name__ == '__main__':
    main()

