from functools import partial
from typing import Callable, Type, Tuple, Dict
import omegaconf
import logging
import numpy as np
import pandas as pd
from multiprocessing import Pool, current_process
from tqdm import tqdm

# from Lexicalization.CNQueryLexic import CNQueryEdgeLexicPandas, CNQueryEdgeLexicInMem
from Lexicalization.RuleBasedLexic import RuleBasedEdgeLexic
from Lexicalization.RuleBasedMaskedLexic import RuleBasedMaskedEdgeLexic

logger = logging.getLogger(__name__)


class ParallelLexicWrapper:
    def __init__(self, config: omegaconf.DictConfig):
        self.config = config

        if (
                ('lexic' in self.config) and
                (self.config.lexic is not None)
        ):
            if self.config.lexic == 'CNQueryEdgeLexicPandas':
                raise NotImplemented('CNQueryEdgeLexicPandas')
                # self.lexic_gen = partial(CNQueryEdgeLexicPandas, self.config)
            elif self.config.lexic == 'CNQueryEdgeLexicInMem':
                raise NotImplemented('CNQueryEdgeLexicInMem')
                # self.lexic_gen = partial(CNQueryEdgeLexicInMem, self.config)
            elif self.config.lexic == 'Base' or self.config.lexic == 'CNQueryEdgeLexic':
                self.lexic_gen = partial(RuleBasedEdgeLexic, self.config)
            elif self.config.lexic == 'RuleBasedMaskedEdgeLexic':
                self.lexic_gen = partial(RuleBasedMaskedEdgeLexic, self.config)
            else:
                raise ValueError(f'Invalid lexic: {self.config.lexic}')
        else:
            self.lexic_gen = partial(RuleBasedEdgeLexic, self.config)

    def lexicalize_corpus(self, df: pd.Series, n_cores: int = 4) -> pd.DataFrame:
        if n_cores is None or n_cores == 1:
            logger.info(f'Running lexicalization without Pool')
            par_func = self.lexic_gen().convert_corpus
            out_df = par_func(df)
            return out_df
            # tqdm.pandas(desc=f"lexicalize")
            # return df.progress_apply(func=par_func)
        else:
            raise NotImplemented('only 1 core for know')

        logger.info(f'Spliting df')
        df_split = np.array_split(df, n_cores)

        par_func = partial(ParallelLexicWrapper._apply_col_wrapper,
                           lexic_gen=self.lexic_gen)

        pool = Pool(n_cores)
        logger.info('Running Pool of workers')

        logger.info('Concatenating results')
        df_out = pd.concat(pool.map(par_func, df_split))

        logger.info('Closing Pool')
        pool.close()
        pool.join()
        return df_out

    @staticmethod
    def _generate_converter(lexic_gen: Callable) -> Callable[[Tuple[str, str, Dict[str, str]]], Tuple[str, float]]:
        lexic = lexic_gen()
        return lexic.convert

    @staticmethod
    def _apply_col_wrapper(small_df: pd.Series, lexic_gen: Callable):
        func = lexic_gen().convert
        myname = int(current_process().name.split('-')[-1])
        if myname == 1:
            tqdm.pandas(desc=f"lexicalize_{current_process()}")
            return small_df.progress_apply(func=func)
        else:
            return small_df.apply(func=func)

