import logging
import math
import os
import pathlib
from itertools import islice
from typing import List, Dict, Tuple, Iterable, Union
import omegaconf
import IPython
import torch
from torch.nn import CrossEntropyLoss
from torch.utils.data import DataLoader, Dataset
from tqdm import tqdm
import numpy as np
import pandas as pd
from transformers import AutoTokenizer, AutoModelWithLMHead, PreTrainedModel, PreTrainedTokenizer
from pytorch_lightning import Trainer, LightningModule


logger = logging.getLogger(__name__)


class SimpleDataset(Dataset):
    def __init__(self, data: List):
        self.data = data

    def __len__(self):
        return len(self.data)

    def __getitem__(self, item):
        return self.data[item]


class LMModule(LightningModule):
    def __init__(self, model_name: str = 'roberta-base', *args, **kwargs):
        super().__init__(*args, **kwargs)
        self.model_name = model_name
        self.cache_dir = pathlib.Path(os.path.expanduser('~/model_cache'))
        self.test_epoch_path = self.cache_dir / pathlib.Path(f'{self.__class__}_output.csv')
        logger.info(f'Loading pretrained model {self.model_name}')
        self.tokenizer: PreTrainedTokenizer = AutoTokenizer.from_pretrained(self.model_name, cache_dir=self.cache_dir)
        self.model: PreTrainedModel = AutoModelWithLMHead.from_pretrained(self.model_name, cache_dir=self.cache_dir)

        self.loss_fct = CrossEntropyLoss(reduction='none')

    def forward(self, batch):
        prediction_scores = self.model(**batch)[0]
        return prediction_scores

    def _compute_loss(self, batch, prediction_scores):
        labels = batch['input_ids']
        masked_lm_loss = self.loss_fct(
            prediction_scores.view(-1, self.model.config.vocab_size),
            labels.view(-1)
        )
        # loss = torch.exp(masked_lm_loss.reshape(labels.shape).mean(1))
        loss = masked_lm_loss.reshape(labels.shape).mean(1)
        return loss

    def test_step(self, batch: Dict[str, torch.Tensor], batch_idx) -> Dict[str, torch.Tensor]:
        lm_batch = batch.copy()
        lm_batch.pop('sentences')
        prediction_scores = self(lm_batch)
        loss = self._compute_loss(lm_batch, prediction_scores)
        return {
            'loss': loss,
            'sentences': batch['sentences'],
        }

    def test_epoch_end(self, outputs: List[Dict[str, torch.Tensor]]) -> Dict[str, np.ndarray]:
        # all_loss = torch.cat([x['loss'] for x in outputs])
        # all_sents = np.concatenate([x['sentences'] for x in outputs])
        df: pd.DataFrame = pd.concat(pd.DataFrame({
            'sentences': out['sentences'],
            'loss': out['loss'].cpu().numpy()
        }) for out in outputs)
        df.to_csv(self.test_epoch_path)
        return {
            # 'loss': all_loss.cpu().numpy(),
        }

    def get_dataloader(self, sents: Iterable[str], batch_size: int = 4):
        return DataLoader(SimpleDataset(list(sents)),
                          collate_fn=self._tokenize_batch,
                          batch_size=batch_size, drop_last=False, num_workers=4)

    def _tokenize_batch(self, batch: List[str]):

        if len(batch) == 1:
            tokenized = {
                'input_ids': torch.tensor(self.tokenizer.encode(batch[0])).reshape([1, -1])
            }
        else:
            tokenized = self.tokenizer.batch_encode_plus(
                batch,
                return_tensors='pt',
                # return_attention_masks=True,
                return_token_type_ids=True,
                add_special_tokens=True,
                max_length=256,
                truncation=True,
                padding=True,
                pad_to_max_length=True
            )



        return {
            'sentences': batch,
            **tokenized
        }


