from functools import partial
import random
from random import shuffle
from typing import Tuple, List, Generator
# from itertools import combinations
import itertools

import hashlib

import IPython
import os
import pandas as pd
import numpy as np
import hydra
import omegaconf
import pathlib

import torch
import warnings
from scipy.sparse import csr_matrix
from sklearn.model_selection import train_test_split
from sklearn.preprocessing import OneHotEncoder
from tqdm import tqdm
import logging
from sklearn.linear_model import LogisticRegression
from sklearn.preprocessing import normalize
import pytorch_lightning as pl

from AdvFiltering import AdvFiltering
from BasicBenchmark import DatasetSplitter
from MyEmbedder import MyEmbedder
from Utils import my_df_flatmap

logger = logging.getLogger(__name__)


class AdvFilteringMC(AdvFiltering):
    def __init__(self, config: omegaconf.dictconfig.DictConfig):
        super().__init__(config)

    def run(self):
        output_path = pathlib.Path(self.config.results_path)
        os.makedirs(output_path, exist_ok=True)

        df_examples = self._load_mc_bechmark()
        df_examples.reset_index(drop=True, inplace=True)

        np_embd = self._compute_all_embeddings(df_examples['text'])
        d_embd = np_embd.shape[1]
        np_embd = np_embd.reshape([-1, 4*d_embd])

        np_oh_label = df_examples['label'].values.reshape([-1, 4])
        _, np_label = np.where(np_oh_label == 1)

        assert np_oh_label.shape[0] == np_label.shape[0]
        assert np_embd.shape[0] == np_label.shape[0]

        np_removed_idx = self.run_af_algorithm(np_embd, np_label)

        df_bench_orig = self._recreate_mc_benchmark(df_examples)
        df_bench_orig['answer'] = np_label

        df_bench = df_bench_orig.drop(index=np_removed_idx)

        self.split_and_save(df_bench, do_split=False)
        self.split_and_save(df_bench, do_split=True)

    @staticmethod
    def _sanity_check_mc(pd_s: pd.Series) -> bool:
        df_qc = pd_s.apply(lambda r: pd.Series({k: v for k, v in zip(['q', 'a'], r.split('?'))}))

        try:
            assert len(set(df_qc['q'])) == 1, '{}'.format(set(df_qc['q']))
            assert len(set(df_qc['a'])) == 4, '{}'.format(df_qc['a'])
        except AssertionError as e:
            return False

        return True

    def _recreate_mc_benchmark(self, df_data: pd.DataFrame):
        df_bech = pd.read_csv(self.config.mc_path)
        df_text = df_bech.apply(axis=1, func=self._sents_to_mc)
        # output_path = self._get_path(prefix='redo_mc_bench', args=(df_data, ), ext='csv')
        #
        # if not output_path.exists():
        #     logger.info(f'Computing recreated MC benchmark')
        #     df_text = (
        #         pd.DataFrame(df_data['text'].values.reshape(-1, 4))
        #             .apply(func=self._sents_to_mc, axis=1)
        #     )
        #     df_text.to_csv(output_path, index=False)
        # else:
        #     logger.info(f'Loading recreated MC benchmark from {output_path}')
        #     df_text = pd.read_csv(output_path)

        return df_text

    @staticmethod
    def _sents_to_mc(mc_ins: pd.Series) -> pd.Series:
        fact = mc_ins['question']
        n_cs = len(mc_ins) - 3
        if mc_ins['type'] == 'disabling':
            q = ' What makes this impossible? '
        elif mc_ins['type'] == 'enabling':
            q = ' What makes this possible? '
        else:
            raise ValueError(mc_ins['type'])

        return pd.Series({
            'question': fact+q,
            **{f'choices_{i}': mc_ins[f'choice_{i}'] for i in range(n_cs)}
        })
        # df_qc = pd_s.apply(lambda r: pd.Series({k: v for k, v in zip(['q', 'a'], r.split('?'))}))
        #
        # q_set = set(df_qc['q'])
        # if len(q_set) == 1:
        #     q = q_set.pop()
        #     if '?' not in q:
        #         q = q+'?'
        # else:
        #     q = np.nan
        #
        # a_set = set(df_qc['a'])
        # if len(a_set) == 4:
        #     a_list = list(a_set)
        # else:
        #     a_list = [np.nan]*4
        #
        # return pd.Series({
        #     'question': q,
        #     **{f'choices_{i}': c for i, c in enumerate(a_list)}
        # })

    def _get_clf_y_pred(self, clf, x_test: np.ndarray, x_train: np.ndarray, y_train: np.ndarray, oh_enc = None) \
            -> np.ndarray:

        y_train_oh = oh_enc.transform(y_train.reshape([-1, 1])).reshape([-1])
        x_train_flat = x_train.reshape([-1, x_train.shape[1] // 4])
        x_test_flat = x_test.reshape([-1, x_test.shape[1] // 4])
        with warnings.catch_warnings():
            warnings.simplefilter("ignore")
            clf.fit(X=x_train_flat, y=y_train_oh)
            y_pred_probs = clf.predict_proba(x_test_flat)
            y_pred = np.argmax(
                y_pred_probs[:, 1].reshape([-1, len(oh_enc.categories_[0])]),
                axis=1
            )
        return y_pred


@hydra.main(config_path='../Configs/AdvFilteringConfig.yaml')
def main(config: omegaconf.dictconfig.DictConfig):
    # assert config.method == 'mc-knn', f'wrong method selected {config.method}'
    AdvFilteringMC(config).run()


if __name__ == '__main__':
    main()
