from .train_utils import train_epoch_metric, train_epoch_output, train_epoch_standard
from .validation_utils import val_epoch_metric, val_epoch_output, val_epoch_standard
from .metric_utils import ClassificationMetrics, PredictionStats
from pytorch_metric_learning.distances import CosineSimilarity
from pytorch_metric_learning.reducers import ThresholdReducer
from pytorch_metric_learning.miners import TripletMarginMiner
from pytorch_metric_learning.losses import TripletMarginLoss
from ..model_utils import MetricModel, StandardModel
from torch.utils.tensorboard import SummaryWriter
from torch.optim import Adam
from typing import Iterator
import torch.nn as nn
import pickle
import torch
import math


def extract_train_utils(parameters: dict, exp_dir: str) -> tuple:
    
    """
    Extracts pre-defined properties and objects, which are required for the training procedure.
    
    Parameters
    ----------
    parameters: Parameter file.
    exp_dir: Directory of the current experiment.

    Returns
    -------
    accumulation_steps: Number of steps for gradient accumulation.
    num_train_iter: Number of iterations in each train epoch.
    train_logger: Tensorboard writer, which logs the training metrics into a tensorboard file.
    class_names: Names of the classes of the classification problem.
    num_epochs: Number of training epochs.
    metrics: Module, which computes common metrics for evaluating the performance of a classifier.
    ps: Object, which computes statistics over model predictions.
    """
    
    # get the desired number of gradient accurulation steps
    accumulation_steps = parameters['training']['accumulation_steps']

    # number of training epochs
    num_epochs = parameters['training']['num_epochs']
    
    num_train_iter = parameters['training']['num_iter']
    
    # extract the class names
    class_names = parameters['metrics']['class_names']

    # initialize a ClassificationMetrics object
    metrics = ClassificationMetrics(**parameters['metrics'])

    # initialize a logger for the trainign iterations
    train_logger_metric = SummaryWriter(exp_dir + 'tb/backbone/')
    train_logger_output = SummaryWriter(exp_dir + 'tb/calibrated/')

    # initialize a PredictionsStatistics object
    ps = PredictionStats(class_names)
    return accumulation_steps, num_train_iter, train_logger_metric, train_logger_output, class_names, num_epochs, metrics, ps


def extract_metric_properties(parameters: dict) -> tuple:

    # extract the properties for triplet learning
    num_neighbors = parameters['training']['num_neighbors']
    var_threshold = parameters['training']['var_threshold']
    triplet_type = parameters['training']['triplet_type']
    margin = parameters['training']['margin']
    
    # store them in a tuple
    properties = (num_neighbors, var_threshold, triplet_type, margin)
    return properties


def init_metric_loss(parameters: dict) -> tuple:
    """
    Initializes all loss components required for metric learning.
    
    Parameters
    ----------
    parameters: Parameter dict, in which all required information is stored.

    Returns
    -------
    loss_fun: TripletMargin Loss function.
    mining_fun: TripletMArgin miner, used for selecting triplets, which are used for training.
    """
    
    # extract the properties for metric learning
    num_neighbors, var_threshold, triplet_type, margin = extract_metric_properties(parameters)
    
    # initialize the loss function
    distance = CosineSimilarity()
    reducer = ThresholdReducer(low=0)
    loss_fun = TripletMarginLoss(margin=margin, distance=distance, reducer=reducer)
    
    # initialize the tripplet miner
    mining_fun = TripletMarginMiner(margin=margin, distance=distance, type_of_triplets=triplet_type)
    return loss_fun, mining_fun, num_neighbors, var_threshold


def train_standard_model(num_epochs: int, num_train_iter: int, acc_steps: int, model: StandardModel, train_loader: Iterator,
                         val_loader: Iterator, optimizer: Adam, metrics: ClassificationMetrics,
                         writer: SummaryWriter, ps: PredictionStats, output_path: str):
    
    # initialize the f1-score for the best model
    f1_score_best = - math.inf
    
    # train the model on the desired number of training epochs
    for epoch in range(num_epochs):
        # display the current epoch
        print('Epoch ' + str(epoch))
    
        # train the model
        train_epoch_standard(num_train_iter, acc_steps, model, train_loader, optimizer)
        
        # evaluate the model
        f1_score, bc, predictions, targets, data_list = val_epoch_standard(model, val_loader, metrics, writer, epoch, ps)
        
        # store the model, if it's better than previsous versions
        if f1_score > f1_score_best:
            f1_score_best = f1_score
            torch.save(model.state_dict(), output_path + 'model.pt')
        
        # reload the best model
    model.load_state_dict(torch.load(output_path + 'model.pt'))
    return model
        