class SpellChecker:
    MODELS = ['bert-large-cased', 'roberta-base', 'gpt2',
              'facebook/bart-large', 'xlm-roberta-large', 'transfo-xl-wt103']

    def __init__(self, config: omegaconf.DictConfig):
        self.model_name = config['perplexity_model']
        assert self.model_name in self.MODELS
        logger.info('SpellChecker is using {}.'.format(self.model_name))
        self.m = LMModule(self.model_name)

    def perplexity(self, sents: List[str]) -> List[float]:
        # m = LMModule()
        output = []
        for s in tqdm(sents, desc='Perplexity'):
            out = self.calc_perplxty(s)
            output.append(out)
        return output

    def calc_perplxty(self, s):
        raise DeprecationWarning()
        # with torch.no_grad():
        #     tokenize_input = self.m.tokenizer.tokenize(s)
        #     tensor_input = torch.tensor(
        #         [self.m.tokenizer.convert_tokens_to_ids(tokenize_input)]
        #     )
        #     loss = self.m.model(tensor_input, labels=tensor_input)[0]
        #     out = torch.exp(loss).item()
        # return out

    def perplexity_corpus(self, sents: Iterable[str]) -> np.ndarray:
        batch_size = 4 if not any([s in self.model_name for s in ['transfo', 'gpt']]) else 1

        logger.info(f'Setting up loader for preplexity computation')
        loader: DataLoader = self.m.get_dataloader(sents, batch_size=batch_size)
        # perplex = self._my_evaluation_loop(loader)
        perplex = self._lightning_evaluation_loop(loader)

        return perplex

    def _lightning_evaluation_loop(self, loader):
        trainer = Trainer(gpus=[1])
        logger.info(f'Compute perplexity of sentences')
        trainer.test(model=self.m, test_dataloaders=loader)
        df = pd.read_csv(self.m.test_epoch_path)
        perplex = df['loss'].values
        return perplex

    def _my_evaluation_loop(self, loader):
        self.m.freeze()
        outputs = []
        pbar = tqdm(total=len(loader) * loader.batch_size,
                    desc='sentences checked')
        for batch_idx, batch in enumerate(loader):
            outputs.append(self.m.test_step(batch, batch_idx))
            pbar.update(loader.batch_size)
        perplex: np.ndarray = self.m.test_epoch_end(outputs)['loss']
        # assert len(perplex) <= (len(loader)*batch_size), f'{len(perplex)} < {(len(loader)*batch_size)}'
        if len(perplex) > (len(loader) * loader.batch_size):
            logger.error(f'The output is larger than the input')
            IPython.embed()
        return perplex

    def pick_best_sentence(self, corpus: Iterable[Iterable[str]]) -> Tuple[np.ndarray, np.ndarray]:
        df: pd.DataFrame = pd.DataFrame(corpus)
        # np_flat = df.fillna("").values.flatten()
        #
        # np_flat_scores = self.perplexity_corpus(np_flat)
        #
        # np_mask = df.isna().values
        # np_2d_score = np_flat_scores.reshape(df.shape)
        # np_2d_score[np_mask] = np_flat_scores.max()
        #
        # np_best_score = np_2d_score.min(axis=1)
        # np_best = np_2d_score.argmin(axis=1)
        # np_best_sent = np.take_along_axis(df.values, np_best.reshape([-1, 1]), axis=1).flatten()

        np_none_flat = df.values.flatten()
        np_flat = np_none_flat[~pd.isna(np_none_flat)]

        np_flat_scores = self.perplexity_corpus(np_flat)

        np_none_scores = np.zeros_like(np_none_flat)
        np_none_scores[~pd.isna(np_none_flat)] = np_flat_scores
        np_none_scores[pd.isna(np_none_flat)] = np_flat_scores.max()

        np_2d_scores = np_none_scores.reshape(df.shape)

        np_best_score = np_2d_scores.min(axis=1)
        np_best = np_2d_scores.argmin(axis=1)
        np_best_sent = np.take_along_axis(df.values, np_best.reshape([-1, 1]), axis=1).flatten()

        return np_best_sent, np_best_score


def main():
    from pprint import pprint

    def _generate_templates() -> Dict[str, List[str]]:
        mappings = {
            'UsedFor': ['is typically used for', 'is typically used to', 'is typically used by'],
            'CapableOf': ['is typically capable of', 'is typically capable to'],
            'Causes': ['typically causes'],
            'CausesDesire': ['typically causes desire of'],
            'Desires': ['typically desires'],
            'NotDesires': ['does not typically desire', 'does not typically desire to'],
            'RelatedTo': ['is typically related to', 'is typically related'],
            'IsA': ['is'],
            'PartOf': ['is part of'],
            'CreatedBy': ['is created by'],
            'MannerOf': ['is a manner of', 'is manner of'],

        }

        # create templates related to predicate verbalizaiton
        all_rules: Dict[str, List[str]] = {
            k.lower(): ['{s} {p} {o}'.replace('{p}', v) for v in vl]
            for k, vl in mappings.items()
        }

        # add templates with A and An
        for k, rl in all_rules.items():
            for r in rl:
                all_rules[k] = all_rules[k] + [f'A {r}', f'An {r}']

        # add the trailing A and An
        for k, rl in all_rules.items():
            for r in rl:
                all_rules[k] = all_rules[k] + [r.replace('{p}', '{p} a'), r.replace('{p}', '{p} an')]

        logger.info(f'##########\nall rules:\n{all_rules}\n##########\n')
        return all_rules

    spl = SpellChecker({})
    triples = [
        ('net', 'catch fish'),
        ('play basketball', 'recreation'),
        ('jump at chance', 'opportunity'),
        ('shallow water', 'childrens pool'),
    ]
    rules = _generate_templates()
    # pprint(rules)
    # iterative_test(rules, spl, triples)

    all_sents = [[r.format(s=s, o=o) for r in rules['usedfor']] for (s, o) in triples]
    output = spl.pick_best_sentence(all_sents)


if __name__ == '__main__':
    main()

