from functools import partial
from random import shuffle
from typing import Tuple, List
from itertools import combinations
import hashlib

import IPython
import pandas as pd
import numpy as np
import hydra
import omegaconf
import pathlib
from tqdm import tqdm
import logging
from sklearn.neighbors import NearestNeighbors
from sklearn.preprocessing import normalize
import pytorch_lightning as pl

from BasicBenchmark import BasicBenchmark
from MyEmbedder import MyEmbedder

logger = logging.getLogger(__name__)


class GenBenchmark(BasicBenchmark):
    def __init__(self, config: omegaconf.dictconfig.DictConfig):
        super().__init__(config)

    def create_benchmark(self):
        self._check_output_path()
        df_batch = self._read_submission().dropna(axis=0)
        df_gen = df_batch.apply(axis=1, func=self._mcq_to_gen)

        df_bench = pd.DataFrame([pd.Series({
            'prompt': lbl,
            **{f'refs_{i}': v for i, v in enumerate(df['refs'].values)},
        }) for lbl, df in df_gen.groupby('prompt')])

        self.split_and_save(df_bench, do_split=True)

    def _mcq_to_gen(self, pd_s: pd.Series) -> pd.Series:
        if pd_s['type'] == 'disabling':
            q = ' What makes this impossible? '
        elif pd_s['type'] == 'enabling':
            q = ' What makes this possible? '
        else:
            raise ValueError(pd_s['type'])

        return pd.Series({
            'prompt': pd_s['fact'] + q,
            'refs': pd_s['condition']
        })

    def split_and_save(self, df_bench, do_split=True):
        output_path = pathlib.Path(self.config.results_path)
        if do_split:
            df_train, df_eval, df_test = self._split_dataset(df_bench)
            logger.warning(f"Partitions: {df_train.shape}, {df_eval.shape}, {df_test.shape}")
            df_train.to_csv(output_path / 'train.csv', index=False)
            df_eval.to_csv(output_path / 'eval.csv', index=False)
            df_test.to_csv(output_path / 'test.csv', index=False)
        else:
            logger.warning(f"Partition: {df_bench.shape}")
            df_bench.to_csv(output_path / 'all.csv', index=False)

    def _split_dataset(self, df: pd.DataFrame) -> \
            Tuple[pd.DataFrame, pd.DataFrame, pd.DataFrame]:
        if self.config.groupby_fact:
            df_eval, df_test, df_train = self._groupby_fact_split(df, split_key='prompt')
        else:
            df_eval, df_test, df_train = self._naive_split(df)

        return df_train, df_eval, df_test


@hydra.main(config_path='../Configs/BenchmarkConvertorConfig.yaml')
def main(config: omegaconf.dictconfig.DictConfig):
    assert config.method == 'gen', f'wrong method selected {config.method}'
    GenBenchmark(config).create_benchmark()


if __name__ == '__main__':
    main()
