import torch
import torch.nn as nn
import numpy as np
import torch.nn.functional as F


class ATLoss(nn.Module):
    def __init__(self):
        super().__init__()
        self.loss_gate = nn.Linear(3, 1 ,bias=False)
        self.loss_gate.weight.data.uniform_(0.8, 1.2)
        self.pos_weight = torch.ones((97))
        self.pos_weight[3:10] = 1.3
        self.pos_weight[10:20] = 1.5
        self.pos_weight[20:] = 2
        self.dropout = nn.Dropout(0.2)

    def forward(self, logits, labels):
        # TH label
        th_label = torch.zeros_like(labels, dtype=torch.float).to(labels)
        th_label[:, 0] = 1.0
        labels[:, 0] = 0.0
        label_idx = labels.sum(dim=1)

        two_idx = torch.where(label_idx==2)[0]

        p_mask = labels + th_label
        n_mask = 1 - labels
        
        num_ex, num_class = labels.size()
        num_ent = int(np.sqrt(num_ex))
        # Rank positive classes to TH
        #logit1_c = logits * p_mask
        logit1 = logits - (1 - p_mask) * 1e30
        logit0 = logits - (1 - labels) * 1e30
        logit_rank = logits * labels
        #logit_rank = logits 

        #positive_prob =  (1 - F.softmax(logit_rank,dim= -1 ))**1.0  
        
        #print(F.softmax(logit_rank, dim=-1))
        
        logit1_exp = torch.exp(logit1)

        denom = logit1_exp + logit1_exp[:,:1]

        th_mask = torch.cat( num_class * [logit1[:,:1]], dim=1) * 1.0
        
        #print(th_mask.size())
        logit1_th = torch.cat([logit1.unsqueeze(1), th_mask.unsqueeze(1)], dim=1)
        #print(logit1_th.size())

        positive_prob = (1 - F.softmax(logit1_th,  dim=1)[:,0,:] )**1
        #positive_prob = positive_prob/(1.5 * positive_prob[:,:1] + 1e-9)
        #loss1 = -(positive_prob *  F.log_softmax(logit1_th, dim=1)[:,0,:] * labels).sum(1).mean()
        loss1 = -( F.log_softmax(logit1_th, dim=1)[:,0,:] * labels).sum(1)
        #orig = (F.log_softmax(logit1, dim=-1))
        #new = logit1_exp/(denom + 1e-20)
        #new = F.log_softmax(logit1_th, dim=1)
        #print(new.size())
        #loss1 = -(balancing_factor * F.log_softmax(logit1, dim=-1) * labels).sum(1).mean()
        #loss1 = -( F.log_softmax(logit1, dim=-1) * labels).sum(1).mean()

        # Rank TH to negative classes
        logit2 = logits - (1 - n_mask) * 1e30
        th_mask_n = torch.cat( num_class * [logit2[:,:1]], dim=1)
        logit2_th = torch.cat([logit2.unsqueeze(1), th_mask_n.unsqueeze(1)], dim=1)
        #loss2 = -((1 - F.softmax(logit2, dim=-1)) ** 2.0 * F.log_softmax(logit2, dim=-1) * th_label).sum(1).mean()
    
        loss2 = -( F.log_softmax(logit2, dim=-1) * th_label).sum(1)
        #loss2 = -(F.log_softmax(logit2_th, dim=1)[:,0,:] * th_label).sum(1).mean()
        #loss0 = torch.abs(loss1 - loss2)
        ''' 
        if two_idx.size()[0]!=0:
            print('Label')
            print(labels[two_idx[0]])
            print('Logit')
            print(logit1[two_idx[0]])
            print('Orig')
            print(orig[two_idx[0]])
            print('New')
            print(new[two_idx[0]])
            exit()
        '''
        #print('Positive loss: {}'.format(loss1.mean()))
        #print('Negative loss: {}'.format(loss2.mean()))
        #print('Positive loss: {}'.format(loss1.size()))
        #print('Negative loss: {}'.format(loss2.size()))
        diag_mask = torch.diag(torch.ones((num_ent))).to(logits)
        loss2 = loss2.view(num_ent, num_ent) * (1 - diag_mask)
        loss2 = loss2.view(-1)
        #tmp2_idx = torch.argsort(loss2)
        #hard_neg = int(0.2 * num_ex)
        #loss2[tmp2_idx[-hard_neg:]] = self.dropout(loss2[tmp2_idx[-hard_neg:]])
        # Sum two parts
        #loss = 1.1 * loss0 + 1.15 * loss1 + 1.0 * loss2
        loss = 1.0 * loss1.mean() + 1.0 * loss2.mean()

        #print(tmp2[tmp2_idx[:50]])
        #print(tmp2[tmp2_idx[-50:]])
        #print(logits[tmp2_idx[-1],:])
        #print(logits[tmp2_idx[-2],:])
        #print(logits[tmp2_idx[-3],:])
        #print(logits[tmp2_idx[-4],:])
        #exit()
        return loss, loss1.mean(), loss2.mean()

    def get_label(self, logits, num_labels=-1):
        th_logit = logits[:, 0].unsqueeze(1) * 1.0
        output = torch.zeros_like(logits).to(logits)
        mask = (logits > th_logit)
        if num_labels > 0:
            top_v, _ = torch.topk(logits, num_labels, dim=1)
            top_v = top_v[:, -1]
            mask = (logits >= top_v.unsqueeze(1)) & mask
        output[mask] = 1.0
        output[:, 0] = (output.sum(1) == 0.).to(logits)
        return output
