from torch.utils.data import DataLoader
from torch.utils.data import Dataset
from collections import defaultdict
from sklearn import model_selection
import matplotlib.pyplot as plt
from collections import Counter
from typing import Union, List
from copy import deepcopy
import pandas as pd
import numpy as np
import random
import torch
import math
import os


class DataSet(Dataset):
    def __init__(self, data: list, labels: Union[None, list]):
        super().__init__()
        self.labels = labels
        self.data = data

    def get_labels(self):
        return self.labels
    
    def get_label_frequencies(self):
        return dict(Counter(self.labels))
    
    def sub_sample_data(self, num_samples: int) -> None:
        
        assert num_samples < len(self), "Sample size too large"
        indices = random.sample(list(range(len(self.labels))), num_samples)
        self.labels = [self.labels[idx] for idx in indices]
        self.data = [self.data[idx] for idx in indices]
    
    def __len__(self):
        return len(self.data)

    def __getitem__(self, idx: int) -> Union[tuple, object]:
        
        if self.labels is not None:
            return self.data[idx], self.labels[idx]
        else:
            return self.data[idx]


class DataSelector:

    def __init__(self, include_dict: Union[None, dict], exclude_dict: Union[None, dict],
                 tmp_mapping: Union[None, dict], final_mapping: Union[None, dict], keys: list):
           
        # label mapping after data selection
        self.final_mapping = final_mapping
        
        # label mapping before data selection
        self.tmp_mapping = tmp_mapping

        # defines which samples should be included into the training procedure
        self.include_dict = include_dict

        # defines which samples should be excluded from the training procedure
        self.exclude_dict = exclude_dict
        
        # the keys of the data frame which are included in the final data frame
        self.keys = keys

    @classmethod
    def include_hard_labels(cls,  df: pd.DataFrame, include_dict: dict) -> pd.DataFrame:
    
        if include_dict['hard_labels'] is not None:
            selection_dict = include_dict['hard_labels']
            
            # for each selection criteria, select the corresponding rows in the data frame
            for label_name, labels in selection_dict.items():
                df = df.loc[df[label_name].isin(labels)]
        return df

    @classmethod
    def exclude_hard_labels(cls, df: pd.DataFrame, exclude_dict: dict) -> pd.DataFrame:
        
        if exclude_dict['hard_labels'] is not None:
            discard_dict = exclude_dict['hard_labels']

            # for each exclution criteria, discard the corresponding rows in the data frame
            for label_name, labels in discard_dict.items():
                df = df.loc[~df[label_name].isin(labels)]
        return df

    @classmethod
    def include_soft_labels(cls, df: pd.DataFrame, include_dict: dict) -> pd.DataFrame:
        
        if include_dict['soft_labels'] is not None:
            selection_dict = include_dict['soft_labels']
            
            for label_name, include_values in selection_dict.items():
                
                if type(include_values[0]) == int:
                    min_val, max_val = include_values[0], include_values[1]
                    df = df.loc[(min_val <= df[label_name]) & (df[label_name] <= max_val)]
                else:
                    min_values = include_values[0]
                    max_values = include_values[1]
                    
                    df_list = []
                    for min_val, max_val in zip(min_values, max_values):
                        df_list.append(df.loc[(min_val <= df[label_name]) & (df[label_name] <= max_val)])
                    df = pd.concat(df_list)
        return df

    @classmethod
    def exclude_soft_labels(cls, df: pd.DataFrame, exclude_dict: dict) -> pd.DataFrame:
    
        if exclude_dict['soft_labels'] is not None:
            discard_dict = exclude_dict['soft_labels']
        
            for label_name, exclude_dict in discard_dict.items():
                min_val, max_val = exclude_dict[0], exclude_dict[1]
                df = df.loc[~((min_val <= df[label_name]) & (df[label_name] <= max_val))]
        return df

    @classmethod
    def select_data(cls, df: pd.DataFrame, include_dict: dict, exclude_dict: dict) -> pd.DataFrame:

        # include datapoints containing certain labels
        df = cls.include_hard_labels(df, include_dict)
        df = cls.include_soft_labels(df, include_dict)

        # exclude datapoints containing certain labels
        df = cls.exclude_hard_labels(df, exclude_dict)
        df = cls.exclude_soft_labels(df, exclude_dict)
        return df
    
    @classmethod
    def map_hard_labels(cls, df: pd.DataFrame, mapping_dict: dict) -> pd.DataFrame:
        
        if mapping_dict['hard_labels'] is not None:
            mapping_dict = mapping_dict['hard_labels']
            
            for label_name, mapping in mapping_dict.items():
                df[label_name] = df[label_name].map(mapping)
        return df
    
    @classmethod
    def map_soft_labels(cls, df: pd.DataFrame, mapping_dict: dict) -> pd.DataFrame:
        
        def map_list(values: list, thresholds: list, labels: list):
            result = []
            for value in values:
                
                # Check if value is less than the lowest threshold
                if value <= thresholds[0]:
                    result.append(labels[0])

                else:
                    
                    # Iterate through the threshold ranges
                    for i in range(len(thresholds) - 1):
                
                        # Check if value falls between two threshold values
                        if thresholds[i] <= value < thresholds[i + 1]:
                            result.append(labels[i + 1])
                            break
            return result
        
        if mapping_dict['soft_labels'] is not None:
            
            mapping_dict = mapping_dict['soft_labels']
            for label_name, mapping in mapping_dict.items():
                
                # extract the mapping parameters
                thresholds = mapping['thresholds']
                labels = mapping['labels']
                
                # get the values which should be mapped
                values = df[label_name].tolist()
                
                # map the values
                results = map_list(values, thresholds, labels)
                
                # store the results
                df[label_name] = results
        return df
        
    @classmethod
    def map_labels(cls, df: pd.DataFrame, mapping_dict: dict) -> pd.DataFrame:

        # map the hard labels
        df = cls.map_hard_labels(df, mapping_dict)
        
        # map the soft labels
        df = cls.map_soft_labels(df, mapping_dict)
        return df

    @classmethod
    def count_class_labels(cls, df: pd.DataFrame, label: str):
        counts = df[label].value_counts().to_dict()
        ratios = deepcopy(counts)
        
        total_count = sum(list(ratios.values()))
        for key, value in ratios.items():
            ratios[key] = float(value) / total_count
        return counts, ratios

    @classmethod
    def get_max_counts(cls, df: pd.DataFrame, label: str, ratios_new: Union[dict, None]):

        df['global_idx'] = list(range(df.shape[0]))
        counts_org, ratios_org = cls.count_class_labels(df, label)
        
        if ratios_new is not None:
            total = len(df)
            counts_new = dict()
            error_list = []
            for key, ratio in ratios_new.items():
                counts_new[key] = ratios_new[key] * total
                if counts_new[key] > counts_org[key]:
                    error_list.append(key)
            
            if len(error_list) > 0:
                r = []
                for key in error_list:
                    r.append(counts_org[key]/float(counts_new[key]))
                r = min(r)
                        
                for key, value in counts_new.items():
                    counts_new[key] = round(r * value)
        else:
            ratios_new, counts_new = ratios_org, counts_org
            
        total = sum(list(counts_new.values()))
        ratios_new = dict()
        for key, value in counts_new.items():
            ratios_new[key] = float(value) / total
        return counts_new, ratios_new, total
 
    @classmethod
    def get_min_total_num(cls, df: pd.DataFrame, label: str, ratio_list: List[dict]):
        
        num_list = list()
        for ratios in ratio_list:
            _, _, total = cls.get_max_counts(df, label, ratios)
            num_list.append(total)
        min_total_num = min(num_list)
        return min_total_num
    
    @classmethod
    def subsample_data(cls, df: pd.DataFrame, label: str, ratios: Union[None, dict], max_num: Union[None, int] = None):
        
        counts_new, ratios_new, total = cls.get_max_counts(df, label, ratios)
        
        if max_num is not None:
            assert max_num <= total, "Maximal Number provided is to large for the provided data and the provided class ratio."
            sub_sample_ratio = float(max_num) / total

            for key, values in counts_new.items():
                counts_new[key] = round(sub_sample_ratio * values)

        df_list, global_indices = list(), list()
        for key, value in counts_new.items():
            
            df_tmp = df.loc[df[label] == key]
            indices = random.sample(list(range(df_tmp.shape[0])), value)
            df_tmp = df_tmp.iloc[indices, :]
            global_indices.extend(list(df_tmp['global_idx']))
            df_list.append(df_tmp)
        df_new = pd.concat(df_list)
        df_new.drop(['global_idx'], axis=1, inplace=True)
        return df_new, global_indices
        
    def load_data(self, data_path: str, sub_sample_label: Union[None, str], ratios: Union[None, dict] = None,
                  max_num: Union[None, int] = None, keep_tmp_mapping: bool = True, delimiter: str = '\t',
                  df: Union[None, pd.DataFrame] = None) -> pd.DataFrame:

        # get the yaml files for all data samples in the provided folder
        if df is None:
            df = pd.read_csv(data_path, sep=delimiter)
        total_num_samples = len(df)
        
        # select the data, as specified in the dictionaries
        df = self.select_data(df, self.include_dict, self.exclude_dict)
        
        global_indices = list(range(df.shape[0]))
        
        # optionally select a subset of data
        if (sub_sample_label is not None) or (ratios is not None):
            
            # compute a first label mapping
            if self.tmp_mapping is not None:
                df_tmp = self.map_labels(deepcopy(df), self.tmp_mapping)
            else:
                df_tmp = deepcopy(df)
            
            df_new, global_indices = self.subsample_data(df_tmp, sub_sample_label, ratios, max_num)
            
            if keep_tmp_mapping:
                df = df_new
            else:
                df = df.iloc[global_indices, :]
        
        # map the labels
        if self.final_mapping is not None:
            df = self.map_labels(df, self.final_mapping)
        
        if self.keys is not None:
            df = df[self.keys]
        return df, global_indices, total_num_samples

    @classmethod
    def create_dataset(cls, df: pd.DataFrame, data_key: str, label_key: Union[None, str]) -> DataSet:
        
        # extract the data
        data = df[data_key].tolist()
        
        # extract the labels, if they exist
        if label_key is None:
            labels = None
        else:
            labels = df[label_key].tolist()
            
        # create a dataset object and return it
        dataset = DataSet(data, labels)
        return dataset
