"""
General class implementing Supervised Contrastive loss.

Dependencies: losses.py
"""
from transformers import Trainer
import torch
import torch.nn.functional as FF
import torch.nn as nn
from losses import SupConLoss
device = torch.device("cuda")
class SupCsTrainer(Trainer):
    w_drop_out = [0.0]
    temperature_s = 0.05
    
    def set_views(self, w_drop_out, temperature):
        """Set the nr of views and probs
        """
        self.w_drop_out = w_drop_out
        self.temperature_s = temperature
    def get_feature(self, model,eval_dataset):
        #labels = inputs.pop("labels")
        model.eval()
        logits_total=[]
        labels_total=[]
        # ----- Default p = 0.1 ---------#
        eval_dataloader = self.get_eval_dataloader(eval_dataset)
        for step, inputs in enumerate(eval_dataloader): 
            labels = inputs.pop("labels")
            inputs = inputs.to(device)
            with torch.no_grad():
                output = model(**inputs)
                logits = output.pooler_output
            #labels = inputs.pop("labels")
            logits_total.append(logits)
            labels_total.append(labels)
            #print("the size of logits:",logits.size())
            #print("the size of labels:",labels.size())
        logits_total = torch.cat(logits_total,0)
        labels_total = torch.cat(labels_total,0)
        print("logtis_total.size():",logits_total.size())
        print("labels_total.size():",labels_total.size())
        return logits_total,labels_total

    
    def compute_loss(self, model, inputs, return_outputs=False):
        labels = inputs.pop("labels")
        
        # ----- Default p = 0.1 ---------#
        output = model(**inputs)
        logits = output.pooler_output.unsqueeze(1) 
        
        # ---- iteratively create dropouts -----#
        for p_dpr in self.w_drop_out:
            # -- Set models dropout --#
            if p_dpr != 0.1:
                model = self.set_dropout_mf(model, w=p_dpr)
            # ---- concat logits ------#
            logits = torch.cat((logits, model(**inputs).pooler_output.unsqueeze(1)), 1)
            #----- Set model back to dropout = 0.1 -----#
            if p_dpr != 0.1: model = self.set_dropout_mf(model, w=0.1)

            
        # ---- L2 norm ---------#
        logits = FF.normalize(logits, p=2, dim=2)
        
        ##----- Set model back to dropout = 0.1 -----#
        #if p_dpr != 0.1: model = self.set_dropout_mf(model, w=0.1)
        
        
        # SupContrast
        loss_fn = SupConLoss(temperature=self.temperature_s) # temperature=0.1

        loss = loss_fn(logits, labels)
        
        return (loss, outputs) if return_outputs else loss
    
    def set_dropout_mf(self, model, w):
        # ------ set hidden dropout -------#
        if hasattr(model, 'module'):
            model.module.embeddings.dropout.p = w
            for i in model.module.encoder.layer:
                i.attention.self.dropout.p = w
                i.attention.output.dropout.p = w
                i.output.dropout.p = w        
        else:
            model.embeddings.dropout.p = w
            for i in model.encoder.layer:
                i.attention.self.dropout.p = w
                i.attention.output.dropout.p = w
                i.output.dropout.p = w
            
        return model
