from torch import nn
from torch.nn import functional as F
from torch.utils.data import TensorDataset, DataLoader, RandomSampler
from torch.optim import Adam
from torch.nn import CrossEntropyLoss

from copy import deepcopy
import gc
import torch
from sklearn.metrics import accuracy_score
import numpy as np
from transformers import AutoTokenizer, AutoModel, AutoConfig
from .SimCSECons_CDA import SimCSE
from .LossUtil import *


class BASE(nn.Module):
    '''
        BASE model
    '''
    def __init__(self, args):
        super(BASE, self).__init__()
        self.args = args

        # # cached tensor for speed
        # self.I_way = nn.Parameter(torch.eye(self.args.way, dtype=torch.float),
        #                           requires_grad=False)

    def _compute_l2(self, XS, XQ):
        '''
            Compute the pairwise l2 distance
            @param XS (support x): support_size x ebd_dim
            @param XQ (support x): query_size x ebd_dim

            @return dist: query_size x support_size

        '''
        diff = XS.unsqueeze(0) - XQ.unsqueeze(1)
        dist = torch.norm(diff, dim=2)

        return dist

    def _compute_cos(self, XS, XQ):
        '''
            Compute the pairwise cos distance
            @param XS (support x): support_size x ebd_dim
            @param XQ (support x): query_size x ebd_dim

            @return dist: query_size support_size

        '''
        dot = torch.matmul(
                XS.unsqueeze(0).unsqueeze(-2),
                XQ.unsqueeze(1).unsqueeze(-1)
                )
        dot = dot.squeeze(-1).squeeze(-1)

        scale = (torch.norm(XS, dim=1).unsqueeze(0) *
                 torch.norm(XQ, dim=1).unsqueeze(1))

        scale = torch.max(scale,
                          torch.ones_like(scale) * 1e-8)

        dist = 1 - dot/scale

        return dist

    def reidx_y(self, YS, YQ):
        '''
            Map the labels into 0,..., way
            @param YS: batch_size
            @param YQ: batch_size

            @return YS_new: batch_size
            @return YQ_new: batch_size
        '''
        unique1, inv_S = torch.unique(YS, sorted=True, return_inverse=True)
        unique2, inv_Q = torch.unique(YQ, sorted=True, return_inverse=True)

        if len(unique1) != len(unique2):
            raise ValueError(
                'Support set classes are different from the query set')

        if len(unique1) != self.args.way:
            raise ValueError(
                'Support set classes are different from the number of ways')

        if int(torch.sum(unique1 - unique2).item()) != 0:
            raise ValueError(
                'Support set classes are different from the query set classes')

        Y_new = torch.arange(start=0, end=self.args.way, dtype=unique1.dtype,
                device=unique1.device)

        return Y_new[inv_S], Y_new[inv_Q]
    
    
    def _label2onehot(self, Y):
        '''
            Map the labels into 0,..., way
            @param Y: batch_size

            @return Y_onehot: batch_size * ways
        '''
        Y_onehot = F.embedding(Y, self.I_way)

        return Y_onehot
    
    @staticmethod
    def compute_acc(pred, true):
        '''
            Compute the accuracy.
            @param pred: batch_size * num_classes
            @param true: batch_size
        '''
        return torch.mean((torch.argmax(pred, dim=1) == true).float()).item()


    
class Bert(nn.Module):
    def __init__(self, pretrained_bert, pool_type="cls", dropout_prob=0.3, n_classes=2):
        super().__init__()
 
        self.encoder = AutoModel.from_pretrained(pretrained_bert)

        self.pool_type = pool_type
        
        
        self.n_classes=n_classes
        self.fc = nn.Linear(768, self.n_classes)

    def forward(self, input_ids, attention_mask, token_type_ids):
        output = self.encoder(input_ids,
                              attention_mask=attention_mask,
                              token_type_ids=token_type_ids)
        if self.pool_type == "cls":
            output = output.last_hidden_state[:, 0]
        elif self.pool_type == "pooler":
            output = output.pooler_output
     
        pred_label = self.fc(output)

        return output, pred_label


