from functools import partial
import random
from random import shuffle
from typing import Tuple, List, Generator, Iterable, Any, Optional, Callable, Union
# from itertools import combinations
import itertools

import hashlib

import IPython
import os
import pandas as pd
import numpy as np
import hydra
import omegaconf
import pathlib

import torch
import warnings
from scipy.sparse import csr_matrix
from sklearn.model_selection import train_test_split
from tqdm import tqdm
import logging
from sklearn.linear_model import LogisticRegression
from sklearn.preprocessing import normalize, OneHotEncoder
import pytorch_lightning as pl

from BackTranslator import Rephraser
from BasicBenchmark import DatasetSplitter
from MyEmbedder import MyEmbedder
from Utils import my_df_flatmap

logger = logging.getLogger(__name__)


class AdvFiltering(DatasetSplitter):
    def __init__(self, config: omegaconf.dictconfig.DictConfig):
        super().__init__(config)
        self.config = config

        # self._hash_id = self._my_hash(self.config.mc_path.encode())
        # self._hash_id = hashlib.sha256(self.config.mc_path.encode()).hexdigest()[0:20]

    @staticmethod
    def _my_hash(inp, limit: int = 20):
        # hs = str(hashlib.sha256(inp).hexdigest()).encode()
        hs = str(hashlib.sha256(inp).hexdigest())
        if limit is not None:
            return hs[:limit]
        else:
            return hs

    def _get_path(self, prefix: str, args: Iterable[Any], ext: str = 'npy') -> pathlib.Path:
        _hash_func_lu = {
            np.ndarray: lambda ar: ar.tostring(),
            pd.DataFrame: lambda df: pd.util.hash_pandas_object(df).values,
            pd.Series: lambda sr: pd.util.hash_pandas_object(sr).values
        }
        # _arg_list = [ar.tostring() if isinstance(ar, np.ndarray) else str(ar) for ar in args]
        _arg_list = [self._my_hash(_hash_func_lu.get(type(ar), lambda o: str(o).encode())(ar)) for ar in args]
        _args_hash = self._my_hash("".join(_arg_list).encode())
        output_path = pathlib.Path(self.config.results_path)
        _path = output_path / f'{prefix}_{_args_hash}.{ext}'
        return _path

    # def _file_cache(self, func: Callable[[Any], Union[np.ndarray, pd.DataFrame, pd.Series]],
    #                 prefix: str, args: Tuple[Any]) -> Union[np.ndarray, pd.DataFrame, pd.Series]:
    #     fpath = self._get_path(prefix, args)
    #     # func_lu = {
    #     #     pd.Series: (pd.read_csv, lambda path, df: df.to_csv(path, index=False)),
    #     #     pd.DataFrame: (pd.read_csv, lambda path, df: df.to_csv(path, index=False)),
    #     #     np.ndarray: (np.load, np.save),
    #     # }
    #     if fpath.exists():
    #         logger.info(f'Loading {prefix} from {fpath}')
    #         out = np.load(fpath)
    #     else:
    #         out = func(*args)
    #         logger.info(f'Running the {prefix}')
    #         return None

    def run(self):
        output_path = pathlib.Path(self.config.results_path)
        os.makedirs(output_path, exist_ok=True)

        df_examples = self._load_mc_bechmark()
        df_examples.drop_duplicates(subset='text', inplace=True)
        df_examples.reset_index(drop=True, inplace=True)
        np_embd = self._compute_all_embeddings(df_examples['text'])
        np_label = df_examples['label'].values
        np_removed_idx = self.run_af_algorithm(np_embd=np_embd, np_label=np_label)

        df_filtered_examples = df_examples.drop(index=np_removed_idx)
        df_bench = self._recreate_mc_benchmark(df_filtered_examples)

        self.split_and_save(df_bench, do_split=False)
        self.split_and_save(df_bench, do_split=True)

    def split_and_save(self, df_bench, do_split=True):
        output_path = pathlib.Path(self.config.results_path)
        if do_split:
            df_eval, df_test, df_train = self._groupby_fact_split(df_bench, split_key='question')

            logger.warning(f"Partitions: {df_train.shape}, {df_eval.shape}, {df_test.shape}")
            df_train.to_csv(output_path / 'train.csv', index=False)
            df_test.to_csv(output_path / 'test.csv', index=False)
            df_eval.to_csv(output_path / 'eval.csv', index=False)

            if self.config.rephrase.test.q:
                logger.info(f'Rephrasing df_test')
                IPython.embed()
                exit()

                rf = Rephraser(self.config)
                df_test_rephrased = rf.process(df_test)
                df_test_rephrased.to_csv(output_path / 'test_rephrased.csv', index=False)

            if self.config.rephrase.eval.q:
                logger.info(f'Rephrasing df_eval')
                rf = Rephraser(self.config)
                df_eval_rephrased = rf.process(df_eval)
                df_eval_rephrased.to_csv(output_path / 'eval_rephrased.csv', index=False)

        else:
            logger.warning(f"Partition: {df_bench.shape}")
            df_bench.to_csv(output_path / 'all.csv', index=False)

    def _recreate_mc_benchmark(self, df_data: pd.DataFrame):
        df_split = df_data.apply(func=lambda r: pd.Series({
            "question": r['text'].split('?')[0]+'?',
            'answer': r['text'].split('?')[1],
            'label': r['label']
        }), axis=1)

        mc_list = []
        for _q, df_q in tqdm(df_split.groupby('question'), desc='Questions Group'):
            df_negs = df_q[df_q['label'] == 0]
            df_poss = df_q[df_q['label'] == 1]
            negs_comb = itertools.combinations(df_negs['answer'].values.tolist(),
                                     self.config.n_mc - 1)
            pos_list = df_poss['answer'].values.tolist()

            for _ngs, _ps in zip(negs_comb, itertools.cycle(pos_list)):
                choices = [_ps] + list(_ngs)
                shuffle(choices)
                pos_idx = choices.index(_ps)

                mc_list.append({
                    'question': _q,
                    **{f'choice_{i}': c for i, c in enumerate(choices)},
                    'answer': pos_idx,
                })
        return pd.DataFrame(mc_list)

    def _load_mc_bechmark(self):
        examples_path = self._get_path(prefix='af_bench', args=(self.config.mc_path, ), ext='csv')

        if not examples_path.exists():
            logger.info(f'Computing mc_benchmark from {examples_path}')
            df_bech = pd.read_csv(self.config.mc_path)
            df_examples = my_df_flatmap(df_bech, self.mc_to_nli)

            logger.info(f'Removing duplicates ...')
            prev_len = len(df_examples)
            logger.info(f'Dropped {prev_len - len(df_examples)} duplicates. Remaining ({len(df_examples)})')

            logger.info(f'Dumping intermediate examples: {examples_path}')
            df_examples.to_csv(examples_path, index=False)
        else:
            logger.info(f'Loading processed mc_benchmark from {examples_path}')
            df_examples = pd.read_csv(examples_path, index_col=False)

        return df_examples

    @staticmethod
    def mc_to_nli(mc_ins: pd.Series) -> Generator[pd.Series, None, None]:
        fact = mc_ins['question']
        n_cs = len(mc_ins) - 3
        if mc_ins['type'] == 'disabling':
            q = ' What makes this impossible? '
        elif mc_ins['type'] == 'enabling':
            q = ' What makes this possible? '
        else:
            raise ValueError(mc_ins['type'])

        for i in range(n_cs):
            c = mc_ins[f'choice_{i}']
            if i == mc_ins['answer']:
                label = 1
            else:
                label = 0
            yield pd.Series({
                'text': fact+q+c,
                'label': label
            })

    def _compute_all_embeddings(self, df_data):
        embd_path = self._get_path(prefix='af_embedding', args=(df_data, ))
        # PosixPath("MCQ-2000/aflite/af_embedding_b'7082dbfa64cb9ce8e77b'.npy")
        if not embd_path.exists():
            logger.info(f'Creating embeddings')
            _module = MyEmbedder(self.config, tmp_path=embd_path)
            _module.eval()
            trainer = pl.Trainer(
                gradient_clip_val=0,
                gpus=self.config.gpus,
                show_progress_bar=True,
                accumulate_grad_batches=1,
                max_epochs=1,
                min_epochs=1,
                val_check_interval=1,
                weights_summary='top',
                num_sanity_val_steps=2,
                resume_from_checkpoint=None,
                # accelerator='dp',
            )
            try:
                with torch.no_grad():
                    trainer.test(_module, test_dataloaders=[_module.get_dataloader(df_data.values)])
            except Exception as e:
                logger.error(f'Error computing the Embeddings: {e}')
                raise e

        logger.info(f'Reading embeddings from {embd_path}')
        np_embd = np.load(embd_path)
        return np_embd

    def run_af_algorithm(self, np_embd, np_label):

        filtered_path = self._get_path(prefix='af_filtered', args=(np_embd, np_label))

        if not filtered_path.exists():
            logger.info(f'Running the AF')
            np_removed_idx = self.aflite_filtering(np_embd=np_embd, np_label=np_label)
            np.save(filtered_path, np_removed_idx)
        else:
            logger.info(f'Loading AF results from {filtered_path}')
            np_removed_idx = np.load(filtered_path)

        return np_removed_idx

    def aflite_filtering(self, np_embd: np.ndarray, np_label: np.ndarray) -> np.ndarray:
        # m = 20000
        # n = 64
        # delta = 0.75
        # k = 500
        # max_rounds = 20

        if len(np_embd) <= self.config.af.m:
            return np.array([])

        removed_idxs = []
        df_embd = pd.DataFrame(np_embd)
        df_label = pd.Series(np_label)

        for r in tqdm(range(self.config.af.max_rounds), desc='Rounds'):
            df_scores = self._compute_af_scores(df_embd, df_label, max_idx=len(np_embd))

            removes = df_scores[df_scores['score'] > self.config.af.delta].sort_values(by='score', axis=0, ascending=False)
            removes = removes[:min(self.config.af.k, len(removes))].index

            removed_idxs += list(removes)

            logger.info(f'Dropping {len(removes)} rows from Data')
            df_embd.drop(index=removes, inplace=True)
            df_label.drop(index=removes, inplace=True)

            logger.info(f'Finished round {r} of advfiltering.')
            if len(removes) < 100:
                logger.info(f'AF terminated due to small number of removes')
                break
            if len(df_embd) <= self.config.af.m:
                logger.info(f'AF terminated due to small number remaining dataset')
                break
        return np.array(removed_idxs)

    def _compute_af_scores(self, df_embd, df_label, max_idx):
        df_ens_preds = pd.DataFrame(np.zeros(shape=(max_idx, 1)))
        df_ens_crcts = pd.DataFrame(np.zeros(shape=(max_idx, 1)))

        oh_enc = OneHotEncoder(sparse=False)
        oh_enc.fit(df_label.values.reshape([-1, 1]))
        logger.info(f'Done fitting OH encoder on data.')
        for i in tqdm(range(self.config.af.n), desc='ensemble iterations'):
            ratio = 1.0 - float(self.config.af.m) / len(df_embd)
            x_train, x_test, y_train, y_test = train_test_split(df_embd, df_label, test_size=ratio)

            clf = LogisticRegression(class_weight='balanced', random_state=57 + i)

            y_pred = pd.Series(
                self._get_clf_y_pred(
                    clf=clf,
                    x_test=x_test.values,
                    x_train=x_train.values, y_train=y_train.values,
                    oh_enc=oh_enc),
                index=x_test.index
            )

            np_preds_update = self._ones_sparse(y_test.index, n_mx=max_idx)
            np_crcts_update = self._ones_sparse(y_test[y_test == y_pred].index, n_mx=max_idx)

            df_ens_preds += np_preds_update
            df_ens_crcts += np_crcts_update
        df_scores = (df_ens_crcts / df_ens_preds).fillna(0).rename(columns={0: 'score'})
        return df_scores

    def _get_clf_y_pred(self, clf, x_test: np.ndarray, x_train: np.ndarray, y_train: np.ndarray, oh_enc=None) \
            -> np.ndarray:
        with warnings.catch_warnings():
            warnings.simplefilter("ignore")
            clf.fit(X=x_train, y=y_train)
            y_pred = clf.predict(x_test)
        return y_pred

    @staticmethod
    def _ones_sparse(idx, n_mx):
        return csr_matrix(
            (np.ones(shape=len(idx)),
             (np.array(idx), np.zeros(shape=len(idx)))),
            shape=(n_mx, 1)
        ).toarray()


@hydra.main(config_path='../Configs/AdvFilteringConfig.yaml')
def main(config: omegaconf.dictconfig.DictConfig):
    # assert config.method == 'mc-knn', f'wrong method selected {config.method}'
    AdvFiltering(config).run()


if __name__ == '__main__':
    main()
