from pytorch_metric_learning.miners import TripletMarginMiner
from pytorch_metric_learning.losses import TripletMarginLoss
from ..model_utils import MetricModel, StandardModel
from torch.nn import CrossEntropyLoss
from torch.optim import Optimizer
from typing import Iterator
from tqdm import tqdm
import torch
import sys


def extract_train_features(train_loader: Iterator, model: MetricModel) -> tuple:
    """
    Extract the feature representation and corresponding labels for a new batch of data sampels.
    
    Parameters
    ----------
    train_loader: Data loader for the training data.
    model: MetricModel, which is used to compute the feature representations.

    Returns
    -------
    features: Feature representation of the provided MetricModel for a batch of input data.
    labels: Labels, correposning to the features.
    """

    # get training data and corresponding labels
    data, labels = next(iter(train_loader))
    
    # forward pass through the model
    features = model.feature_extractor(data)
    return features, labels


def train_epoch_standard(num_iter: int, accumulation_steps: int, model: StandardModel,
                         train_loader: Iterator, optimizer: Optimizer) -> None:
    
    # training mode
    model.train()
    
    for i in tqdm(range(num_iter), 'Train Epoch', total=num_iter, file=sys.stdout):
        
        # break if  the number of iterations have been reached
        if i == num_iter:
            break
        
        # clear old gradients
        optimizer.zero_grad()
        
        for _ in range(accumulation_steps):
            
            # get training data and corresponding labels
            data, labels = next(iter(train_loader))
            
            # convert the data type if necessary
            if type(labels) == list:
                labels = torch.LongTensor(labels)
            
            # compute the loss
            loss = model.compute_loss(data, labels)
            
            # backpropagate the loss
            loss.backward()
            
        # update the model
        optimizer.step()
        
        # flush the output
        sys.stdout.flush()


def train_epoch_metric(num_iter: int, accumulation_steps: int, model: MetricModel, loss_func: TripletMarginLoss,
                       mining_func: TripletMarginMiner, train_loader: Iterator, optimizer: Optimizer) -> None:
    # training mode
    model.train()
    
    for i in tqdm(range(num_iter), 'Train Epoch', total=num_iter, file=sys.stdout):
        
        # break if  the number of iterations have been reached
        if i == num_iter:
            break
        
        # clear old gradients
        optimizer.zero_grad()
        
        for _ in range(accumulation_steps):
        
            # extract the features end corresponding labels for the next batch of training data
            features, labels = extract_train_features(train_loader, model)
            
            # compute the loss
            indices_tuple = mining_func(features, labels)
            loss = loss_func(features, labels, indices_tuple)
            
            # backpropagate the error
            loss.backward()
        
        # update the model
        optimizer.step()
        
        # flush the output
        sys.stdout.flush()


def train_epoch_output(train_iter: int, model: MetricModel, loss_fun: CrossEntropyLoss,
                       train_loader: Iterator, optimizer: Optimizer) -> None:
    
    # training mode
    model.train()
    
    for i in tqdm(range(train_iter), 'Train Epoch', total=train_iter, file=sys.stdout):
        
        # break if  the number of iterations have been reached
        if i == train_iter:
            break
            
        # clear old gradients
        optimizer.zero_grad()
        
        # extract the features end corresponding labels for the next batch of training data
        features, labels = extract_train_features(train_loader, model)
        
        # use the correct device
        labels = labels.to(features.device)
        
        # forward pass
        predictions = model.classifier(features)
        
        # compute the loss
        loss = loss_fun(predictions, labels)
        
        # backpropagate the error
        loss.backward()
        
        # update the model parameters
        optimizer.step()
        
        # flush the output
        sys.stdout.flush()
