import logging
import re

import IPython
import pandas as pd
from collections import deque
from typing import List, Union, Iterable, Generator, Tuple, Dict
import omegaconf
from tqdm import tqdm

from Lexicalization.Lexicalizations import EdgeLexicalizer, PathLexicalizer, EntityLexicError
from Lexicalization.SpellChecker import SpellChecker

logger = logging.getLogger(__name__)


COMMONSENSE_MAPPING = {
    'Antonym': 'is antonym of',
    'AtLocation': 'is at location',
    'CapableOf': 'is capable of',
    'Causes': 'causes',
    'CausesDesire': 'causes desire of',
    'CreatedBy': 'is created by',
    'DefinedAs': 'is defined as',
    'DerivedFrom': 'is derived from',
    'Desires': 'desires',
    'DistinctFrom': 'is distinct from',
    'Entails': 'entails',
    'EtymologicallyDerivedFrom': 'is etymologically derived from',
    'EtymologicallyRelatedTo': 'is etymologically related to',
    'fe:ExcludesFE': 'excludes frame element',
    'FormOf': 'is form of',
    'HasA': 'has a',
    'HasContext': 'has context',
    'HasFirstSubevent': 'has first subevent',
    'HasFrameElement': 'has frame element',
    'HasInstance': 'has instance',
    'HasLastSubevent': 'has last subevent',
    'HasLexicalUnit': 'has lexical unit',
    'HasPrerequisite': 'has prerequisite',
    'HasProperty': 'has property',
    'HasSemType': 'has semantic type',
    'HasSubevent': 'has subevent',
    'HasSubframe': 'has subframe',
    'InstanceOf': 'is instance of',
    'IsA': 'is a',
    'IsCausativeOf': 'is causative of',
    'IsInchoativeOf': 'is inchoative of',
    'IsInheritedBy': 'is inherited by',
    'IsPerspectivizedIn': 'is perspectivized in',
    'IsPOSFormOf': 'is part-of-speech form of',
    'IsPrecededBy': 'is preceded by',
    'IsUsedBy': 'is used by',
    'LocatedNear': ' has location near',
    'MadeOf': 'is made of',
    'MannerOf': 'is a manner of',
    'MotivatedByGoal': 'is motivated by goal',
    'object': 'object',
    'OMWordnetOffset': 'has wordnet offset',
    'NotCapableOf': 'is not capable of',
    'NotDesires': 'does not desire',
    'NotHasProperty': 'does not have property',
    'PartOf': 'is part of',
    'PartOfSpeech': 'has part-of-speech of',
    'PerspectiveOn': 'is perspective on',
    'POSForm': 'is part-of-speech form',
    'Precedes': 'precedes',
    'PWordnetSynset': 'has wordnet synset',
    'ReceivesAction': 'has received action',
    'ReframingMapping': 'has reframing mapping',
    'RelatedTo': 'is related to',
    'st:RootType': 'has root type',
    'fe:RequiresFE': 'requires frame element',
    'subject': 'subject',
    'SeeAlso': 'see also',
    'SimilarTo': 'is similar to',
    'subClassOf': 'is subclass of',
    'SubframeOf': 'is subframe of',
    'st:SubType': 'is subtype of',
    'st:SuperType': 'is supertype of',
    'SymbolOf': 'is symbol of',
    'Synonym': 'is synonym of',
    'UsedFor': 'is used for',
    'Uses': 'uses',
    'DesireOf': 'desire of',
    'InheritsFrom': 'inherits from',
    'LocationOfAction': 'is location of action',
    'dbpedia/capital': 'is capital of',
    'dbpedia/field': 'is field of',
    'dbpedia/genre': 'is genre of',
    'dbpedia/genus': 'is genus of',
    'dbpedia/influencedBy': 'is influenced by',
    'dbpedia/knownFor': 'is known for',
    'dbpedia/language': 'has language',
    'dbpedia/leader': 'is leader of',
    'dbpedia/occupation': 'has occupation',
    'dbpedia/product': 'is product of',
    # new relations
    'rdfs:subClassOf': 'is subclass of',

    # ATOMIC Predicates:
    'at:xNeed': 'needs to',
    'at:xAttr': 'is',
    'at:xWant': 'want to',
    'at:oReact': 'will be',
    'at:oEffect': 'will',
    'at:xEffect': 'gets',
    'at:xIntent': 'wanted to be',
    'at:oWant': 'will',
    'at:xReact': 'will feel',
}