def train_backbone_model(num_epochs: int, num_train_iter: int, class_names: str, acc_steps: int, model: MetricModel,
                         loss_fun: TripletMarginLoss, mining_fun: TripletMarginMiner, train_loader: Iterator,
                         ref_loader: Iterator, val_loader: Iterator, optimizer: Adam, num_neighbors: int,
                         metrics: ClassificationMetrics, writer: SummaryWriter, pca_var: float,
                         ps: PredictionStats, output_path: str) -> MetricModel:

    # initialize the f1-score for the best model
    f1_score_best = - math.inf
    
    # train the model on the desired number of training epochs
    for epoch in range(num_epochs):
        
        # display the current epoch
        print('Epoch ' + str(epoch))
        
        # train the model
        train_epoch_metric(num_train_iter, acc_steps, model, loss_fun, mining_fun, train_loader, optimizer)
    
        # validate the model
        f1_score = val_epoch_metric(model, ref_loader, val_loader, num_neighbors, metrics, class_names, writer, epoch, pca_var, ps)
        
        # store the model, if it's better than previsous versions
        if f1_score > f1_score_best:
            f1_score_best = f1_score
            torch.save(model.state_dict(), output_path + 'model.pt')
    
    # reload the best model
    model.load_state_dict(torch.load(output_path + 'model.pt'))
    return model

            
def train_output_layer(num_epochs: int, train_iter: int, model: MetricModel, train_loader: Iterator, val_loader: Iterator,
                       metrics: ClassificationMetrics, writer: SummaryWriter, ps: PredictionStats, output_path: str):
    model.train()
    
    # create an optimizer for the model
    optimizer = Adam(model.classifier.parameters(), lr=1e-2, weight_decay=0.01)
    
    # freeze the backbone model
    for param in model.backbone.parameters():
        param.requires_grad = False
    
    # initialize the f1-score for the best model
    f1_score_best = - math.inf
    
    # initialize the loss function
    loss_fun = nn.CrossEntropyLoss()
    
    # train the model on the desired number of training epochs
    for epoch in range(num_epochs):
        
        # display the current epoch
        print('Epoch ' + str(epoch))
        
        # train the model
        train_epoch_output(train_iter, model, loss_fun, train_loader, optimizer)
        
        # validate the model
        f1_score, bc, _, _, _ = val_epoch_output(model, val_loader, metrics, writer, epoch, ps, bc=None)

        # store the model, if it's better than previsous versions
        if f1_score > f1_score_best:
            
            # update the currently achieved best f1 score
            f1_score_best = f1_score
            
            # save the model
            torch.save(model.state_dict(), output_path + 'model.pt')
            
            # save the beta calibration module
            with open(output_path + 'calibrator.pckl', 'wb') as f:
                pickle.dump(bc, f)


# noinspection PyTupleAssignmentBalance
def metric_training(parameters: dict, exp_dir: str, model: nn.Module, optimizer: Adam, train_loader: Iterator,
                    ref_loader: Iterator, val_loader: Iterator, model_path: str) -> None:
    
    # Extracts pre-defined properties and objects, which are required for the training procedure
    train_parameters = extract_train_utils(parameters, exp_dir)
    accumulation_steps, num_train_iter, logger_metric, logger_output, class_names, num_epochs, metrics, ps = train_parameters
    
    # extract the loss components required for metric learning.
    loss_fun, mining_fun, num_neighbors, pca_var = init_metric_loss(parameters)
    
    # training procedure of the backbone model
    model = train_backbone_model(num_epochs, num_train_iter, class_names, accumulation_steps, model,
                                 loss_fun, mining_fun, train_loader, ref_loader, val_loader, optimizer,
                                 num_neighbors, metrics, logger_metric, pca_var, ps, model_path)
    
    train_output_layer(num_epochs, 2, model, train_loader,
                       val_loader, metrics, logger_output, ps, model_path)


def standard_training(parameters: dict, exp_dir: str, model: nn.Module, optimizer: Adam,
                      train_loader: Iterator, val_loader: Iterator, model_path: str) -> None:
    
    # Extracts pre-defined properties and objects, which are required for the training procedure
    train_parameters = extract_train_utils(parameters, exp_dir)
    acc_steps, num_train_iter, _, logger, _, num_epochs, metrics, ps = train_parameters
    
    train_standard_model(num_epochs, num_train_iter, acc_steps, model, train_loader,
                         val_loader, optimizer, metrics, logger, ps, model_path)
