# -*- coding: utf-8 -*-
# @Time    : 2021/6/10
# @Author  : kaka
import torch
import torch.nn as nn
from transformers import BertConfig, BertModel
import torch.nn.functional as F
import numpy as np


class SimCSE(nn.Module):
    def __init__(self, pretrained='bert-base-uncased', pool_type="cls", dropout_prob=0.3,  n_classes=2):
        super().__init__()
        conf = BertConfig.from_pretrained(pretrained)
        conf.attention_probs_dropout_prob = dropout_prob
        conf.hidden_dropout_prob = dropout_prob
        self.dropout_prob = dropout_prob
        self.encoder = BertModel.from_pretrained(pretrained, config=conf)
        assert pool_type in ["cls", "pooler"], "invalid pool_type: %s" % pool_type
        self.pool_type = pool_type
        
        self.n_classes = n_classes
        self.fc = nn.Linear(768, self.n_classes)
        
        #反事实对比生成模块
        self.max_seq_length = 128
        # self.counter_attention = torch.nn.Parameter(torch.FloatTensor(self.max_seq_length, self.max_seq_length))
        # nn.init.xavier_normal_(self.counter_attention)

        
        gain = 1.0
        self.query_proj_s = nn.Linear(768, 768, bias=False)
        self.key_proj_s = nn.Linear(768, 768, bias=False)
        self.bias_s = nn.Parameter(torch.rand(768).uniform_(-0.1, 0.1))
        self.score_proj_s = nn.Linear(768, 1)
        
        self.query_proj = nn.Linear(768, 768, bias=False)
        self.key_proj = nn.Linear(768, 768, bias=False)
        self.bias = nn.Parameter(torch.rand(768).uniform_(-0.1, 0.1))
        self.score_proj = nn.Linear(768, 1)
        
        
#         nn.init.xavier_uniform_(self.query_proj_s.weight, gain=gain)
#         nn.init.xavier_uniform_(self.key_proj_s.weight, gain=gain)
#         nn.init.xavier_uniform_(self.score_proj_s.weight, gain=gain)
        
#         nn.init.xavier_uniform_(self.query_proj.weight, gain=gain)
#         nn.init.xavier_uniform_(self.key_proj.weight, gain=gain)
#         nn.init.xavier_uniform_(self.score_proj.weight, gain=gain)
        
        self.att_type = "mul" #"add"  "mul"
        
    def forward(self, input_ids, attention_mask, token_type_ids, distribution_vec=None):
        
        
        # counter_mask = torch.mm(input_ids, self.counter_attention) #input_ids.shape 的形状一样
        # sm = torch.nn.Softmax(dim=0)
        # counter_mask = sm(counter_mask)
        
#         zero_mask = torch.zero(input_ids.shape, dtype=torch.long)
        
#         input_ids_pos = torch.where(counter_mask>=0.5, input_ids, zero_mask)
#         input_ids_neg = torch.where(counter_mask<0.5, input_ids, zero_mask)


        
        
        
        output = self.encoder(input_ids ,
                              attention_mask=attention_mask,
                              token_type_ids=token_type_ids)
    
        # print("attention_mask.shape: ", attention_mask.shape)
        # print(attention_mask)
        last_hid = output.last_hidden_state
        cls_out = output.last_hidden_state[:, 0]
         
        out_dense_ori = cls_out
        pred_label_ori = self.fc(out_dense_ori)
        
        h_t = cls_out.unsqueeze(1)
       
        temperature = 1

        # print(last_hid.shape, h_t.shape) #torch.Size([5, 128, 768]) torch.Size([1, 768, 5])
        if distribution_vec is None:
            
                #####################################################################################
            if self.att_type == "add":
                h_t = h_t.permute(0, 2, 1)
                attn_weights = torch.bmm(last_hid, h_t)
                attn_weights = attn_weights.squeeze()  #.permute(1, 0, 2)
                attention = F.softmax(attn_weights/temperature, 1)
                # attention_mask = attention_cls.clone() #torch.rand_like(attention_cls)
             
            #####################################################################################
            elif self.att_type == "mul":
                score = self.score_proj_s(torch.tanh(self.key_proj_s(h_t) + self.query_proj_s(last_hid) + self.bias_s)).squeeze(-1)
                # score = self.score_proj(torch.tanh(self.key_proj(h_t) + self.query_proj(last_hid) + self.bias)).squeeze(-1)
                
                attention = F.softmax(score/temperature, dim=-1)
             
        
        else:
            distribution_vec = torch.unsqueeze(distribution_vec,dim=1)
            distribution_vec = distribution_vec.repeat(last_hid.shape[0], 1, 1)
            
            #####################################################################################
            if self.att_type == "add":
                h_t = h_t.permute(0, 2, 1)
                attn_weights = torch.bmm(last_hid, h_t)
                attn_weights = attn_weights.squeeze()  #.permute(1, 0, 2)
                attention = F.softmax(attn_weights/temperature, 1)
           

            #####################################################################################
            elif self.att_type == "mul":
                distribution_vec = distribution_vec.permute(0, 2, 1)
                score = self.score_proj(torch.tanh(self.key_proj(distribution_vec) + self.query_proj(last_hid) + self.bias)).squeeze(-1)
                attention = F.softmax(score, dim=-1)

        
       ##################################################### #策略4   #####################################################
        attention = attention*100
        attention_reverse = torch.ones(attention.shape).to(attention.device) - attention
        
        # print("attention")
        # print(attention[0])
        # print("attention_reverse")
        # print(attention_reverse[0])
        
        
        t_mask = torch.nn.functional.gumbel_softmax(attention, tau=0.1)
        t_mask_reverse = torch.nn.functional.gumbel_softmax(attention_reverse, tau=0.1)
        
        k= 10 #attention.shape[1] + 10
        # print("k:", k )
        for i in range(k-1):
            t_max = torch.nn.functional.gumbel_softmax(attention, tau=0.1)
            t_mask = t_mask + t_max
            
            t_max_reverse = torch.nn.functional.gumbel_softmax(attention_reverse, tau=0.1)
            t_mask_reverse = t_mask_reverse + t_max_reverse
        
        # print(t_mask, t_mask_reverse)
        neg_mask_bool = (t_mask > 0.9).long()  # 高attention的单词mask掉成neg样本
        pos_mask_bool = (t_mask_reverse > 0.9).long()  # 低attention的单词mask掉成pos样本
        
        input_ids_pos = torch.mul(input_ids, (1-pos_mask_bool))
        input_ids_neg = torch.mul(input_ids, (1-neg_mask_bool))
        
        # print("posids")
        # print(input_ids_pos[0])
        # print("negids")
        # print(input_ids_neg[0])
        
        new_attention_pos = torch.mul(attention, (1-pos_mask_bool)).unsqueeze(1)
        new_attention_neg = torch.mul(attention, (1-neg_mask_bool)).unsqueeze(1)