class RuleBasedEdgeLexic(EdgeLexicalizer):
    def __init__(self, config: omegaconf.DictConfig) -> None:
        super().__init__()
        self.templates = self._generate_templates()
        if hasattr(config, 'perplexity') and config.perplexity:
            self.checker = SpellChecker(config)
            logger.info(f'Using perplexity ({self.checker}) in {self}')

    @staticmethod
    def _generate_templates() -> Dict[str, List[str]]:
        mappings = {
            'UsedFor': ['is typically used for',
                        'is typically used to',
                        'is typically used by',
                        'is typically for',
                        ],
            'CapableOf': ['is typically capable of', 'is typically capable to'],
            'Causes': ['typically causes'],
            'CausesDesire': ['typically causes desire of'],
            'Desires': ['typically desires'],
            'NotDesires': ['does not typically desire', 'does not typically desire to'],
            'RelatedTo': ['is typically related to', 'is typically related'],
            'IsA': ['is'],
            'PartOf': ['is part of'],
            'CreatedBy': ['is created by'],
            'MannerOf': ['is a manner of', 'is manner of'],

        }

        # create templates related to predicate verbalizaiton
        all_rules: Dict[str, List[str]] = {
            k.lower(): ['{s} {p} {o}'.replace('{p}', v) for v in vl]
            for k, vl in mappings.items()
        }
        "{0} is used for {1}", "{0} is for {1}", "You can use {0} to {1}", "You can use {0} for {1}", "{0} are used to {1}", "{0} is used to {1}", "{0} can be used to {1}", "{0} can be used for {1}"
        # add templates with A and An
        # for k, rl in all_rules.items():
        #     for r in rl:
        #         all_rules[k] = all_rules[k] + [f'A {r}', f'An {r}']

        # add the trailing A and An
        for k, rl in all_rules.items():
            for r in rl:
                all_rules[k] = all_rules[k] + [r.replace('{p}', '{p} a'), r.replace('{p}', '{p} an')]

        # logger.info(f'##########\nall rules:\n{all_rules}\n##########\n')
        return all_rules

    @staticmethod
    def _cskg_to_edge(cskg_edge: Tuple[str, str, Dict[str, str]]) -> Tuple[str, str, str]:
        return cskg_edge[0], cskg_edge[2]['predicate'].lower(), cskg_edge[1]

    def convert(self, cskg_edge: Tuple[str, str, Dict[str, str]]) -> Dict[str, Union[str, float]]:
        raise DeprecationWarning('use of convert is depricated')
        # edge = self._cskg_to_edge(cskg_edge)
        # cleaned_edge = tuple(map(self._clean_relation_entity, edge))
        # lexics = self._triple2str(*cleaned_edge)
        # if hasattr(self, 'checker'):
        #     out = min([(s, self.checker.calc_perplxty(s)) for s in lexics], key=lambda t: t[1])
        # else:
        #     out = lexics[0], -1.0
        #
        # return pd.Series({'sentence': out[0], 'perplexity': out[1], 'edge': cskg_edge})

    def convert_corpus(self, cskg_edges: Iterable[Tuple[str, str, Dict[str, str]]]) \
            -> Iterable[Dict[str, Union[str, float]]]:
        all_lexics = []
        for cskg_edge in tqdm(cskg_edges, desc='Create Candidates'):
            edge = self._cskg_to_edge(cskg_edge)
            cleaned_edge = tuple(map(self._clean_relation_entity, edge))
            lexics = self._triple2str(*cleaned_edge)
            all_lexics.append(lexics)

        if hasattr(self, 'checker'):
            sents, scores = self.checker.pick_best_sentence(all_lexics)
        else:
            sents = [lex[0] for lex in all_lexics]
            scores = -1

        df = pd.DataFrame.from_dict(
            {'sentence': sents, 'perplexity': scores, 'edge': cskg_edges}, )
        return df

    @staticmethod
    def _clean_relation_entity(s: str) -> str:
        if '/' not in s:
            return s
        for p in [r'\/c\/en\/([^\s\/]*)', r'\/r\/([^\s\/]*)', r'[^:]:([^\s\/]*)']:
            m = re.findall(p, s)
            if len(m) > 0:
                # assert len(m) == 1, f'multiple match (p={p}) in {s} :{m}'
                return m[0].replace('_', ' ').strip()
        raise EntityLexicError(f'{s}')

    def _triple2str(self, s: str, p: str, o: str) -> Union[str, List[str]]:
        plist = self.templates.get(p, None)
        if plist is None:
            logger.error(f'Unknown predicate: {s} {p} {o}')
        # logger.info(f'Using all the rules: {plist}')
        return [
            t.format(s=s, o=o).replace('  ', ' ')
            for t in plist
        ]

    @staticmethod
    def _pred2str(pred: str) -> str:
        verb = COMMONSENSE_MAPPING.get(pred, None)
        if verb is None:
            logger.error(f'New predicate: {pred}')
            return pred
        else:
            return verb

