from functools import reduce
from tqdm import tqdm, tqdm_pandas
import IPython
import pandas as pd
import logging
from typing import Optional, Any, List, Callable, NoReturn, Tuple, Dict
import pathlib
import hydra
import omegaconf

from KGLexicGenerator import KGLexicGenerator

logger = logging.getLogger(__name__)


class DICEClozeGenerator(KGLexicGenerator):
    TRANSLATION = {
        # "be": "can be/ReceivesAction",
        "be at": "AtLocation",
        "requires": "HasPrerequisite",
        "be used to": "UsedFor",
        "want": "Desires",
        "has": "HasA",
        "wants to": "causesDesire",
        "be motivated by": "MotivatedByGoal",
        # (this is actually a very frequent case,
        # so if none of the above predicates appears,
        # then this is the default case)
        # "": "CapableOf",
    }

    def __init__(self, config: omegaconf.DictConfig):
        super().__init__(config)
        self.dice_path = [
            pathlib.Path('~/mowgli-cache/DICE/conceptnet-dice.csv'),
            pathlib.Path('~/mowgli-cache/DICE/conceptnet-extended.csv'),
        ]
        # self._check_dice_data()

    def _check_config(self) -> NoReturn:
        assert {'graph_path', 'output_path'}.issubset(self.config.keys()), \
            f'Invalid config {self.config}'
        assert pathlib.Path(self.config.graph_path).suffix == '.csv', f'Only supporting csv for DICE'
        assert pathlib.Path(self.config.output_path).suffix == '.csv', 'Only supporting csv for output format'

        assert isinstance(self.config.parallel, int), f'{self.config.parallel}'
        assert isinstance(self.config.max_len, int)

        assert len(self.config.predicates) > 0

    def _check_dice_data(self):
        for p in self.dice_path:
            assert p.exists(), f'{p} does not exists.'

    def _read_edges(self) -> pd.DataFrame:
        # dfs = []
        # for p in self.dice_path:
        #     dfs += [pd.read_csv(p)]
        # return pd.concat(dfs)
        # return pd.read_csv(self.dice_path[0])
        return pd.read_csv(self.config.graph_path, nrows=1000)

    @staticmethod
    def _convert_to_cn_edge(edge: pd.Series) -> pd.Series:
        new_edge: pd.Series = edge.copy()
        new_edge.drop(labels=['property'], inplace=True)
        found = False
        for pat, pred in DICEClozeGenerator.TRANSLATION.items():
            if str(edge['property']).startswith(pat):
                new_edge['predicate'] = f"/r/{pred.lower()}"
                new_edge['object'] = edge['property'].replace(pat, '')
                found = True

        if not found:
            # Capable of
            new_edge['predicate'] = '/r/capableof'
            new_edge['object'] = edge['property']
            found = True

        if not found:
            logger.error(f'New type of edge: {edge}')
            raise ValueError(f'new edge: {edge}')
        return new_edge

    @staticmethod
    def _is_Capableof(edge: pd.Series) -> bool:
        patterns = DICEClozeGenerator.TRANSLATION.keys()
        return not any([p in edge['property'] for p in patterns])

    def generate(self, materialize: bool = True) -> pd.DataFrame:
        df = self._get_useful_edges()
        logger.info(f'Found {len(df)} edges with given predicates.')

        # filtering edges based on some score
        top_df = pd.Series(self._get_top_edges(df))

        # lexicalizing

        n_cores = self.config.parallel if self.config.parallel > 0 else 1

        df_lexic = (self.lexic.lexicalize_corpus(top_df, n_cores=n_cores))

        # convering to prompt
        tqdm.pandas(desc="Prompt Gen.")
        df = df_lexic.progress_apply(self._to_prompt, axis=1)

        if materialize:
            self._write_prompts_to_disk(df)

        return df

    def _edge_to_cskg(self, edge: pd.Series) -> Tuple[str, str, Dict[str, str]]:
        data = edge.to_dict()
        data['weight'] = data['score']
        return edge['subject'], edge['object'], data

    def _get_useful_edges(self) -> pd.DataFrame:
        df = self._read_edges()

        tqdm.pandas(desc="DICE2CN")
        df_cn = df.progress_apply(DICEClozeGenerator._convert_to_cn_edge, axis=1)

        df_usefull = df_cn[df_cn.apply(self.is_usefull_edge, axis=1)]
        return df_usefull.apply(self._edge_to_cskg, axis=1)

    def is_usefull_edge(self, edge: pd.Series) -> bool:
        return str(edge['predicate']).lower() in self.USEFUL_PREDICATE


@hydra.main("../Configs/ClozeGeneratorConfig.yaml")
def main(config: omegaconf.DictConfig) -> Optional[Any]:
    m = DICEClozeGenerator(config)
    m.generate(materialize=True)
    # m._read_edges()
    return 0


if __name__ == '__main__':
    main()
