from utils.training.validation_utils import val_epoch_output, val_epoch_standard
from evaluation_script import create_data_loader, load_evaluation_objects
from utils.model_utils import MetricModel
from data.data_loader import DataSelector
from utils.model_utils import load_model
from torch.utils.data import DataLoader
from typing import Union
import pandas as pd
import pickle
import torch
import yaml
import os


def load_indices(index_dir_labeled: str, index_dir_unlabeled: str) -> tuple:
    
    # get the path of labeled and unlabeled indices
    if index_dir_unlabeled is not None:
        idx_dir_unlabeled = index_dir_unlabeled + 'remaining_indices.txt'
    
        # load the indices of the unlabeled data samples
        with open(idx_dir_unlabeled, 'r') as f:
            idx_unlabeled = f.readlines()
            idx_unlabeled = [int(x) for x in idx_unlabeled]
    else:
        idx_unlabeled = None
        
    # load the indices of the labeled data samples
    idx_dir_labeled = index_dir_labeled + 'train_indices.txt'
    with open(idx_dir_labeled, 'r') as f:
        idx_labeled = f.readlines()
        idx_labeled = [int(x) for x in idx_labeled]
    return idx_labeled, idx_unlabeled


def select_unlabeled_data(df: pd.DataFrame, parameters: dict) -> pd.DataFrame:
    
    # create a data selector object
    data_selector = DataSelector(**parameters['init'])
    
    # subselect the data
    df, _, _ = data_selector.load_data(df=df, **parameters['load'])
    return df


def prepare_labeled_data(df_labeled: pd.DataFrame) -> pd.DataFrame:
    # split the data into positive and negative labels
    df_pos = df_labeled.loc[df_labeled['Toxicity'] > 0]
    df_neg = df_labeled.loc[df_labeled['Toxicity'] <= 0]
    
    # normalize the positive labels
    labels_pos = df_pos['Toxicity'].tolist()
    labels_pos = [x - min(labels_pos) for x in labels_pos]
    labels_pos = [0.5 * x / max(labels_pos) + 0.5 for x in labels_pos]
    
    # normalize the negative labels
    labels_neg = df_neg['Toxicity'].tolist()
    labels_neg = [x - min(labels_neg) for x in labels_neg]
    labels_neg = [0.5 * x / max(labels_neg) for x in labels_neg]
    
    # store the labels in the dataframes
    df_pos['Toxicity'] = labels_pos
    df_neg['Toxicity'] = labels_neg
    
    # concatenate the dataframes
    df_labeled = pd.concat([df_pos, df_neg])
    return df_labeled


def load_data(parameters: dict, delimiter: str = '\t') -> tuple:
    
    # get the paths of the unlabeled and labeled train data
    data_path_unlabeled = parameters['data']['load']['unlabeled']
    data_path_labeled = parameters['data']['load']['labeled']
    
    # get the paths of the indices for the unlabeled and labeled training data
    index_dir_labeled = parameters['data']['load']['indices_labeled']
    index_dir_unlabeled = parameters['data']['load']['indices_unlabeled']

    # load the data
    df_unlabeled = pd.read_csv(data_path_unlabeled, sep=delimiter)
    df_labeled = pd.read_csv(data_path_labeled, sep=delimiter)

    # select the data according to the indices
    idx_labeled, idx_unlabeled = load_indices(index_dir_labeled, index_dir_unlabeled)
    if idx_unlabeled is not None:
        df_unlabeled = df_unlabeled.iloc[idx_unlabeled]
    df_labeled = df_labeled.iloc[idx_labeled]
    
    # normalize the labels in the labeled data
    df_labeled = prepare_labeled_data(df_labeled)
    
    # subselect the unlabeled data
    selector_params = parameters['selector']
    df_unlabeled = select_unlabeled_data(df_unlabeled, selector_params)
    return df_labeled, df_unlabeled


def load_pretrained_model(parameters: dict) -> MetricModel:
    
    # load a raw model
    model = load_model(device=parameters['model']['device'], **parameters['model']['kwargs'])
    
    # load the pretrained weights into the model
    model.load_state_dict(torch.load(parameters['model']['path'], map_location=parameters['model']['device']))
    
    with open(parameters['model']['calibrator'], 'rb') as f:
        calibrator = pickle.load(f)
    return model, calibrator


def compute_pseudo_labels(parameters: dict, df: pd.DataFrame):
    
    # extract the data loader properties
    data_key, label_key = parameters['data_loader']['data_key'], parameters['data_loader']['label_key']

    # extract the data loader properties
    batch_size = parameters['data_loader']['batch_size']
    
    # create a data set object
    dset = DataSelector.create_dataset(df, data_key, label_key)
    
    #if dset.labels is not None:
    #    dset.labels = [0 if x <= 0.5 else 1 for x in dset.labels]
    #else:
    #    dset.labels = [-1] * len(dset.data)
    dset.labels = [x if x != 2 else 1 for x in dset.labels]
    
    # create a data loader
    data_loader = DataLoader(dset, batch_size, shuffle=False)
    
    # load a pretrained model
    model, calibrator = load_pretrained_model(parameters)
    
    # load other objects, needed for model evaluation and pseudo labeling
    output_path = parameters['output']['path'] + 'pseudo_labels/'
    class_names = parameters['output']['class_names']
    metrics, ps, writer = load_evaluation_objects(output_path, class_names)
    
    # evaluate the model
    if parameters['model']['kwargs']['standard']:
        _, _, predictions, targets, data_list = val_epoch_standard(model, data_loader, metrics, writer, 0, ps, calibrator)
    else:
        _, _, predictions, targets, data_list = val_epoch_output(model, data_loader, metrics, writer, 0, ps, calibrator)
    return predictions, targets, data_list

    
if __name__ == '__main__':
    
    # path of the parameter file
    parameter_file = 'parameters/pseudo_labels.yaml'
    
    # load the parameters
    with open(parameter_file, 'r') as f:
        parameters = yaml.safe_load(f)

    # path in which the results should be stored
    output_path = parameters['output']['path']
    if not os.path.exists(output_path):
        os.makedirs(output_path)

    # load the data
    df_labeled, df_unlabeled = load_data(parameters)

    # compute the pseudo labels for the unlabeled data
    predictions, _, data_list = compute_pseudo_labels(parameters, df_unlabeled)

    # create a dicitonar containing data, pseudo labels and correct labels
    data_dict_u = {
        'Text': data_list,
        'Correct_Targets': df_unlabeled['label'],
        'Toxicity': predictions.detach().cpu().tolist(),
    }

    # create data frames for labeled and unlabeled data
    df_unlabeled = pd.DataFrame.from_dict(data_dict_u)[['Text', 'Toxicity', 'Correct_Targets']]
    df_labeled['Correct_Targets'] = df_labeled['Toxicity'].tolist()
    df_labeled = df_labeled[['Text', 'Toxicity', 'Correct_Targets']]
    
    print(df_unlabeled)

    # create a common dataset
    df_out = pd.concat([df_labeled, df_unlabeled], ignore_index=True)
    print(df_out)

    # store the data
    output_name = parameters['output']['output_name']
    df_out.to_csv(output_path + output_name, sep='\t', index=False)
