import logging
import os
import pathlib
from functools import partial

import pdb
import IPython
import omegaconf
from transformers import AutoTokenizer, AutoModelWithLMHead, PreTrainedModel, PreTrainedTokenizer
from transformers import pipeline

logger = logging.getLogger(__name__)


class MaskFiller:
    MODELS = ['bert-large-cased', 'roberta-base',
              'facebook/bart-large', 'xlm-roberta-large', 'transfo-xl-wt103']

    def __init__(self, config: omegaconf.DictConfig):
        self.model_name = config['unmask_model']
        assert self.model_name in self.MODELS
        logger.info('SpellChecker is using {}.'.format(self.model_name))

        self.cache_dir = pathlib.Path(os.path.expanduser('~/model_cache'))
        logger.info(f'Loading pretrained model {self.model_name}')
        tokenizer: PreTrainedTokenizer = AutoTokenizer.from_pretrained(self.model_name, cache_dir=self.cache_dir)
        model: PreTrainedModel = AutoModelWithLMHead.from_pretrained(self.model_name, cache_dir=self.cache_dir)
        self.unmasker = pipeline('fill-mask', model=model, tokenizer=tokenizer, framework='pt')

        self.white_list = {'for', 'by', 'to', 'through', 'from', 'as', 'in', 'when'}

    def unmask(self, sent: str) -> str:
        if '[MASK]' not in sent:
            return sent

        try:
            outs = self.unmasker(sent)
        except Exception as e:
            logger.error(f'Error string {sent}')
            raise e
        extra_tokens = set(sent.split()+['[CLS]', '[SEP]'])

        def _get_added_token(_filled, _extras, _whites):
            return set(_filled['sequence'].split()).difference(_extras).issubset(_whites)

        filtered_out = list(
            filter(
                partial(_get_added_token, _extras=extra_tokens, _whites=self.white_list),
                outs
            )
        )

        if len(filtered_out) == 0:
            # logger.warning(f'Invalid masked fill\n{outs}')
            return ""

        return filtered_out[0]['sequence']

    def __call__(self, sent: str):
        return self.unmask(sent)