from typing import Dict, Union, List

import numpy as np
import omegaconf
import pytorch_lightning as pl
import torch
from torch.utils.data import DataLoader
from transformers import AutoTokenizer, AutoModel
import logging

from Utils import ClassificationDataset
logger = logging.getLogger(__name__)


class MyEmbedder(pl.LightningModule):
    def __init__(self, config: omegaconf.dictconfig.DictConfig, tmp_path: str):
        super().__init__()
        self.hparams = dict(config)
        self.tmp_path = tmp_path
        self.tokenizer = AutoTokenizer.from_pretrained(config["language_model"],
                                                       cache_dir="/nas/home/qasemi/model_cache",
                                                       use_fast=False)
        self.embedder = AutoModel.from_pretrained(config["language_model"],
                                                  cache_dir="/nas/home/qasemi/model_cache")
        self.embedder.eval()

    def forward(self, batch):
        assert len(batch["input_ids"].shape) == 2, "LM only take two-dimensional input"
        assert len(batch["attention_mask"].shape) == 2, "LM only take two-dimensional input"
        assert len(batch["token_type_ids"].shape) == 2, "LM only take two-dimensional input"

        batch["token_type_ids"] = None if "roberta" in self.hparams["language_model"] else batch["token_type_ids"]

        results = self.embedder(input_ids=batch["input_ids"], attention_mask=batch["attention_mask"],
                                token_type_ids=batch["token_type_ids"])

        token_embeddings, *_ = results

        pooled = None
        if self.hparams['pooling_method'].lower() == 'cls':
            pooled = token_embeddings[:, 0, :]
        elif self.hparams['pooling_method'].lower() == 'mean':
            pooled = token_embeddings.mean(dim=1)
        else:
            raise ValueError('Invalid Pooling {}'.format(self.hparams['pooling_method']))

        return pooled

    def test_step(self, batch, batch_idx) -> Dict[str, torch.Tensor]:
        return {
            "embeddings": self(batch).detach().cpu().numpy(),
            "sents": batch['sents'],
        }

    def test_end(self, outputs):
        _embeddings = np.concatenate([o[f"embeddings"] for o in outputs], axis=0)
        _sents = [s for o in outputs for s in o[f"sents"]]

        assert len(_sents) == len(_embeddings), f'{len(_sents)} <> {_embeddings.shape}'
        # df = pd.DataFrame.from_dict({
        #     'embd': _embeddings.detach().cpu().numpy(),
        #     'sent': _sents,
        # })
        # np.save(self.tmp_path, _embeddings.detach().cpu().numpy())
        np.save(self.tmp_path, _embeddings)
        # df.to_csv(self.tmp_path)
        logger.warning("ending the embedding")
        return {}

    def get_dataloader(self, data) -> Union[DataLoader, List[DataLoader]]:
        return DataLoader(
            ClassificationDataset(data),
            batch_size=self.hparams["batch_size"], collate_fn=self._collate,
            num_workers=self.hparams['cpu_limit'], shuffle=False,
        )

    def _collate(self, examples):
        batch_size = len(examples)
        results = self.tokenizer.batch_encode_plus(
            examples,
            add_special_tokens=True,
            max_length=self.hparams["max_length"],
            return_tensors='pt',
            return_token_type_ids=True,
            # return_attention_masks=True,
            pad_to_max_length=True,
            truncation=True,
        )

        assert results["input_ids"].shape[0] == batch_size, \
            f"Invalid shapes {results['input_ids'].shape} {batch_size}"

        return {
            **results,
            "sents": examples,
        }