import pandas as pd
import os.path


def split_data(df: pd.DataFrame, num_labeled: int) -> tuple:
    """
    Split data into labeled and unlabeled samples.
    
    Parameters
    ----------
    df: Raw data frame.
    num_labeled: Number of labeled samples.

    Returns
    -------
    data_labeled: Dataframe containing the original labeled data samples.
    data_unlabeled: Dataframe containing the unlabeled data samples.
    """
    
    # extract the labeled data
    data_labeled = df.iloc[:num_labeled]
    
    # extract the unlabeled data
    data_unlabeled = df.iloc[num_labeled:]
    return data_labeled, data_unlabeled


def threshold_based_selection(df: pd.DataFrame, pos_threshold: float = 0.9, neg_threshold: float = 0.1) -> pd.DataFrame:
    """
    Select samples based on certain thresholds.
    
    Parameters
    ----------
    df: Raw Dataframe, from which certain samples should be sub-selected.
    pos_threshold: Threshold of the positive data samples.
    neg_threshold: Threshold of the negative samples.

    Returns
    -------
    df: Subselected data frame.
    """
    
    # select pandas data frame
    df = df.loc[(df['Toxicity'] >= pos_threshold) | (df['Toxicity'] <= neg_threshold)]
    return df


def ratio_based_selection(df: pd.DataFrame, pos_ratio: float = 0.01, neg_ratio: float = 0.01) -> pd.DataFrame:
    """
    Select samples based on certain ratios.

    Parameters
    ----------
    df: Raw Dataframe, from which certain samples should be sub-selected.
    pos_ratio: Ratio of selected positive samples.
    neg_ratio: Ratio of negative samples.

    Returns
    -------
    df: Subselected data frame.
    """
    
    # sort the data frame in ascending order
    df_neg = df.sort_values(by=['Toxicity'], ascending=True)
    
    # select only negative samples
    df_neg = df_neg.loc[df_neg['Toxicity'] <= 0.5]
    
    # get the number of negative samples
    num_neg = int(neg_ratio * len(df_neg))
    df_neg = df_neg.iloc[:num_neg]
    
    # sort the data frame in ascending order
    df_pos = df.sort_values(by=['Toxicity'], ascending=False)

    # select only negative samples
    df_pos = df_pos.loc[df_pos['Toxicity'] > 0.5]
    
    # get the number of positive samples
    num_pos = int(pos_ratio * len(df_pos))
    df_pos = df_pos.iloc[:num_pos]

    # concatenate the selected data samples
    df = pd.concat([df_neg, df_pos])
    return df


if __name__ == '__main__':
    
    # path of the data file
    data_path = '/data1/flo/data/paper2/measuring/pseudo_labels/external/computed/hatexplain/homophobia.tsv'
    output_path = '/data1/flo/data/paper2/measuring/pseudo_labels/external/selected/hatexplain/'
    data_name = data_path.split('/')[-1].split('.')[0]
    
    # create the output path
    if not os.path.exists(output_path):
        os.makedirs(output_path)
    
    # which thresholds are used for threshold based selection
    thresholds = (0.80, 0.20)
    
    # ratios
    ratios = (0.1, 0.1)

    # number of labeled samples
    num_labeled = 200

    # read the raw data frame
    df = pd.read_csv(data_path, sep='\t')
    
    # divide labeled and unlabeled data
    df_labeled, df_unlabeled = split_data(df, num_labeled)
    
    # select the data based on different data selection methods
    df_threshold = threshold_based_selection(df_unlabeled, *thresholds)
    df_ratio = ratio_based_selection(df_unlabeled, *ratios)
    
    # concatenate labeled and selected data
    df_threshold = pd.concat([df_labeled, df_threshold])
    df_ratio = pd.concat([df_labeled, df_ratio])

    # store the data
    df_threshold.to_csv(output_path + 'threshold/' + data_name + '_threshold_' + str(thresholds[0]) + '.tsv', sep='\t', index=False)
    df_ratio.to_csv(output_path + 'ratio/' + data_name + '_ratio_' + str(ratios[0]) + '.tsv', sep='\t', index=False)