class PROTO_CDA(BASE):
    '''
        META-LEARNING WITH DIFFERENTIABLE CLOSED-FORM SOLVERS
    '''
    def __init__(self, args):
        super(PROTO_CDA, self).__init__(args)
        self.args = args
        self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
        
        self.model = SimCSE(self.args.pretrained_bert, 'cls', 0.3, self.args.way)#.to(self.device)
        # self.outer_optimizer = Adam(self.model.parameters(), lr=self.args.lr)
        
        self.model_pos = SimCSE(self.args.pretrained_bert, 'cls', 0.3, self.args.way)#.to(self.device)
        # self.outer_optimizer_pos = Adam(self.model_pos.parameters(), lr=self.args.lr)
            
        
    def _compute_prototype(self, XS, YS):
        '''
            Compute the prototype for each class by averaging over the ebd.

            @param XS (support x): support_size x ebd_dim
            @param YS (support y): support_size

            @return prototype: way x ebd_dim
        '''
        # sort YS to make sure classes of the same labels are clustered together
        sorted_YS, indices = torch.sort(YS)
        sorted_XS = XS[indices]

        prototype = []
        for i in range(self.args.way):
            prototype.append(torch.mean(
                sorted_XS[i*self.args.shot:(i+1)*self.args.shot], dim=0,
                keepdim=True))

        prototype = torch.cat(prototype, dim=0)

        return prototype

    def forward(self, support, query):
        '''
            @param XS (support x): support_size x ebd_dim
            @param YS (support y): support_size
            @param XQ (support x): query_size x ebd_dim
            @param YQ (support y): query_size

            @return acc
            @return loss
        '''
        # print("proto_CDA len(support), len(query): ", len(support), len(query))
        s_label_ids = support['label']
        s_input_ids = support['text']
        s_attention_mask = support['ret_mask']
        s_segment_ids = support['ret_type_ids']
        s_len = len(s_label_ids)
        
        q_label_ids = query['label']
        q_input_ids = query['text']
        q_attention_mask = query['ret_mask']
        q_segment_ids = query['ret_type_ids']
        q_len = len(q_label_ids)
        
 
        YS = s_label_ids
    
        YQ = q_label_ids
        
        # print("ori YS:", YS)
        
        YS, YQ = self.reidx_y(YS, YQ)
#         print("reidx YS:", YS)
        
#         exit()
        
        # XS, _ = self.model(s_input_ids, s_attention_mask, s_segment_ids)
        s_output_ori, s_pred_label_ori, s_input_ids_pos, s_input_ids_neg, s_out_dense_pos, s_pred_label_pos, s_out_dense_neg, s_pred_label_neg, s_loss_mask_reg = self.model(s_input_ids, s_attention_mask, s_segment_ids)
        XS = s_output_ori

        prototype = self._compute_prototype(XS, YS)


        ########################################################################
        s_output_ori_2, s_pred_label_ori_2, s_input_ids_pos_2, s_input_ids_neg_2, s_out_dense_pos_2, s_pred_label_pos_2, s_out_dense_neg_2, s_pred_label_neg_2, s_loss_mask_reg_2 = self.model(s_input_ids, s_attention_mask, s_segment_ids)


        pred_merge = torch.cat([s_output_ori, s_output_ori_2],dim=1) 
        pred_resplit = pred_merge.view(2*s_output_ori.shape[0], s_output_ori.shape[1])
        loss_cons_s = compute_loss(pred_resplit, 0.05, self.device)

        ######################################################################
        s_output_ori_pos, s_pred_label_ori_pos, _, _, _, _, _, _, _ = self.model_pos(s_input_ids_pos, s_attention_mask, s_segment_ids)
        s_output_ori_neg, s_pred_label_ori_neg, _, _, _, _, _, _, _ = self.model_pos(s_input_ids_neg, s_attention_mask, s_segment_ids)

        XS_pos = s_output_ori_pos
        prototype_pos = self._compute_prototype(XS_pos, YS)

        ##################################################################### 
        pred_merge_direct_pos = torch.cat([XS, XS_pos],dim=1) 
        pred_resplit_direct_pos = pred_merge_direct_pos.view(2*XS.shape[0], XS.shape[1])
        loss_counter_direct_pos_s = compute_loss(pred_resplit_direct_pos, 0.05, self.device)

        contrastive_loss = torch.nn.TripletMarginLoss(margin=1, p=2)
        loss_cons_triple_s = contrastive_loss(s_pred_label_ori, s_pred_label_ori_pos, s_pred_label_ori_neg)
        #####################################################################################



