from utils.training.validation_utils import val_epoch_output, val_epoch_standard
from utils.training.metric_utils import ClassificationMetrics, PredictionStats
from torch.utils.tensorboard import SummaryWriter
from utils.model_utils import load_model
from torch.utils.data import DataLoader
from main import load_dataset
import pickle
import torch
import yaml


def create_data_loader(parameters: dict, data_split: str = 'validation') -> DataLoader:
    
    # extract the parameters of the validation data
    selector_params = parameters['data'][data_split]['selector']['measuring']
    
    # extract the batch size
    batch_size = parameters['data'][data_split]['data_loader']['measuring']['batch_size']
    
    # load the validation set
    dset, _, _ = load_dataset(selector_params)
    
    # create the data loader
    data_loader = DataLoader(dset, batch_size, shuffle=False)
    return data_loader


def load_evaluation_objects(output_path: str, class_names: list) -> tuple:
    
    # create a tensorboard summary writer
    writer = SummaryWriter(output_path)
    
    # create a ClassificationMetrics object
    metrics = ClassificationMetrics(class_names)
    
    # create a PredictionStats object
    ps = PredictionStats(class_names)
    return metrics, ps, writer


if __name__ == '__main__':
    
    # path of the parameter file
    parameter_file = 'parameters/measuring/heureka/measuring.yaml'
    
    # path of the model
    model_path = '/data1/flo/models/measuring/baselines/baseline_standard_balanced/07_25_2023__14_22_03/model.pt'
    
    # path in which the results should be stored
    output_path = '/data1/flo/models/measuring/baselines/baseline_standard_balanced/07_25_2023__14_22_03/'
    
    # load the parameters
    with open(parameter_file, 'r') as f:
        parameters = yaml.safe_load(f)

    # create a data loader for the validation data
    data_loader = create_data_loader(parameters)
    
    # load the model
    device = 'cuda:0'
    model = load_model(device=device, **parameters['model']['kwargs'])
    model.load_state_dict(torch.load(model_path, map_location=device))
    
    # load additional objects used for model evaluation
    class_names = ['Normal', 'Toxic']
    metrics, ps, writer = load_evaluation_objects(output_path, class_names)
    
    # evaluate the model
    if parameters['training']['type'] == 'standard':
        _, bc, _, _, _ = val_epoch_standard(model, data_loader, metrics, writer, 0, ps)
    else:
        _, bc, _, _, _ = val_epoch_output(model, data_loader, metrics, writer, 0, ps)

    # store the beta calibration module
    with open(output_path + 'calibrator.pckl', 'wb') as f:
        pickle.dump(bc, f)
