import logging
import IPython
from typing import List, Union, Iterable, Generator, Tuple, Dict
import omegaconf
import re
from Lexicalization.RuleBasedLexic import RuleBasedEdgeLexic, COMMONSENSE_MAPPING
from Lexicalization.SpellChecker import SpellChecker
from Lexicalization.MaskFiller import MaskFiller

logger = logging.getLogger(__name__)


class RuleBasedMaskedEdgeLexic(RuleBasedEdgeLexic):
    def __init__(self, config: omegaconf.DictConfig) -> None:
        super().__init__(config)
        self.unmasker = MaskFiller(config)

    @staticmethod
    def _generate_templates() -> Dict[str, List[str]]:
        mappings = {
            'UsedFor': [
                '{s} is typically used [MASK] {o}',
                '{s} are typically used [MASK] {o}',
                # '{s} is typically [MASK] {o}',
                # '{s} are typically [MASK] {o}',
                'You can typically use {s} [MASK] {o}',
                '{s} can typically be used [MASK] {o}',
            ],
            'CapableOf': [
                '{s} is typically capable [MASK] {o}',
                '{s} are typically capable [MASK] {o}'
            ],
            'Causes': [
                '{s} typically causes {o}',
                '{s} typically cause {o}'
            ],
            'CausesDesire': [
                '{s} typically causes desire [MAKS] {o}',
                '{s} typically cause desire [MAKS] {o}'
            ],
            'Desires': ['{s} typically desires {o}'],
            # 'NotDesires': ['does not typically desire', 'does not typically desire to'],
            # 'RelatedTo': ['is typically related to', 'is typically related'],
            # 'IsA': ['is'],
            # 'PartOf': ['is part of'],
            # 'CreatedBy': ['is created by'],
            # 'MannerOf': ['is a manner of', 'is manner of'],

        }

        # create templates related to predicate verbalizaiton
        all_rules: Dict[str, List[str]] = {
            k.lower(): list(vl)
            for k, vl in mappings.items()
        }

        return all_rules

    def _triple2str(self, s: str, p: str, o: str) -> Union[str, List[str]]:
        plist = self.templates.get(p, None)
        if plist is None:
            logger.error(f'Unknown predicate: {s} {p} {o}')
            raise ValueError(f'Could not find predicate {p}')

        return [
            s for s in
            [
                self.unmasker.unmask(t.format(s=s, o=o))
                    .replace(' [SEP]', '')
                    .replace('[CLS] ', '')
                for t in plist
            ]
            if len(s) > 0
        ]



