from functools import partial
from random import shuffle
from typing import Tuple, List
from itertools import combinations
import hashlib
import pandas as pd
import numpy as np
import hydra
import omegaconf
import pathlib
from tqdm import tqdm
import logging
from sklearn.neighbors import NearestNeighbors
from sklearn.preprocessing import normalize
import pytorch_lightning as pl

from BasicBenchmark import BasicBenchmark
from MyEmbedder import MyEmbedder

logger = logging.getLogger(__name__)


class MCBenchmark(BasicBenchmark):
    def __init__(self, config: omegaconf.dictconfig.DictConfig):
        super().__init__(config)
        self.knn = NearestNeighbors(n_neighbors=self.config.mc_knn.n_neighbors, metric="euclidean")

    def create_benchmark(self):
        self._check_output_path()
        df_batch = self._read_submission().dropna(axis=0)
        np_embd = self.create_or_load_embeddings(df_batch['condition'])
        if self.config.neg_sampling == 'cosine':
            df_mc_bench = self.consine_distance_sampling(df_batch, np_embd)
        elif self.config.neg_sampling == 'simple':
            df_mc_bench = self.simple_negative_sampling(df_batch, np_embd)
        else:
            raise ValueError(self.config.neg_sampling)

        # self.aflite_filtering(df_mc_bench, df_batch, np_embd)

        self.split_and_save(df_mc_bench, do_split=False)

    def split_and_save(self, df_bench, do_split=True):
        output_path = pathlib.Path(self.config.results_path)
        if do_split:
            df_train, df_eval, df_test = self._split_dataset(df_bench)
            logger.warning(f"Partitions: {df_train.shape}, {df_eval.shape}, {df_test.shape}")
            df_train.to_csv(output_path / 'train.csv', index=False)
            df_eval.to_csv(output_path / 'eval.csv', index=False)
            df_test.to_csv(output_path / 'test.csv', index=False)
        else:
            logger.warning(f"Partition: {df_bench.shape}")
            df_bench.to_csv(output_path / 'all.csv', index=False)

    def _split_dataset(self, df: pd.DataFrame) -> \
            Tuple[pd.DataFrame, pd.DataFrame, pd.DataFrame]:
        if self.config.groupby_fact:
            df_eval, df_test, df_train = self._groupby_fact_split(df, split_key='question')
        else:
            df_eval, df_test, df_train = self._naive_split(df)

        return df_train, df_eval, df_test

    def consine_distance_sampling(self, df, np_embd):
        logger.info(f'Computing the KNNs with k={self.knn.n_neighbors}')
        self.knn.fit(X=normalize(np_embd, norm='l2'))
        neighbors = self.knn.kneighbors(X=np_embd, return_distance=False)
        mc_list = []
        for ((i, row), neighs) in tqdm(zip(df.iterrows(), neighbors), desc='negative sampling'):
            df_neighs = df.iloc[neighs]
            df_negs = df_neighs[df_neighs.apply(
                axis=1,
                func=partial(lambda _r, _base: _r['fact'] != _base, _base=row['fact'])
            )]
            for comb, y_label in self.return_all_choice_combinations(
                    positive=row['condition'],
                    negatives_list=df_negs['condition'].values.tolist()[:self.config.mc_knn.n_neighbors_to_use]):
                mc_list.append({
                    'question': row['fact'],
                    'type': row['type'],
                    **{f'choice_{i}': c for i, c in enumerate(comb)},
                    'answer': y_label,
                })

        df_mc_bench = pd.DataFrame(mc_list)
        return df_mc_bench

    def simple_negative_sampling(self, df, np_embd):
        mc_list = []
        for i, row in tqdm(df.iterrows(), desc='negative sampling'):
            df_same_q_diff_typ = df[(df['fact'] == row['fact']) & (df['type'] != row['type'])]

            for comb, y_label in self.return_all_choice_combinations(
                    positive=row['condition'], negatives_list=df_same_q_diff_typ['condition'].values.tolist()):
                mc_list.append({
                    'question': row['fact'],
                    'type': row['type'],
                    **{f'choice_{i}': c for i, c in enumerate(comb)},
                    'answer': y_label,
                })

        df_mc_bench = pd.DataFrame(mc_list)
        return df_mc_bench

    def return_all_choice_combinations(self, positive: str, negatives_list: List[str]) \
            -> List[Tuple[List[str], int]]:
        # create all the combinations of the negative choices
        negatives_comb = combinations(negatives_list, self.config.mc_knn.n_negative_choices)
        output = []
        for _negs in negatives_comb:
            # for each combination, first adds the positive choice to the mix, then shuffles the choices
            # then it finds the index of the positive choice (correct label) and return it with the choices.
            choices = [positive] + list(_negs)
            shuffle(choices)
            pos_idx = choices.index(positive)
            output.append((choices, pos_idx))

        return list(output)

    def create_or_load_embeddings(self, resps):
        output_path = pathlib.Path(self.config.results_path)

        _hash_id = hashlib.sha256(self.config.batch_path.encode()).hexdigest()[0:20]
        embd_path = output_path / 'embeddings_{}.npy'.format(_hash_id)
        if not embd_path.exists():
            logger.info(f'Creating embeddings')
            _module = MyEmbedder(self.config, tmp_path=embd_path)
            trainer = pl.Trainer(
                gradient_clip_val=0,
                gpus=self.config.gpus,
                show_progress_bar=True,
                accumulate_grad_batches=8,
                max_epochs=1,
                min_epochs=1,
                val_check_interval=1,
                weights_summary='top',
                num_sanity_val_steps=2,
                resume_from_checkpoint=None,
            )

            trainer.test(_module, test_dataloaders=[_module.get_dataloader(resps.values)])

        logger.info(f'Reading embeddings')
        np_embd = np.load(embd_path)
        return np_embd



@hydra.main(config_path='../Configs/BenchmarkConvertorConfig.yaml')
def main(config: omegaconf.dictconfig.DictConfig):
    assert config.method == 'mc-knn', f'wrong method selected {config.method}'
    MCBenchmark(config).create_benchmark()


if __name__ == '__main__':
    main()
