import pathlib

import IPython
import hydra
from typing import *

import omegaconf
import torch
import pytorch_lightning as pl
import torch.nn as nn
import pandas as pd
import numpy as np
# from loguru import logger
from torch.utils.data import DataLoader
from transformers import AutoModel, AutoTokenizer, AdamW
from sklearn.metrics import confusion_matrix, f1_score

import logging

from LMBenchmarkEvaluator import BaseEvaluationModule, BaseLMModule
from Utils import ClassificationDataset

logger = logging.getLogger(__name__)


class MCEvaluationModule(BaseEvaluationModule):
    def __init__(self, config):
        super().__init__(config)

    def dataloader(self, x_path: Union[str, pathlib.Path]):
        df: pd.DataFrame = pd.read_csv(x_path, index_col=None).fillna('')
        return ClassificationDataset(df.to_dict("record"))

    def collate(self, examples):
        batch_size = len(examples)
        df = pd.DataFrame(examples)
        n_choices = df.shape[1] - 2

        # TODO: remove this '?' in future
        sents = [(r['question']+'?')+r[f'choices_{c}'] for i_r, r in df.iterrows() for c in range(n_choices)]

        results = self.tokenizer.batch_encode_plus(
            sents,
            add_special_tokens=True,
            max_length=self.hparams["max_length"],
            return_tensors='pt',
            return_token_type_ids=True,
            # return_attention_masks=True,
            pad_to_max_length=True,
            truncation=True,
        )

        # torch_labels = torch.nn.functional.one_hot(
        #     torch.from_numpy(df['answer'].values), num_classes=n_choices)
        torch_labels = torch.from_numpy(df['answer'].values)

        assert results["input_ids"].shape[0] == batch_size * n_choices, \
            f"Invalid shapes {results['input_ids'].shape} {batch_size} {n_choices}"

        assert torch_labels.shape[0]*n_choices == results["input_ids"].shape[0], \
            f"Invalid shapes {results['input_ids'].shape} {torch_labels.shape}"
        return {
            **results,
            "labels": torch_labels,
            "text": sents,
        }


class MCLMModule(MCEvaluationModule, BaseLMModule):
    def __init__(self, config):
        super().__init__(config)
        logger.warning('AutoModel : {}'.format(config["model"]))
        self.embedder = AutoModel.from_pretrained(config["model"], cache_dir="/nas/home/qasemi/model_cache")
        self.embedder.train()

        self.dropout = nn.Dropout(self.hparams['hidden_dropout_prob'])

        self.classifier = nn.Linear(self.embedder.config.hidden_size, 1, bias=True)
        self.classifier.weight.data.normal_(mean=0.0, std=self.embedder.config.initializer_range)
        self.classifier.bias.data.zero_()

    def forward(self, batch):

        assert len(batch["input_ids"].shape) == 2, "LM only take two-dimensional input"
        assert len(batch["attention_mask"].shape) == 2, "LM only take two-dimensional input"
        assert len(batch["token_type_ids"].shape) == 2, "LM only take two-dimensional input"

        batch["token_type_ids"] = None if "roberta" in self.hparams["model"] else batch["token_type_ids"]

        results = self.embedder(input_ids=batch["input_ids"], attention_mask=batch["attention_mask"],
                                token_type_ids=batch["token_type_ids"])

        token_embeddings, *_ = results

        pooled = None
        if self.hparams['pooling_method'].lower() == 'cls':
            pooled = token_embeddings[:, 0, :]
        elif self.hparams['pooling_method'].lower() == 'mean':
            pooled = token_embeddings.mean(dim=1)
        else:
            raise ValueError('Invalid Pooling {}'.format(self.hparams['pooling_method']))

        grouped_by_question = self.dropout(pooled)
        # .reshape([-1, self.embedder.config.hidden_size*4])
        logits = torch.nn.functional.softmax(
            self.classifier(grouped_by_question).reshape([-1, 4]),
            dim=1
        )
        return logits

    def _collect_evaluation_results(self, outputs, mytag):
        _loss_mean = torch.stack([o[f'{mytag}_loss'] for o in outputs]).mean()
        _logits = torch.cat([o[f"{mytag}_batch_logits"] for o in outputs])
        _labels = torch.cat([o[f"{mytag}_batch_labels"] for o in outputs])
        val_acc = torch.sum(_labels == torch.argmax(_logits, dim=1)) / (_labels.shape[0] * 1.0)

        # f1_score = self._compute_f1_score(_labels, _logits)

        logger.info(f'{mytag}_acc={val_acc}, {mytag}_loss={_loss_mean}')

        self.logger.experiment.add_scalar(f'{mytag}_loss', _loss_mean)
        self.logger.experiment.add_scalar(f'{mytag}_acc', val_acc)

        # all_text =
        df = pd.DataFrame.from_dict({
            'predicted_label': torch.argmax(_logits, dim=1),
            'true_label': _labels.detach().cpu().numpy(),
        }, orient='columns')

        df_text = (
            pd.DataFrame(np.array([s for o in outputs for s in o['text']]).reshape(-1, 4))
            .apply(
                func=self._sents_to_mc,
                axis=1
            )
        )
        fixed_cols = df_text.columns
        df[fixed_cols] = df_text

        df.to_csv(f"{mytag}_dump.csv")

        df[df['true_label'] != df['predicted_label']].apply(
            axis=1,
            func=lambda r: pd.Series({
                **{c: r[c]for c in fixed_cols},
                'true_label': r['true_label'],
                'predicted_label': r['predicted_label']
            })
        ).to_csv(
            f'{mytag}_errors.csv'
        )

        return {
            f'{mytag}_loss': _loss_mean,
            "progress_bar": {
                f"{mytag}_accuracy": val_acc,
            }
        }

    @staticmethod
    def _sents_to_mc(pd_s: pd.Series) -> pd.Series:
        df_qc = pd_s.apply(lambda r: pd.Series({k: v for k, v in zip(['q', 'a'], r.split('?'))}))

        q_set = set(df_qc['q'])
        if len(q_set) == 1:
            q = q_set.pop()
            if '?' not in q:
                q = q + '?'
        else:
            q = np.nan

        a_set = set(df_qc['a'])
        if len(a_set) == 4:
            a_list = list(a_set)
        else:
            a_list = [np.nan] * 4

        return pd.Series({
            'question': q,
            **{f'choices_{i}': c for i, c in enumerate(a_list)}
        })


@hydra.main(config_path='../Configs/LMBenchEval.yaml')
def main(config: omegaconf.dictconfig.DictConfig):
    _module = MCLMModule(config)
    trainer = pl.Trainer(
        gradient_clip_val=0,
        gpus=config.gpus,
        show_progress_bar=True,
        accumulate_grad_batches=config["accumulate_grad_batches"],
        max_epochs=config["max_epochs"],
        min_epochs=1,
        val_check_interval=config['val_check_interval'],
        weights_summary='top',
        num_sanity_val_steps=config.warmup_steps,
        resume_from_checkpoint=None,
    )
    trainer.fit(_module)
    trainer.test(_module)


if __name__ == '__main__':
    main()