import torch
import numpy as np
from evaluation import prepare_task_input, evaluate_embedding
from tqdm import tqdm

def training(train_loader, learner, table,loss_table,args):
    print('\n={}/{}=Iterations/Batches'.format(args.max_iter, len(train_loader)))
    learner.model.train()
    instance_loss=[]
    cluster_loss=[]
    total_loss=[]
    for i in tqdm(np.arange(args.max_iter+1)):
        try:
            batch = next(train_loader_iter)
        except:
            train_loader_iter = iter(train_loader)
            batch = next(train_loader_iter)
        
        feats, _ = prepare_task_input(learner.model, batch, args, is_contrastive=True)
        losses = learner.forward(feats)
        instance_loss.append(losses['Instance-CL_loss'])
        cluster_loss.append(losses['clustering_loss'])
        total_loss.append(losses['total_loss'])
        
        if ((i%args.print_freq==0) or (i==args.max_iter)):
            evaluate_embedding(learner.model,table, args, i)
            learner.model.train()
        
        if i!=0 and ((i%(args.print_freq*10)==0) or (i==args.max_iter)):
            instance_loss=torch.tensor(instance_loss)
            cluster_loss=torch.tensor(cluster_loss)
            total_loss=torch.tensor(total_loss)
            loss_table.add_row([instance_loss.mean(), cluster_loss.mean(), total_loss.mean()])
            instance_loss=[]
            cluster_loss=[]
            total_loss=[]
            print(loss_table)
    return None   



             