#        ################################################### #策略3   #####################################################
#         sorted_att, sorted_indices = torch.sort(attention, descending=True)
 
#         pos_indices = sorted_indices[:, :sorted_indices.shape[1]//2]
#         neg_indices = sorted_indices[:, sorted_indices.shape[1]//2:]
#         # print(pos_indices.shape)
#         # print(pos_indices)

        
#         pos_indices = pos_indices[:,torch.randperm(pos_indices.shape[1])]
#         pos_indices = pos_indices[:, -sorted_indices.shape[1]//10:]
#         # print(pos_indices.shape)
#         # print(pos_indices)
#         neg_indices = neg_indices[:,torch.randperm(neg_indices.shape[1])]
#         neg_indices = neg_indices[:, -sorted_indices.shape[1]//10:]

#         input_ids_zero = torch.zeros(input_ids.shape, dtype=torch.long).to(input_ids.device)
        
#         input_ids_pos = input_ids.scatter(1, pos_indices, input_ids_zero)
#         input_ids_neg = input_ids.scatter(1, neg_indices, input_ids_zero)

#         # print(input_ids_pos)
#         # print(input_ids_neg)
#         # print("done")
#         zero_mask_new = torch.zeros(attention.shape).to(attention.device)
#         new_attention_pos = attention.scatter(1, pos_indices, zero_mask_new).unsqueeze(1)
#         new_attention_neg = attention.scatter(1, neg_indices, zero_mask_new).unsqueeze(1)
        
        
#        ###################################################################################################
            
            
        output_pos = self.encoder(input_ids_pos,
                      attention_mask=attention_mask,
                      token_type_ids=token_type_ids)
        
        last_hid_pos = output_pos.last_hidden_state
        out_dense_pos = output_pos.last_hidden_state[:, 0]
        out_dense_pos = torch.bmm(new_attention_pos, last_hid_pos)
        out_dense_pos = out_dense_pos.squeeze(1)
        pred_label_pos = self.fc(out_dense_pos)
         
        output_neg = self.encoder(input_ids_neg ,
                      attention_mask=attention_mask,
                      token_type_ids=token_type_ids)
        
        last_hid_neg = output_neg.last_hidden_state
        out_dense_neg = output_neg.last_hidden_state[:, 0]
        out_dense_neg = torch.bmm(new_attention_neg, last_hid_neg)
        out_dense_neg = out_dense_neg.squeeze(1)
        pred_label_neg = self.fc(out_dense_neg)


        loss_mask_norm = torch.norm(t_mask,p=2) + torch.norm(t_mask_reverse,p=2) #越稀疏越好，loss_mask_reg尽量小
  
        logp_x = F.log_softmax(t_mask, dim=-1)
        p_y = F.softmax(t_mask_reverse, dim=-1)
        loss_mask_dist = 1/F.kl_div(logp_x, p_y, reduction='sum') #F.cross_entropy(t_mask, t_mask_reverse,reduction='sum')   #F.kl_div(t_mask, t_mask_reverse,reduction='sum')  #t_mask, t_mask_reverse差异越大越好，F.cross_entropy越大越好，-F.cross_entropy尽量小
 
        loss_mask_reg = 0.2*(0.01*loss_mask_norm + 0.1*loss_mask_dist)
        # print("loss_mask_norm", loss_mask_norm, "loss_mask_dist", loss_mask_dist,"loss_mask_reg", loss_mask_reg)
        return out_dense_ori, pred_label_ori, input_ids_pos, input_ids_neg, out_dense_pos, pred_label_pos, out_dense_neg, pred_label_neg, loss_mask_reg
