from typing import Tuple, Dict, List, Callable, Any

import os
from sklearn.model_selection import train_test_split
import pandas as pd
import numpy as np
import hydra
import omegaconf
import pathlib
import IPython
from tqdm import tqdm
import logging
logger = logging.getLogger(__name__)


class DatasetSplitter:
    def __init__(self, config: omegaconf.dictconfig.DictConfig):
        self.config: omegaconf.dictconfig.DictConfig = config

    def _naive_split(self, df_with_context, splits=None):

        if splits is None:
            splits = self.config.split

        # X_train, X_test, y_train, y_test
        if splits[2] > 0:
            df_train_eval, df_test, _, _ = train_test_split(
                df_with_context,
                df_with_context['label'],
                test_size=splits[2]
            )
        else:
            df_train_eval = df_with_context
            df_test = None

        if splits[1] > 0:
            df_train, df_eval, _, _ = train_test_split(
                df_train_eval,
                df_train_eval['label'],
                test_size=splits[1]
            )
        else:
            df_train = df_train_eval
            df_eval = None

        return df_eval, df_test, df_train

    def _groupby_fact_split(self, df_with_context: pd.DataFrame, split_key: str = 'label'):
        logger.info(f'splitting dataset based on split key:{split_key}')
        df_facts_list = []
        for val, g_label in df_with_context.groupby(split_key):
            for label, g in g_label.groupby(split_key):
                df_facts_list.append(pd.Series({
                    'indecies': g.index.values,
                    'label': label
                }))
        df_facts = pd.DataFrame(df_facts_list)

        ind_eval, ind_test, ind_train = self._naive_split(df_facts)

        df_train = df_with_context.loc[np.concatenate(ind_train['indecies'].values).ravel()]
        df_test = df_with_context.loc[np.concatenate(ind_test['indecies'].values).ravel()]
        df_eval = df_with_context.loc[np.concatenate(ind_eval['indecies'].values).ravel()]

        return df_eval, df_test, df_train


class BasicBenchmark(DatasetSplitter):
    def __init__(self, config: omegaconf.dictconfig.DictConfig):
        super().__init__(config)
        self.config: omegaconf.dictconfig.DictConfig = config

    def create_benchmark(self):
        self._check_output_path()
        df = self._read_submission().dropna(axis=0)
        df_bench = my_flatmap(df, func=self._submission_to_test).drop_duplicates()
        self.split_and_save(df_bench)

    def split_and_save(self, df_bench):
        df_train, df_eval, df_test = self._split_dataset(df_bench)
        logger.warning(f"Partitions: {df_train.shape}, {df_eval.shape}, {df_test.shape}")
        output_path = pathlib.Path(self.config.results_path)
        df_train.to_csv(output_path / 'train.csv', index=False)
        df_eval.to_csv(output_path / 'eval.csv', index=False)
        df_test.to_csv(output_path / 'test.csv', index=False)

    def _check_output_path(self):
        output_path = pathlib.Path(self.config.results_path)
        logger.info(f'Checking the dir path {output_path}')
        os.makedirs(output_path, exist_ok=True)
        assert output_path.is_dir(), f'{output_path}'

    def _split_dataset(self, df: pd.DataFrame) -> \
            Tuple[pd.DataFrame, pd.DataFrame, pd.DataFrame]:
        df_no_context = df[df['context'] == ""]
        df_with_context = df[df['context'].apply(func=lambda c: len(c) > 0)]

        if self.config.groupby_fact:
            df_eval, df_test, df_train = self._groupby_fact_split(df_with_context, split_key='label')
        else:
            df_eval, df_test, df_train = self._naive_split(df_with_context)

        return pd.concat([df_train, df_no_context]), df_eval, df_test

    def _submission_to_test(self, sub: pd.Series) -> List[pd.Series]:
        output = [
            pd.Series({
                "context": sub['condition'],
                "question": sub['fact'],
                "label": 1 if sub['type'] == 'enabling' else 0,
            }),
        ]
        if self.config.include_facts:
            output.append(pd.Series({
                "context": "",
                "question": sub['fact'],
                "label": 1,
            }))
        return output

    def _read_submission(self) -> pd.DataFrame:
        df = pd.read_csv(self.config.batch_path, usecols=['prompt', 'response'])
        return df.apply(self._process_submission, axis=1)

    def _process_submission(self, sub: pd.Series) -> pd.Series:
        return pd.Series({
            **self._cleanup_prompt(sub['prompt']),
            **self._cleanup_response(sub['response']),
            **self._other_features(prompt=sub['prompt'], resp=sub['response'])
        })

    def _other_features(self, prompt: str, resp: str) -> Dict[str, Any]:
        return {}

    @staticmethod
    def _cleanup_prompt(prompt: str) -> Dict[str, str]:
        prompt_type = None
        fact = None

        if " possible?" in prompt:
            prompt_type = 'enabling'
            fact = prompt.replace(' When is this possible?', '')
        elif " impossible?" in prompt:
            prompt_type = 'disabling'
            fact = prompt.replace(' When is this impossible?', '')

        return {
            'fact': fact,
            'type': prompt_type,
        }

    @staticmethod
    def _cleanup_response(response: str) -> Dict[str, str]:
        cleaned_resp = response.lower().strip()
        if cleaned_resp.startswith('when '):
            cleaned_resp = cleaned_resp.replace('when ', '')
        if cleaned_resp.lower() == 'invalid':
            cleaned_resp = None

        if cleaned_resp is not None and not cleaned_resp.endswith('.'):
            cleaned_resp = cleaned_resp + '.'

        if cleaned_resp is not None:
            cleaned_resp = cleaned_resp[0].upper()+cleaned_resp[1:]
        return {
            'condition': cleaned_resp
        }


def my_flatmap(df: pd.DataFrame, func: Callable[[pd.Series], pd.Series]) -> pd.DataFrame:
    rows = []
    for index, row in tqdm(df.iterrows()):
        multrows = func(row)
        for rr in multrows:
            rows.append(rr)
    return pd.DataFrame.from_records(rows)


@hydra.main(config_path='../Configs/BenchmarkConvertorConfig.yaml')
def main(config: omegaconf.dictconfig.DictConfig):
    BasicBenchmark(config).create_benchmark()


if __name__ == '__main__':
    main()
