import torch
import torch.nn as nn
import torch.nn.functional as F
from transformers import BartForConditionalGeneration
from utils import *


def Model_Mapping(model_type):
    if 'base' in model_type:  
        return BartForConditionalGeneration.from_pretrained("facebook/bart-base"), torch.nn.Linear(768, 768)
    elif 'large' in model_type:  
        return BartForConditionalGeneration.from_pretrained("facebook/bart-large"), torch.nn.Linear(1024, 1024)


class ClarET(nn.Module):
    def __init__(self, model_type, dropout_prob=0.1, margin=0.5):
        super(ClarET, self).__init__()
        self.model, self.trans = Model_Mapping(model_type)
        
        self.dropout = nn.Dropout(p=dropout_prob)
        self.act = nn.GELU()
        self.margin = margin
        self.triplet_loss = nn.TripletMarginWithDistanceLoss(distance_function=nn.PairwiseDistance(), margin=margin)
        
        self.init_weight(self.trans)    

    def init_weight(self, layer):
        torch.nn.init.xavier_uniform_(layer.weight)
        torch.nn.init.constant_(layer.bias, 0)

    def forward(self, infilling_input_ids, infilling_attention_mask, infilling_labels, infilling_mask_loc, \
               positive_piece_input_ids, positive_piece_attention_mask, \
               negative_piece_input_ids, negative_piece_attention_mask, \
               class_input_ids, class_attention_mask, class_labels, \
               tag_input_ids=None, tag_attention_mask=None, tag_labels=None, \
               ablation='none', soft_label=0, eval_mode=False):
        
        if tag_input_ids==None:
            ################## infilling loss ##################
            outputs = self.model(input_ids=infilling_input_ids, attention_mask=infilling_attention_mask, labels=infilling_labels, soft_label=soft_label)
            g_loss = outputs.loss

            ################## contrastive loss ##################
            if ablation == 'wok' or ablation == 'woall':
                with torch.no_grad():
                    seq_batch = int(infilling_mask_loc.size(0))
                    hidden_states = outputs.encoder_last_hidden_state                               #b,l,h
                    hidden_states = hidden_states[:seq_batch, ...]
                    infilling_mask_loc = infilling_mask_loc.unsqueeze(-1)                   #b,l,1
                    hidden_states = torch.sum(hidden_states * infilling_mask_loc, dim=1)    #b,h
                    hidden_states = self.trans(self.dropout(hidden_states)).unsqueeze(1)    #b,1,h
                    
                    # positive piece and negative piece 
                    positive_piece_num = int(positive_piece_input_ids.size(0))
                    negative_piece_num = int(negative_piece_input_ids.size(1))

                    negative_piece_input_ids = negative_piece_input_ids.view(negative_piece_input_ids.size(0)*negative_piece_input_ids.size(1), negative_piece_input_ids.size(2)) if negative_piece_input_ids is not None else None
                    negative_piece_attention_mask = negative_piece_attention_mask.view(negative_piece_attention_mask.size(0)*negative_piece_attention_mask.size(1), negative_piece_attention_mask.size(2)) if negative_piece_attention_mask is not None else None
                    
                    piece_input_ids = torch.cat([positive_piece_input_ids, negative_piece_input_ids], 0)    # *b,l,h
                    piece_attention_mask = torch.cat([positive_piece_attention_mask, negative_piece_attention_mask], 0)
                    piece_hidden_states = self.model(input_ids=piece_input_ids, attention_mask=piece_attention_mask).encoder_last_hidden_state[:,0,:] #11b,h
                    k_loss = self.triplet_loss(hidden_states, \
                                            piece_hidden_states[:positive_piece_num,:].unsqueeze(1), \
                                            piece_hidden_states[positive_piece_num:,:].view(hidden_states.size(0), \
                                            negative_piece_num, -1))
            else:
                seq_batch = int(infilling_mask_loc.size(0))
                hidden_states = outputs.encoder_last_hidden_state                               #b,l,h
                hidden_states = hidden_states[:seq_batch, ...]
                infilling_mask_loc = infilling_mask_loc.unsqueeze(-1)                   #b,l,1
                hidden_states = torch.sum(hidden_states * infilling_mask_loc, dim=1)    #b,h
                hidden_states = self.trans(self.dropout(hidden_states)).unsqueeze(1)    #b,1,h
                
                # positive piece and negative piece 
                positive_piece_num = int(positive_piece_input_ids.size(0))
                negative_piece_num = int(negative_piece_input_ids.size(1))

                negative_piece_input_ids = negative_piece_input_ids.view(negative_piece_input_ids.size(0)*negative_piece_input_ids.size(1), negative_piece_input_ids.size(2)) if negative_piece_input_ids is not None else None
                negative_piece_attention_mask = negative_piece_attention_mask.view(negative_piece_attention_mask.size(0)*negative_piece_attention_mask.size(1), negative_piece_attention_mask.size(2)) if negative_piece_attention_mask is not None else None
                
                piece_input_ids = torch.cat([positive_piece_input_ids, negative_piece_input_ids], 0)    # *b,l,h
                piece_attention_mask = torch.cat([positive_piece_attention_mask, negative_piece_attention_mask], 0)
                piece_hidden_states = self.model(input_ids=piece_input_ids, attention_mask=piece_attention_mask).encoder_last_hidden_state[:,0,:] #11b,h
                k_loss = self.triplet_loss(hidden_states, \
                                        piece_hidden_states[:positive_piece_num,:].unsqueeze(1), \
                                        piece_hidden_states[positive_piece_num:,:].view(hidden_states.size(0), \
                                        negative_piece_num, -1))

            ################## classification loss ##################
            if ablation == 'woct' or ablation == 'woall':
                with torch.no_grad():
                    class_outputs = self.model(input_ids=class_input_ids, attention_mask=class_attention_mask, labels=class_labels, soft_label=soft_label)
                    ct_loss = class_outputs.loss
            else:
                class_outputs = self.model(input_ids=class_input_ids, attention_mask=class_attention_mask, labels=class_labels, soft_label=soft_label)
                ct_loss = class_outputs.loss

            return g_loss, k_loss, ct_loss

        else:
            ################## infilling loss ##################
            outputs = self.model(input_ids=infilling_input_ids, attention_mask=infilling_attention_mask, labels=infilling_labels)
            g_loss = outputs.loss

            ################## contrastive loss ##################
            seq_batch = int(infilling_mask_loc.size(0))
            hidden_states = outputs.encoder_last_hidden_state                               #b,l,h
            hidden_states = hidden_states[:seq_batch, ...]
            infilling_mask_loc = infilling_mask_loc.unsqueeze(-1)                   #b,l,1
            hidden_states = torch.sum(hidden_states * infilling_mask_loc, dim=1)    #b,h
            hidden_states = self.trans(self.dropout(hidden_states)).unsqueeze(1)    #b,1,h
            
            # positive piece and negative piece 
            positive_piece_num = int(positive_piece_input_ids.size(0))
            negative_piece_num = int(negative_piece_input_ids.size(1))

            negative_piece_input_ids = negative_piece_input_ids.view(negative_piece_input_ids.size(0)*negative_piece_input_ids.size(1), negative_piece_input_ids.size(2)) if negative_piece_input_ids is not None else None
            negative_piece_attention_mask = negative_piece_attention_mask.view(negative_piece_attention_mask.size(0)*negative_piece_attention_mask.size(1), negative_piece_attention_mask.size(2)) if negative_piece_attention_mask is not None else None
            
            piece_input_ids = torch.cat([positive_piece_input_ids, negative_piece_input_ids], 0)    # *b,l,h
            piece_attention_mask = torch.cat([positive_piece_attention_mask, negative_piece_attention_mask], 0)
            piece_hidden_states = self.model(input_ids=piece_input_ids, attention_mask=piece_attention_mask).encoder_last_hidden_state[:,0,:] #11b,h
            k_loss = self.triplet_loss(hidden_states, \
                                    piece_hidden_states[:positive_piece_num,:].unsqueeze(1), \
                                    piece_hidden_states[positive_piece_num:,:].view(hidden_states.size(0), \
                                    negative_piece_num, -1))

            ################## classification loss ##################
            class_outputs = self.model(input_ids=class_input_ids, attention_mask=class_attention_mask, labels=class_labels)
            c_loss = class_outputs.loss
            
            ################## tagging loss ##################
            tag_outputs = self.model(input_ids=tag_input_ids, attention_mask=tag_attention_mask, labels=tag_labels)
            t_loss = tag_outputs.loss
            
            return g_loss, k_loss, c_loss, t_loss