#############################################################query################################################################################       


        q_output_ori, q_pred_label_ori, q_input_idq_pos, q_input_idq_neg, q_out_dense_pos, q_pred_label_pos, q_out_dense_neg, q_pred_label_neg, q_losq_mask_reg = self.model(q_input_ids, q_attention_mask, q_segment_ids)
        XQ = q_output_ori

        q_pred = -self._compute_l2(prototype, q_output_ori)

        #######################################SIMCSE#################################
        q_output_ori_2, q_pred_label_ori_2, q_input_idq_poq_2, q_input_idq_neg_2, q_out_dense_poq_2, q_pred_label_poq_2, q_out_dense_neg_2, q_pred_label_neg_2, q_losq_mask_reg_2 = self.model(q_input_ids, q_attention_mask, q_segment_ids)


        q_pred_merge = torch.cat([q_output_ori, q_output_ori_2],dim=1) 
        q_pred_resplit = q_pred_merge.view(2*q_output_ori.shape[0], q_output_ori.shape[1])
        loss_cons_q = compute_loss(q_pred_resplit, 0.05, self.device)

        ######################################################################
        q_output_ori_pos, q_pred_label_ori_pos, _, _, _, _, _, _, _ = self.model_pos(q_input_idq_pos, q_attention_mask, q_segment_ids)
        q_output_ori_neg, q_pred_label_ori_neg, _, _, _, _, _, _, _ = self.model_pos(q_input_idq_neg, q_attention_mask, q_segment_ids)

        XQ_pos = q_output_ori_pos
        ##################################################################### 
        q_pred_merge_direct_pos = torch.cat([XQ, XQ_pos],dim=1) 
        q_pred_resplit_direct_pos = q_pred_merge_direct_pos.view(2*XQ.shape[0], XQ.shape[1])
        loss_counter_direct_pos_q = compute_loss(q_pred_resplit_direct_pos, 0.05, self.device)

        loss_cons_triple_q = contrastive_loss(q_pred_label_ori, q_pred_label_ori_pos, q_pred_label_ori_neg)

        #############################################################################

        output_ori_spt_mean = torch.mean(XS,dim=0)
        q_output_ori_newpos, q_pred_label_ori_newpos, q_input_ids_pos_newpos, q_input_ids_neg_newpos, q_out_dense_pos_newpos, q_pred_label_pos_newpos, q_out_dense_neg_newpos, q_pred_label_neg_newpos, q_loss_mask_reg_newpos = self.model_pos(q_input_ids, q_attention_mask, q_segment_ids, output_ori_spt_mean)

        q_pred_pos = -self._compute_l2(prototype_pos, q_output_ori_newpos)



        loss_qry_cons_pos = contrastive_loss(YQ.unsqueeze(1), torch.max(q_pred, 1)[1].unsqueeze(1), torch.max(q_pred_pos, 1)[1].unsqueeze(1))

        loss_q = F.cross_entropy(q_pred,  YQ)
        loss_q_pos = F.cross_entropy(q_pred_pos,  YQ)



        loss = loss_q + loss_cons_s + loss_counter_direct_pos_s + loss_cons_triple_s   +   loss_cons_q + loss_counter_direct_pos_q + loss_cons_triple_q  + loss_qry_cons_pos


        acc = BASE.compute_acc(q_pred, YQ)

        return acc, loss
