from torch.utils.data import DataLoader, WeightedRandomSampler
from utils.training.validation_utils import val_epoch_output
from utils.training.train_model import train_output_layer
from evaluation_script import load_evaluation_objects
from utils.model_utils import load_model
from data.data_loader import DataSet
from main import load_dataset
import torch
import yaml


def create_data_loader(parameters: dict, batch_size: int, data_split: str = 'validation') -> DataLoader:
    
    # load the validation set
    dset, _, _ = load_dataset(parameters)

    # create the data loader
    if data_split == 'validation':
        data_loader = DataLoader(dset, batch_size, shuffle=False)
    else:
        
        # compute weights for each class
        freq = dset.get_label_frequencies()
        cls_weights = {}
        for key, f in freq.items():
            cls_weights[key] = 1. / f
        
        # compute the sample weights
        samples_weight = torch.DoubleTensor([cls_weights[tar] for tar in dset.get_labels()])
        
        # create a sampler, which is used for balanced sampling
        sampler = WeightedRandomSampler(samples_weight.type('torch.DoubleTensor'), len(samples_weight))
        
        data_loader = DataLoader(dset, batch_size, sampler=sampler)
    return data_loader


if __name__ == '__main__':
    
    parameter_file = 'parameters/civil_comments/heureka/output_layer.yaml'
    # load the parameters
    with open(parameter_file, 'r') as f:
        parameters = yaml.safe_load(f)

    # extract import paramters
    num_epochs = parameters['training']['num_epochs']
    exp_path = parameters['experiment']['base_path']
    model_path = parameters['model']['path']

    # load the model
    model = load_model(**parameters['model']['kwargs'])
    model.load_state_dict(torch.load(model_path, map_location=parameters['model']['kwargs']['device']))
    
    # load the data
    val_loader = create_data_loader(parameters['data']['validation'], parameters['data']['batch_size'], data_split='validation')
    train_loader = create_data_loader(parameters['data']['train'], parameters['data']['batch_size'], data_split='train')

    # load additional objects used for model evaluation
    class_names = ['Normal', 'Toxic']
    metrics, ps, writer = load_evaluation_objects(exp_path + 'tb/calibrated/', class_names)

    # train the output layers
    train_output_layer(num_epochs, 1, model, train_loader, val_loader, metrics, writer, ps, model_path)
