BASE_PATH = './Probing_Experiments/Adding_Classification_Heads_In_Between/prepared_data/'

import pandas as pd
from sklearn.model_selection import train_test_split

def load_custom_ds(dset_name):
    if dset_name == 'olid_taska':
        df_train = pd.read_csv(BASE_PATH + 'olid_train_A.csv', index_col=False)
        df_test = pd.read_csv(BASE_PATH + 'olid_test_A.csv', index_col=False)
        df_train, df_validation = train_test_split(df_train, test_size=0.25, random_state=42)
        return df_train, df_validation, df_test
    elif dset_name == 'olid_taskb':
        df_train = pd.read_csv(BASE_PATH + 'olid_train_B.csv', index_col=False)
        df_test = pd.read_csv(BASE_PATH + 'olid_test_B.csv', index_col=False)
        df_train, df_validation = train_test_split(df_train, test_size=0.25, random_state=42)
        return df_train, df_validation, df_test
    elif dset_name == 'olid_taskc':
        df_train = pd.read_csv(BASE_PATH + 'olid_train_C.csv', index_col=False)
        df_test = pd.read_csv(BASE_PATH + 'olid_test_C.csv', index_col=False)
        df_train, df_validation = train_test_split(df_train, test_size=0.25, random_state=42)
        return df_train, df_validation, df_test
    elif dset_name == 'abuseval':
        df_train = pd.read_csv(BASE_PATH + 'abuseval_train.csv', index_col=False)
        df_test = pd.read_csv(BASE_PATH + 'abuseval_test.csv', index_col=False)
        df_train, df_validation = train_test_split(df_train, test_size=0.25, random_state=42)
        return df_train, df_validation, df_test
    elif dset_name == 'anatomy_of_hate':
        df_train = pd.read_csv(BASE_PATH + 'anatomy_of_hate_train.csv', index_col=False)
        df_test = pd.read_csv(BASE_PATH + 'anatomy_of_hate_test.csv', index_col=False)
        df_train, df_validation = train_test_split(df_train, test_size=0.25, random_state=42)
        return df_train, df_validation, df_test
    elif dset_name == 'davidson':
        df_train = pd.read_csv(BASE_PATH + 'davidson_train.csv', index_col=False)
        df_test = pd.read_csv(BASE_PATH + 'davidson_test.csv', index_col=False)
        df_train.rename(columns={'final_posts': 'text',
                           'final_labels': 'label'},
                  inplace=True, errors='raise')
        df_test.rename(columns={'final_posts': 'text',
                           'final_labels': 'label'},
                  inplace=True, errors='raise')
        df_train, df_validation = train_test_split(df_train, test_size=0.25, random_state=42)
        return df_train, df_validation, df_test
    elif dset_name == 'dynabench_label':
        df_train = pd.read_csv(BASE_PATH + 'df_dynabench_label_detection_train.csv', index_col=False)
        df_test = pd.read_csv(BASE_PATH + 'df_dynabench_label_detection_test.csv', index_col=False)
        df_train, df_validation = train_test_split(df_train, test_size=0.25, random_state=42)
        return df_train, df_validation, df_test
    elif dset_name == 'dynabench_type':
        df_train = pd.read_csv(BASE_PATH + 'df_dynabench_type_detection_train.csv', index_col=False)
        df_test = pd.read_csv(BASE_PATH + 'df_dynabench_type_detection_test.csv', index_col=False)
        df_train.rename(columns={'type': 'label'},
                  inplace=True, errors='raise')
        df_test.rename(columns={'type': 'label'},
                  inplace=True, errors='raise')
        df_train, df_validation = train_test_split(df_train, test_size=0.25, random_state=42)
        return df_train, df_validation, df_test
    elif dset_name == "hatexplain_label":
        df_train = pd.read_csv(BASE_PATH + 'df_hateXplain_train_label_pred.csv', index_col=False)
        df_test = pd.read_csv(BASE_PATH + 'df_hateXplain_test_label_pred.csv', index_col=False)
        df_train.rename(columns={'final_posts': 'text',
                           'final_labels': 'label'},
                  inplace=True, errors='raise')
        df_test.rename(columns={'final_posts': 'text',
                           'final_labels': 'label'},
                  inplace=True, errors='raise')
        df_train, df_validation = train_test_split(df_train, test_size=0.25, random_state=42)
        return df_train, df_validation, df_test
    elif dset_name == 'hatexplain_target':
        df_train = pd.read_csv(BASE_PATH + 'df_hateXplain_train_target_pred.csv', index_col=False)
        df_test = pd.read_csv(BASE_PATH + 'df_hateXplain_test_target_pred.csv', index_col=False)
        df_train.rename(columns={'final_posts': 'text',
                           'final_targets': 'label'},
                  inplace=True, errors='raise')
        df_test.rename(columns={'final_posts': 'text',
                           'final_targets': 'label'},
                  inplace=True, errors='raise')
        df_train, df_validation = train_test_split(df_train, test_size=0.25, random_state=42)
        return df_train, df_validation, df_test
    elif dset_name == 'latent_hatred_labels':
        df_train = pd.read_csv(BASE_PATH + 'df_latenthatred_labels_train.csv', index_col=False)
        df_test = pd.read_csv(BASE_PATH + 'df_latenthatred_labels_test.csv', index_col=False)
        df_train.rename(columns={'final_posts': 'text',
                           'final_labels': 'label'},
                  inplace=True, errors='raise')
        df_test.rename(columns={'final_posts': 'text',
                           'final_labels': 'label'},
                  inplace=True, errors='raise')
        df_train, df_validation = train_test_split(df_train, test_size=0.25, random_state=42)
        return df_train, df_validation, df_test
    elif dset_name == 'latent_hatred_implicit_class':
        df_train = pd.read_csv(BASE_PATH + 'df_latenthatred_implicit_class_train.csv', index_col=False)
        df_test = pd.read_csv(BASE_PATH + 'df_latenthatred_implicit_class_test.csv', index_col=False)
        df_train.rename(columns={'final_posts': 'text',
                           'final_implicit_class': 'label'},
                  inplace=True, errors='raise')
        df_test.rename(columns={'final_posts': 'text',
                           'final_implicit_class': 'label'},
                  inplace=True, errors='raise')
        df_train, df_validation = train_test_split(df_train, test_size=0.25, random_state=42)
        return df_train, df_validation, df_test
    elif dset_name == 'stormfront':
        df_train = pd.read_csv(BASE_PATH + 'df_stormfront_train.csv', index_col=False)
        df_test = pd.read_csv(BASE_PATH + 'df_stormfront_test.csv', index_col=False)
        df_train.rename(columns={'sentence': 'text',
                           'label': 'label'},
                  inplace=True, errors='raise')
        df_test.rename(columns={'sentence': 'text',
                           'label': 'label'},
                  inplace=True, errors='raise')
        df_train = df_train[df_train["label"] != 'relation']
        df_train = df_train[df_train["label"] != 'idk/skip']
        df_test = df_test[df_test["label"] != 'relation']
        df_test = df_test[df_test["label"] != 'idk/skip']
        df_train, df_validation = train_test_split(df_train, test_size=0.25, random_state=42)
        return df_train, df_validation, df_test
    elif dset_name == 'waseem':
        df_train = pd.read_csv(BASE_PATH + 'df_waseem_train.csv', index_col=False)
        df_test = pd.read_csv(BASE_PATH + 'df_waseem_test.csv', index_col=False)
        df_train.rename(columns={'final_posts': 'text',
                           'final_labels': 'label'},
                  inplace=True, errors='raise')
        df_test.rename(columns={'final_posts': 'text',
                           'final_labels': 'label'},
                  inplace=True, errors='raise')
        df_train, df_validation = train_test_split(df_train, test_size=0.25, random_state=42)
        return df_train, df_validation, df_test
    elif dset_name == 'founta':
        df_train = pd.read_csv(BASE_PATH + 'founta_train.csv', index_col=False)
        df_test = pd.read_csv(BASE_PATH + 'founta_test.csv', index_col=False)
        df_train.rename(columns={'full_text': 'text',
                           'label': 'label'},
                  inplace=True, errors='raise')
        df_test.rename(columns={'full_text': 'text',
                           'label': 'label'},
                  inplace=True, errors='raise')
        df_train, df_validation = train_test_split(df_train, test_size=0.25, random_state=42)
        return df_train, df_validation, df_test
    elif dset_name == 'gab':
        df_train = pd.read_csv(BASE_PATH + 'gab_train.csv', index_col=False)
        df_test = pd.read_csv(BASE_PATH + 'gab_test.csv', index_col=False)
        df_train.rename(columns={'final_posts': 'text',
                           'final_labels': 'label'},
                  inplace=True, errors='raise')
        df_test.rename(columns={'final_posts': 'text',
                           'final_labels': 'label'},
                  inplace=True, errors='raise')
        df_train, df_validation = train_test_split(df_train, test_size=0.25, random_state=42)
        return df_train, df_validation, df_test
    elif dset_name == 'hasoc_english_task_1':
        df_train = pd.read_csv(BASE_PATH + 'hasoc_english_task_1_train.csv', index_col=False)
        df_test = pd.read_csv(BASE_PATH + 'hasoc_english_task_1_test.csv', index_col=False)
        df_train.rename(columns={'task_1': 'label'},
                  inplace=True, errors='raise')
        df_test.rename(columns={'task_1': 'label'},
                  inplace=True, errors='raise')
        df_train, df_validation = train_test_split(df_train, test_size=0.25, random_state=42)
        return df_train, df_validation, df_test
    elif dset_name == 'hasoc_english_task_2':
        df_train = pd.read_csv(BASE_PATH + 'hasoc_english_task_2_train.csv', index_col=False)
        df_test = pd.read_csv(BASE_PATH + 'hasoc_english_task_2_test.csv', index_col=False)
        df_train.rename(columns={'task_2': 'label'},
                  inplace=True, errors='raise')
        df_test.rename(columns={'task_2': 'label'},
                  inplace=True, errors='raise')
        df_train, df_validation = train_test_split(df_train, test_size=0.25, random_state=42)
        return df_train, df_validation, df_test
    elif dset_name == 'offenseval':
        df_train = pd.read_csv(BASE_PATH + 'offenseval_train.csv', index_col=False)
        df_test = pd.read_csv(BASE_PATH + 'offenseval_test.csv', index_col=False)
        df_train.rename(columns={'implicit_explicit': 'label'},
                  inplace=True, errors='raise')
        df_test.rename(columns={'implicit_explicit': 'label'},
                  inplace=True, errors='raise')
        df_train, df_validation = train_test_split(df_train, test_size=0.25, random_state=42)
        return df_train, df_validation, df_test
    elif dset_name == 'toxigen_group':
        df_train = pd.read_csv(BASE_PATH + 'toxigen_group_train.csv', index_col=False)
        df_test = pd.read_csv(BASE_PATH + 'toxigen_group_test.csv', index_col=False)
        df_train.rename(columns={'generation': 'text',
                           'group': 'label'},
                  inplace=True, errors='raise')
        df_test.rename(columns={'generation': 'text',
                           'group': 'label'},
                  inplace=True, errors='raise')
        df_train, df_validation = train_test_split(df_train, test_size=0.25, random_state=42)
        return df_train, df_validation, df_test
    elif dset_name == 'toxigen_label':
        df_train = pd.read_csv(BASE_PATH + 'toxigen_label_train.csv', index_col=False)
        df_test = pd.read_csv(BASE_PATH + 'toxigen_label_test.csv', index_col=False)
        df_train.rename(columns={'generation': 'text',
                           'prompt_label': 'label'},
                  inplace=True, errors='raise')
        df_test.rename(columns={'generation': 'text',
                           'prompt_label': 'label'},
                  inplace=True, errors='raise')
        replacement_dict = {0: 'NonHate', 1: 'Hate'}
        df_train['label'] = df_train['label'].replace(replacement_dict)
        df_test['label'] = df_test['label'].replace(replacement_dict)
        df_train, df_validation = train_test_split(df_train, test_size=0.25, random_state=42)
        return df_train, df_validation, df_test