import os
import pickle
from itertools import chain
from typing import List, Tuple, Dict, NoReturn, Union
import networkx as nx
import pandas as pd
import pathlib
import logging
import random
import IPython
import omegaconf
import hydra

# from tqdm import tqdm
from tqdm import tqdm, tqdm_pandas

from Lexicalization.ParallelLexicWrapper import ParallelLexicWrapper
from Utils import CoreQuisiteBase

logger = logging.getLogger(__name__)


class KGLexicGenerator(CoreQuisiteBase):
    def __init__(self, config: omegaconf.DictConfig) -> NoReturn:
        self.config = config
        self._check_config()

        if 'predicates' in self.config and self.config.predicates is not None:
            assert isinstance(self.config.predicates, omegaconf.listconfig.ListConfig), \
                f'{type(self.config.predicates)}'

            _check_str = [isinstance(s, str)
                          for s in self.config.predicates]
            logger.info(f'self.USEFUL_PREDICATE: {self.USEFUL_PREDICATE}')
            _check_type = [s.lower() in self.USEFUL_PREDICATE
                           for s in self.config.predicates]
            logger.info(f'check_str: {_check_str}, check_type:{_check_type}')
            assert all(_check_type) and all(_check_type), f'{self.config.predicates}'
            logger.warning(f'Updating USEFUL_PREDICATE: {self.config.predicates}')
            self.USEFUL_PREDICATE = [p.lower() for p in self.config.predicates]

        self.lexic = ParallelLexicWrapper(self.config)

    def _check_config(self) -> NoReturn:
        assert {'graph_path', 'output_path'}.issubset(self.config.keys()), \
            f'Invalid config {self.config}'
        assert pathlib.Path(self.config.graph_path).suffix == '.graphml', f'Only supporting graphml for KG format'
        assert pathlib.Path(self.config.output_path).suffix == '.csv', 'Only supporting csv for output format'

        assert isinstance(self.config.parallel, int), f'{self.config.parallel}'

    def is_usefull_edge(self, edge: Tuple[str, str, Dict[str, Union[str, float]]]) -> bool:
        """
        the function checks whether the given `edge` is usefull fo the process of looking up prompt questions.
        Args:
            edge: Tuple[str, str, Dict[str, Union[str, float]]]
                the edge as a string or the dict of string .

        Returns:

        """
        return str(edge[2]['predicate']).lower() in self.USEFUL_PREDICATE

    def generate(self, materialize: bool = True) -> pd.DataFrame:
        useful_edges = self._get_useful_edges()

        logger.info(f'Found {len(useful_edges)} edges with given predicates.')

        top_random_list = self._get_top_edges(useful_edges)

        df_raw = pd.DataFrame({'edges': top_random_list})['edges']
        # lexicalize all the edges
        n_cores = self.config.parallel if self.config.parallel > 0 else 1
        df_lexic = (self.lexic.lexicalize_corpus(df_raw, n_cores=n_cores))

        tqdm.pandas(desc="Prompt Gen.")
        df = df_lexic.progress_apply(self._to_prompt, axis=1)

        if materialize:
            self._write_prompts_to_disk(df)

        return df

    def _dump_func_prefix(self, pred: str, df: pd.DataFrame) -> NoReturn:
        assert len(df) > 0, f'{df.shape}'
        pred = pred.replace('/r/', '')
        # ['subject', 'object', 'edge_predicate', 'edge_other', 'edge_datasource', 'edge_weight']
        non_able_columns = list(set(df.columns) - {'enabling', 'disabling'})
        # logger.info(f'Key columns= {non_able_columns}')
        formatted_df = df.set_index(non_able_columns).stack()

        folder_path = self._get_folder_path(pred)

        filename = pathlib.Path(self.config.output_path).name
        formatted_df.to_csv(folder_path/filename)

        self._dump_random_sample(formatted_df, folder_path, n_sample=800)
        self._dump_random_sample(formatted_df, folder_path, n_sample=2000)
        self._dump_random_sample(formatted_df, folder_path, n_sample=8000)

    def _get_folder_path(self, pred):
        folder_path = pathlib.Path(self.config.output_path).parent / pred
        os.makedirs(folder_path, exist_ok=True)
        return folder_path

    def _dump_random_sample(self, formatted_df, folder_path, n_sample):
        subset_indx = random.sample(
            formatted_df.index.droplevel(7).to_frame().values.tolist(),
            k=n_sample//2
        )
        pair_subset_indx = list(self.flatmap(
            lambda l: [tuple(l + ['enabling']), tuple(l + ['disabling'])],
            subset_indx
        ))
        subset = formatted_df.loc[pair_subset_indx].drop_duplicates()

        # subset = formatted_df.sample(n=n_sample, axis=0)
        subset.to_csv(folder_path / f'Predicates_{n_sample}.txt', header=False, index=False)

    def _write_prompts_to_disk(self, df: pd.DataFrame):
        for pred, partial_df in tqdm(df.groupby(by='edge_predicate'), desc='Dumping'):
            self._dump_func_prefix(pred=pred, df=partial_df)

    def _get_top_edges(self, useful_edges) -> pd.DataFrame:
        sorted_edges = sorted(useful_edges, key=lambda d: d[2]['weight'], reverse=True)
        if 'max_len' in self.config.keys() and self.config.max_len > 0:
            buf_len = min(len(sorted_edges), self.config.max_len * 5)
            top_random_list = random.choices(sorted_edges[:buf_len], k=self.config.max_len)
        else:
            top_random_list = sorted_edges
        return top_random_list

    def _get_useful_edges(self):
        useful_edges_path = pathlib.Path('./useful_edges.pkl')
        if not useful_edges_path.exists():
            logger.info(f'Loading KG from {self.config.graph_path}')
            self.kg: nx.MultiDiGraph = nx.read_graphml(self.config.graph_path)
            logger.info(f'filtering edges based on predicate')
            useful_edges = list(filter(self.is_usefull_edge, self.kg.edges(data=True)))
            logger.info(f'Generating useful edges cache.')
            with open(useful_edges_path, 'wb') as fp:
                pickle.dump(useful_edges, fp)
        else:
            logger.info(f"Loading useful edges from {useful_edges_path}")
            with open(useful_edges_path, 'rb') as fp:
                useful_edges = pickle.load(fp)
        return useful_edges

    @staticmethod
    def _clean_edge_data(edge: Tuple[str, str, Dict[str, str]]) -> pd.Series:
        return pd.Series({
            "subject": edge[0],
            "object": edge[1],
            **{f'edge_{k}': v for k, v in edge[2].items()}
        })

    def _to_prompt(self, lexic: pd.Series) -> pd.Series:
        lex_edge = str(lexic['sentence']).capitalize()
        perp = lexic['perplexity']
        edge = lexic['edge']

        perplx_dict = {'perplexity': perp}

        edge_info = KGLexicGenerator._clean_edge_data(edge)

        q_enable = random.choice(KGLexicGenerator.QUESTIONS['enabling'])
        q_disable = random.choice(KGLexicGenerator.QUESTIONS['disabling'])

        if KGLexicGenerator.COND_TYPE[edge[2]['predicate'].lower()] in ['both', 'enabling']:
            enabling = {"enabling": f'{lex_edge}. {q_enable}?'}
        else:
            enabling = {"enabling": None}

        if KGLexicGenerator.COND_TYPE[edge[2]['predicate'].lower()] in ['both', 'disabling']:
            disabling = {"disabling": f'{lex_edge}. {q_disable}?'}
        else:
            disabling = {"disabling": None}

        return pd.Series({
            **edge_info,
            **enabling,
            **disabling,
            **perplx_dict,
        })

    @staticmethod
    def flatmap(f, items):
        return chain.from_iterable(map(f, items))


@hydra.main("../Configs/ClozeGeneratorConfig.yaml")
def main(config: omegaconf.DictConfig) -> NoReturn:
    # exit(1)
    gen = KGLexicGenerator(config)
    gen.generate()


if __name__ == '__main__':
    main()
