from utils.training.train_model import standard_training, metric_training
from torch.utils.data import DataLoader, WeightedRandomSampler
from utils.experiment import init_experiment
from data.data_loader import DataSelector
from utils.model_utils import load_model
from data.data_loader import DataSet
from torch.optim import Adam
from copy import deepcopy
from typing import Union
import torch
import time
import yaml
import sys


def load_dataset(selector_params: dict, path: Union[str, None] = None) -> DataSet:

    # create a data selector object
    data_selector = DataSelector(**selector_params['init'])

    # load the data
    df, global_indices, total_num_samples = data_selector.load_data(**selector_params['load'])

    # create a DataSet object
    dset = DataSelector.create_dataset(df, selector_params['data_key'], selector_params['label_key'])
    
    if path is not None:
        with open(path, 'w') as f:
            f.write('\n'.join([str(x) for x in global_indices]))
    return dset, global_indices, total_num_samples
    

def create_data_loaders(dset_train, dset_val, parameters: dict, selector_params_train_ref: dict, dataset_name: str) -> tuple:
    
    # extract the desired batch size
    batch_size_train = parameters['data']['train']['data_loader'][dataset_name]['batch_size']
    
    # compute weights for each class
    freq = dset_train.get_label_frequencies()
    cls_weights = {}
    for key, f in freq.items():
        cls_weights[key] = 1. / f

    # compute the sample weights
    samples_weight = torch.DoubleTensor([cls_weights[tar] for tar in dset_train.get_labels()])
    
    # create a sampler, which is used for balanced sampling
    sampler = WeightedRandomSampler(samples_weight.type('torch.DoubleTensor'), len(samples_weight))
    
    # create the train loader
    train_loader = DataLoader(dset_train, batch_size_train, sampler=sampler)
    
    # create the validation loader
    batch_size_val = parameters['data']['validation']['data_loader'][dataset_name]['batch_size']
    val_loader = DataLoader(dset_val, batch_size_val, shuffle=False)
    
    # create the reference train set for computing the nearest neighbors in the validation epoch
    ref_set = deepcopy(dset_train)
    
    # (Optionally) subselect the ref set
    max_num = selector_params_train_ref['load']['max_num']
    if max_num is not None:
        if max_num < len(ref_set):
            ref_set.sub_sample_data(max_num)
    ref_loader = DataLoader(ref_set, batch_size_val, shuffle=False)
    return train_loader, val_loader, ref_loader


def get_arguments() -> tuple:
    
    # file, in which all training parameters are stored
    parameter_file = sys.argv[1]
    
    # name of the dataset
    dataset_name = sys.argv[2]
    
    # device, on which the model should be stored
    device = sys.argv[3]
    
    # label, which is used for sub-selecting the data
    sub_sample_label = None if sys.argv[4] == 'None' else sys.argv[4]
    
    # number of labeled data to include
    num_labeled = None if sys.argv[5] == 'None' else sys.argv[5]
    output = (parameter_file, dataset_name, device, sub_sample_label, num_labeled)
    return output


def extract_parameters(parameters: dict, sub_sample_label: Union[str, None], num_labeled: str, dataset_name: str) -> tuple:
    # get the selector parameters
    selector_params_validation = parameters['data']['validation']['selector'][dataset_name]
    selector_params_train_ref = parameters['data']['train_ref']['selector'][dataset_name]
    selector_params_train = parameters['data']['train']['selector'][dataset_name]
    training_params = parameters['training']
    
    if num_labeled is not None:
        selector_params_train['load']['max_num'] = int(num_labeled)
    if sub_sample_label is not None:
        selector_params_train['load']['sub_sample_label'] = sub_sample_label
    output = (selector_params_train, selector_params_train_ref, selector_params_validation, training_params)
    return output


# noinspection PyTupleAssignmentBalance
def main():

    parameter_file, dataset_name, device, sub_sample_label, num_labeled = get_arguments()
    
    # wait before creating the directory
    time.sleep(3 * int(device))

    # create the directory for the current experiment
    parameters, exp_dir, model_path = init_experiment(parameter_file)
    
    # get the selector parameters
    params_extracted = extract_parameters(parameters, sub_sample_label, num_labeled, dataset_name)
    selector_params_train, selector_params_train_ref, selector_params_validation, training_params = params_extracted

    # load the data
    validation_set, _, _ = load_dataset(selector_params_validation, exp_dir + 'validation_indices.txt')
    train_set, global_train_indices, total_num_samples = load_dataset(selector_params_train, exp_dir + 'train_indices.txt')
    with open(exp_dir + 'remaining_indices.txt', 'w') as f:
        f.write('\n'.join([str(x) for x in list(range(total_num_samples)) if x not in global_train_indices]))
    
    # create data loaders for training and validation
    train_loader, val_loader, ref_loader = create_data_loaders(train_set, validation_set, parameters,
                                                               selector_params_train_ref, dataset_name)
    print(len(train_loader), len(ref_loader), len(val_loader))
    
    # store the adapted parameters in the current experiment directory
    file_name = parameter_file.split('/')[-1]
    with open(exp_dir + file_name, 'w') as f:
        yaml.dump(parameters, f)

    # get the desired model parameters
    model_params = parameters['model']

    # load the model in a parameter efficient configuration
    model = load_model(device='cuda:' + device, **model_params['kwargs'])
    
    # create an optimizer
    optimizer = Adam(model.parameters(), lr=0.00003, weight_decay=0.01)
    
    # start the model training
    if training_params['type'] == 'metric':
        metric_training(parameters, exp_dir, model, optimizer, train_loader,
                        ref_loader, val_loader, model_path)
    else:
        print('Standard Training')
        standard_training(parameters, exp_dir, model, optimizer, train_loader, val_loader, model_path)
    
    
if __name__ == '__main__':
    main()
