import logging
import pathlib
import re
from typing import List, Tuple, NoReturn, Optional, Union, Dict, Set

import IPython
from tqdm import tqdm

from Utils import CoreQuisiteBase
from .Lexicalizations import EntityLexicError, EdgeLexicalizer, SimplePathLexic
from .RuleBasedLexic import COMMONSENSE_MAPPING
from .Utils import OMCSBase
import pandas as pd

logger = logging.getLogger(__name__)


class CNQueryEdgeLexicBase(EdgeLexicalizer, OMCSBase, CoreQuisiteBase):

    def __init__(self, *args, **kwargs) -> NoReturn:
        super().__init__(*args, **kwargs)

    def convert(self, edge: Tuple[str, str, str]) -> Tuple[str, float]:
        """
        Converts the edge to the lexicalized sentence associated with it.
        The module assumes the edge consists of the string label of source/target nodes and the edge between them,
        in the format (source_label, edge_label, target_label)
        Args:
            edge: Tuple[NODE_T str, EDGE_T str, NODE_T str]

        Returns: str
            a sentence associated with the lexicalization of the edge
        """
        edge_clean = list(map(self._clean_relation_entity, edge))
        sent = self._triple2str(*edge_clean)
        return sent, 0.0

    def _triple2str(self, s: str, p: str, e: str) -> Optional[str]:
        all_matches = []
        p_str = self._pred2str(p)
        for wp in self.cn_web_paths:
            file_path = self.cache_path / wp.split('/')[-1]
            assert file_path.exists(), f'{file_path} does not exist.'
            # FIXME: I am just using the predicate base for performance reasons
            # matches = self._lookup_in_file(words=[s, p_str, e], path=file_path)
            matches = self._lookup_in_file(words=[s, self._get_predicate_base(p_str), e], path=file_path)
            # if len(matches) == 0:
            #     # use the base of predicate for the look up
            #     matches += self._lookup_in_file(words=[s, self._get_predicate_base(p_str), e], path=file_path)
            all_matches += matches

        if len(all_matches) == 0:
            logger.debug(f'No matches for: {[s, p_str, e]}')
            return None

        sent = self._filter_matches(matches=all_matches, words=[s, p_str, e])
        return sent

    def _filter_matches(self, matches: List[str], words: List[str]) -> str:
        a = self.cn_web_paths  # dummy
        return sorted(matches, key=len)[0]

    @staticmethod
    def _pred2str(pred: str) -> str:
        verb = COMMONSENSE_MAPPING.get(pred, None)
        if verb is None:
            logger.error(f'New predicate: {pred}')
            return pred
        else:
            return verb

    @staticmethod
    def _clean_relation_entity(s: str) -> str:
        if '/' not in s:
            return s
        for p in [r'\/c\/en\/([^\s\/]*)', r'\/r\/([^\s\/]*)', r'[^:]:([^\s\/]*)']:
            m = re.findall(p, s)
            if len(m) > 0:
                # assert len(m) == 1, f'multiple match (p={p}) in {s} :{m}'
                return m[0].replace('_', ' ')
        raise EntityLexicError(f'{s}')


class CNQueryEdgeLexicPandas(CNQueryEdgeLexicBase):
    def __init__(self, *args, **kwargs):
        super().__init__(*args, **kwargs)
        self.memory: Dict[str, pd.DataFrame] = {}
        self._load_data()

    def _load_data(self):
        logger.info(f'Loading data for CNQueryEdgeLexicPandas')
        for p in self.cn_web_paths:
            path: pathlib.Path = self.cache_path / p.split('/')[-1]
            assert path.exists()

            df = pd.read_csv(path, sep='\t', usecols=[1, 4, 6],
                             error_bad_lines=False, warn_bad_lines=True)
            df = df[df['language_id'] == 'en']
            self.memory[str(path)] = df

        # self.memory = pd.concat(dfs, axis=0)
        # logger.info(f'{self.memory.head()}')

    def _lookup_in_file(self, words: List[str], path: pathlib.Path, keep_score: bool = False, add_words: bool = False) \
            -> List[Union[str, Tuple[str, float]]]:
        memory = self.memory[str(path)]
        matches: pd.DataFrame = memory[
            memory.apply(lambda r: all([w in r['text'] for w in words]), axis=1)
        ]
        logger.debug(f'look up {words} resulted in {len(matches)} rows.')
        if len(matches) > 0:
            logger.debug(f'First match is: {matches.iloc[0]}')

        if keep_score:
            keep_func = lambda m: (m['text'], float(m['score']))
        else:
            keep_func = lambda m: (m['text'],)

        pre = (words,) if add_words else ()
        prep = lambda r: (*pre, *keep_func(r))

        return matches.apply(prep, axis=1).values.tolist()


class CNQueryEdgeLexicInMem(CNQueryEdgeLexicPandas):
    def __init__(self, *args, **kwargs):
        self.lookup_table: Dict[str, Dict[str, Set[int]]] = {}
        super().__init__(*args, **kwargs)

    def _load_data(self):
        super(CNQueryEdgeLexicInMem, self)._load_data()
        regx = r"".join([r"\b{}\b|".format(w)
                         for w in ['a', 'as', 'of', 'for', 'not', 'by', r'\.', r'\,', 'be']])[:-1]

        for path, memory in self.memory.items():
            for i, sent in tqdm(memory['text'].items(), desc='sentences'):
                sent_clean = re.sub(regx, "", sent)
                for w in sent_clean.split():
                    if w not in self.lookup_table[str(path)]:
                        if len(w) == 0:
                            continue
                        self.lookup_table[str(path)][w] = set()

                self.lookup_table[w].add(i)

    def _lookup_in_file(self, words: List[str], path: pathlib.Path, keep_score: bool = False, add_words: bool = False) \
            -> List[Union[str, Tuple[str, float]]]:
        lut = self.lookup_table[str(path)]
        memory = self.memory[str(path)]
        all_indecies = list(sum([lut[w] for w in words]))
        matches: pd.DataFrame = memory.iloc[all_indecies]

        logger.debug(f'look up {words} resulted in {len(matches)} rows.')
        if len(matches) > 0:
            logger.debug(f'First match is: {matches.iloc[0]}')

        if keep_score:
            keep_func = lambda m: (m['text'], float(m['score']))
        else:
            keep_func = lambda m: (m['text'],)

        pre = (words,) if add_words else ()
        prep = lambda r: (*pre, *keep_func(r))

        return matches.apply(prep, axis=1).values.tolist()
