#!/usr/bin/python3
#-*- coding:utf-8 -*-
############################
#File Name: new_hmc.py
#Created Time:
############################


import torch
import torch.nn as nn
from models.mlp import MLP, ACT, MLP1
import torch.nn.functional as F
from torch.autograd import Variable

class n_SuperCls(nn.Module):
    '''
    this classifier
    '''
    def __init__(self,input_dim=300,hidden_dim=300,classes=[9,137]):
        super(n_SuperCls,self).__init__()
        self.classes = classes
        self.local_0 = MLP(input_dim=1*hidden_dim,output_dim=classes[0])
        self.local_1 = MLP(input_dim=2*hidden_dim,output_dim=classes[1]-classes[0])
        self.lab_cls0 = MLP(input_dim=input_dim,output_dim=classes[0])
        self.lab_cls1 = MLP(input_dim=input_dim,output_dim=classes[1]-classes[0])
        #self.lab_cls2 = MLP(input_dim=input_dim,output_dim=classes[2]-classes[1])
        self.lab_clsg = MLP(input_dim=input_dim,output_dim=classes[-1])
        self.mapp1 = nn.Linear(hidden_dim,hidden_dim)
        self.mapp0 = nn.Linear(hidden_dim,hidden_dim)
        self.wc0 = nn.Linear(hidden_dim,hidden_dim)
        self.wc1 = nn.Linear(hidden_dim,hidden_dim)
        self.wf0 = nn.Linear(hidden_dim,hidden_dim)
        self.wf1 = nn.Linear(hidden_dim,hidden_dim)
        self.high_g = MLP(input_dim=2*hidden_dim,output_dim=classes[-1])
        self.num_classes = classes[-1]
        self.dropout = nn.Dropout(0.5)
        self.layer_size = 1
        self.hidden_size = hidden_dim
        self.lstm = nn.LSTMCell(2*input_dim,hidden_dim)
        self.h_layers = 1
        self.linears_g = nn.ModuleList([nn.Linear(2*hidden_dim,2*hidden_dim) for idx in range(self.h_layers)])
        self.linears_t = nn.ModuleList([nn.Linear(2*hidden_dim,2*hidden_dim) for idx in range(self.h_layers)])
        self.fc_0 = nn.Sequential(nn.Linear(2*input_dim,input_dim),
                                   nn.LeakyReLU())
        self.fc_1 = nn.Sequential(nn.Linear(2*input_dim,input_dim),
                                   nn.LeakyReLU())
        self.fc_2 = nn.Sequential(nn.Linear(2*input_dim,input_dim),
                                  nn.LeakyReLU())

        self.g_1 = nn.Linear(hidden_dim,hidden_dim)
        self.t_1 = nn.Linear(hidden_dim,hidden_dim)
        self.ws = nn.ModuleList([nn.ModuleList([nn.Linear(hidden_dim,hidden_dim) for _ in range(3)]) for i in range(2)])

    def _highway_layer(self,input_,num_layers=1,act=nn.LeakyReLU()):
        for idx in range(num_layers):
            g = act(self.linears_g[idx](input_))
            t = torch.sigmoid(self.linears_t[idx](input_))
            output = t*g +(1.-t)*input_
        return output

    def lab_reps(self,lab_vec,indx):
        WQ = self.ws[indx][0](lab_vec) # num x hidden
        WK = self.ws[indx][1](lab_vec)
        WV = self.ws[indx][2](lab_vec)
        QK = WQ.matmul(WK.t()) # num x num
        QK = torch.softmax(QK,dim=1)
        V  = QK.matmul(WV)
        return V

    def forward(self,embeds,tfidf,labels=None,flag="test"):

        vocab_size = tfidf.shape[1]
        batch_size = tfidf.shape[0]

        #hx = torch.zeros(batch_size,self.hidden_size).cuda()
        #cx = torch.zeros(batch_size,self.hidden_size).cuda()
        #pad = torch.zeros(batch_size,self.hidden_size).cuda()

        # get the last output of GCN and compute the X
        embeds_0 = self.dropout(embeds[0])
        rp0 = tfidf.mm(embeds_0[:-self.num_classes,:])
        #lab0 = self.lab_reps(embeds_0[vocab_size:vocab_size+self.classes[0],:],0)
        #sc0 = rp0.mm(self.mapp0(embeds_0[vocab_size:vocab_size+self.classes[0],:]).t())
        #sc0 = rp0.mm(lab0.t())
        #scs0 = torch.softmax(sc0,dim=1)
        #ct0 = scs0.mm(embeds_0[vocab_size:vocab_size+self.classes[0],:])
        #rp0 = self.fc_0(torch.cat((rp0,ct0),dim=1))
        #rp0 = self.dropout0(rp0)
        #rp0 = torch.cat((rp0,pad),dim=1)
        #rp0 ,cx = self.lstm(rp0,(hx,cx))
        logits0 = self.local_0(rp0)
        lab0 = F.normalize(embeds_0[vocab_size:vocab_size+self.classes[0]:,:])
        corm0 = lab0.mm(lab0.t())
        corm0 = torch.softmax(corm0*1.1,dim=0)

        embeds_1 = self.dropout(embeds[1])
        rp1 = tfidf.mm(embeds_1[:-self.num_classes,:])
        #sc1 = rp1.mm(
        #        self.mapp1(embeds_1[vocab_size+self.classes[0]:vocab_size+self.classes[1],:]).t())
        #scs1 = torch.softmax(sc1,dim=1)
        #ct1 = scs1.mm(embeds_1[vocab_size+self.classes[0]:vocab_size+self.classes[1],:])
        # compute the second layer hidden state and output logits2
        #rp2 = self.com_2(torch.cat((rp2,ct1,ct2),dim=1))
        #g_1 = torch.relu(self.g_1(torch.cat((rp0,rp1),dim=1)))
        #t_1 = torch.sigmoid(self.t_1(rp0))
        #rp1 = t_1*rp1 +(1.-t_1)*rp0
        #print(rp0.shape)
        #print(rp1.shape)
        #lab_ct0 = torch.softmax(logits0,dim=1).mm(embeds_0[vocab_size:vocab_size+self.classes[0],:])

        #labs = self.lab_reps(embeds_1[vocab_size:vocab_size+self.classes[-1],:],0)
        #if flag=="test":
        #    lab_ct0 = torch.softmax(torch.sigmoid(logits0),dim=1).mm(labs[:self.classes[0],:])
        #elif flag=="train":
        #    lab_ct0 = torch.softmax(labels[:,:self.classes[0]],dim=1).mm(
        #        labs[:self.classes[0],:])

        if flag=="test":
            lab_ct0 = torch.softmax(logits0,dim=1).mm(embeds_0[vocab_size:vocab_size+self.classes[0],:])
        elif flag=="train":
            lab_ct0 = torch.softmax(labels[:,:self.classes[0]],dim=1).mm(
                embeds_0[vocab_size:vocab_size+self.classes[0],:])


        #rp1 = self.fc_1(torch.cat((self.wf0(rp1),self.wc0(lab_ct0)),dim=1))
        rp1_ = torch.cat((rp1,lab_ct0),dim=1)
        #rp1,cx = self.lstm(rp1,(rp0,cx))
        logits1 = self.local_1(rp1_)

        lab1 = F.normalize(embeds_1[vocab_size+self.classes[0]:vocab_size+self.classes[1],:])
        corm1 = lab1.mm(lab1.t())
        corm1 = torch.softmax(corm1*1.1,dim=0)
        #logits1 = logits1.matmul(corm1) + logits1
        # compute the logitsG

        #embeds_2 = self.dropout2(embeds[2])
        #rp2 = tfidf.mm(embeds_2[:-self.num_classes,:])
        #lab2 = F.normalize(embeds_2[vocab_size+self.classes[1]:vocab_size+self.classes[2],:])
        #corm2 = lab2.mm(lab2.t())
        #corm2 = torch.softmax(corm2,dim=0)
        #if flag=="test":
        #    lab_ct1 = torch.softmax(logits1,dim=1).mm(embeds_1[vocab_size+self.classes[0]:vocab_size+self.classes[1],:])
        #elif flag=="train":
        #    lab_ct1 = torch.softmax(labels[:,self.classes[0]:self.classes[1]],dim=1).mm(
        #        embeds_1[vocab_size+self.classes[0]:vocab_size+self.classes[1],:])

        #rp2 = self.fc_2(torch.cat((self.wf1(rp2),self.wc1(lab_ct1)),dim=1))
        #logits2 = self.local_2(rp2)


        rpG = self._highway_layer(torch.cat((rp0,rp1),dim=1))
        #rpG = self._highway_layer(torch.cat((rp0,rp1,rp2),dim=1))
        #rpG = self.dropoutg(rpG)
        #rpG = torch.cat((rp0,rp1),dim=1)
        logitsG = self.high_g(rpG)
        #print(logitsG)
        #m0,_ = torch.max(logitsG,dim=0)
        #print(m0)
        #embeds_g = self.dropoutg(embeds[2])
        #rpG = tfidf.mm(embeds_g[:-self.num_classes,:])
        #logitsG = self.high_g(rpG)
        #labG = F.normalize(embeds_g[-self.num_classes:,:])
        #cormg = labG.mm(labG.t())
        #cormg = torch.softmax(cormg,dim=0)

        #labg = F.normalize(embeds[1][-self.classes[-1]:,:])
        #corm = labg.mm(labg.t())
        #corm = torch.softmax(corm,dim=1)
        #logitsG = logitsG.matmul(corm)

        lab_loss0 = self.lab_cls0(embeds_0[vocab_size:vocab_size+self.classes[0],:])
        lab_loss1 = self.lab_cls1(embeds_1[vocab_size+self.classes[0]:vocab_size+self.classes[1],:])
        #lab_loss2 = self.lab_cls2(embeds_2[vocab_size+self.classes[1]:vocab_size+self.classes[2],:])
        #lab_lossg = self.lab_clsg(embeds_g[vocab_size:vocab_size+self.classes[1],:])
        #print(lab_loss1)
        #print(lab_loss2)
        #lab_loss = torch.cat((lab_loss1,lab_loss2),dim=1)

        div_loss = torch.mean(torch.cosine_similarity(embeds_0,embeds_1,dim=1)**2)
        #print("--------------")
        #print(logits0.tolist()[0])
        #print(logits1.tolist()[0])
        #print(logits2.tolist()[0])
        #print(logitsG.tolist()[0])

        return {"0":logits0,
                "1":logits1,
                #"2":logits2,
                "g":logitsG,
                'corm0':corm0,
                'corm1':corm1,
                #'corm2':corm2,
                'sc0': 0,
                'sc1': 1,
                'lab_loss0':lab_loss0,
                'lab_loss1':lab_loss1}


class n_SuperClsF(nn.Module):
    '''
    this classifier
    '''
    def __init__(self,input_dim=300,hidden_dim=300,classes=[9,137]):
        super(n_SuperClsF,self).__init__()
        self.classes = classes
        self.local_0 = MLP(input_dim=1*hidden_dim,output_dim=classes[0])
        self.local_1 = MLP(input_dim=2*hidden_dim,output_dim=classes[1]-classes[0])
        self.local_2 = MLP(input_dim=2*hidden_dim,output_dim=classes[2]-classes[1])
        self.local_3 = MLP(input_dim=2*hidden_dim,output_dim=classes[3]-classes[2])
        self.lab_cls0 = MLP(input_dim=input_dim,output_dim=classes[0])
        self.lab_cls1 = MLP(input_dim=input_dim,output_dim=classes[1]-classes[0])
        self.lab_cls2 = MLP(input_dim=input_dim,output_dim=classes[2]-classes[1])
        self.lab_cls3 = MLP(input_dim=input_dim,output_dim=classes[3]-classes[2])
        #self.lab_cls2 = MLP(input_dim=input_dim,output_dim=classes[2]-classes[1])
        self.lab_clsg = MLP(input_dim=input_dim,output_dim=classes[-1])
        self.mapp1 = nn.Linear(hidden_dim,hidden_dim)
        self.mapp0 = nn.Linear(hidden_dim,hidden_dim)
        self.wc0 = nn.Linear(hidden_dim,hidden_dim)
        self.wc1 = nn.Linear(hidden_dim,hidden_dim)
        self.wf0 = nn.Linear(hidden_dim,hidden_dim)
        self.wf1 = nn.Linear(hidden_dim,hidden_dim)
        self.high_g = MLP(input_dim=4*hidden_dim,output_dim=classes[-1])
        self.num_classes = classes[-1]
        self.dropout = nn.Dropout(0.5)
        self.layer_size = 1
        self.hidden_size = hidden_dim
        self.lstm = nn.LSTMCell(2*input_dim,hidden_dim)
        self.h_layers = 1
        self.linears_g = nn.ModuleList([nn.Linear(4*hidden_dim,4*hidden_dim) for idx in range(self.h_layers)])
        self.linears_t = nn.ModuleList([nn.Linear(4*hidden_dim,4*hidden_dim) for idx in range(self.h_layers)])
        self.fc_0 = nn.Sequential(nn.Linear(2*input_dim,input_dim),
                                   nn.LeakyReLU())
        self.fc_1 = nn.Sequential(nn.Linear(2*input_dim,input_dim),
                                   nn.LeakyReLU())
        self.fc_2 = nn.Sequential(nn.Linear(2*input_dim,input_dim),
                                  nn.LeakyReLU())

        self.g_1 = nn.Linear(hidden_dim,hidden_dim)
        self.t_1 = nn.Linear(hidden_dim,hidden_dim)
        self.ws = nn.ModuleList([nn.ModuleList([nn.Linear(hidden_dim,hidden_dim) for _ in range(3)]) for i in range(2)])

    def _highway_layer(self,input_,num_layers=1,act=nn.LeakyReLU()):
        for idx in range(num_layers):
            g = act(self.linears_g[idx](input_))
            t = torch.sigmoid(self.linears_t[idx](input_))
            output = t*g +(1.-t)*input_
        return output

    def lab_reps(self,lab_vec,indx):
        WQ = self.ws[indx][0](lab_vec) # num x hidden
        WK = self.ws[indx][1](lab_vec)
        WV = self.ws[indx][2](lab_vec)
        QK = WQ.matmul(WK.t())/10 # num x num
        QK = torch.softmax(QK,dim=1)
        V  = QK.matmul(WV)
        return V

    def forward(self,embeds,tfidf,labels=None,flag="test"):

        vocab_size = tfidf.shape[1]
        batch_size = tfidf.shape[0]

        # get the last output of GCN and compute the X
        embeds_0 = self.dropout(embeds[0])
        rp0 = tfidf.mm(embeds_0[:-self.num_classes,:])

        logits0 = self.local_0(rp0)
        lab0 = F.normalize(embeds_0[vocab_size:vocab_size+self.classes[0]:,:])
        corm0 = lab0.mm(lab0.t())
        corm0 = torch.softmax(corm0,dim=0)

        embeds_1 = self.dropout(embeds[1])
        rp1 = tfidf.mm(embeds_1[:-self.num_classes,:])
        if flag=="test":
            lab_ct0 = torch.softmax(logits0,dim=1).mm(embeds_0[vocab_size:vocab_size+self.classes[0],:])
        elif flag=="train":
            lab_ct0 = torch.softmax(labels[:,:self.classes[0]],dim=1).mm(
                embeds_0[vocab_size:vocab_size+self.classes[0],:])


        rp1_ = torch.cat((rp1,lab_ct0),dim=1)
        #rp1,cx = self.lstm(rp1,(rp0,cx))
        logits1 = self.local_1(rp1_)

        lab1 = F.normalize(embeds_1[vocab_size+self.classes[0]:vocab_size+self.classes[1],:])
        corm1 = lab1.mm(lab1.t())
        corm1 = torch.softmax(corm1,dim=0)

        embeds_2 = self.dropout(embeds[2])
        rp2 = tfidf.mm(embeds_2[:-self.num_classes,:])
        lab2 = F.normalize(embeds_2[vocab_size+self.classes[1]:vocab_size+self.classes[2],:])
        corm2 = lab2.mm(lab2.t())
        corm2 = torch.softmax(corm2,dim=0)
        if flag=="test":
            lab_ct1 = torch.softmax(logits1,dim=1).mm(embeds_1[vocab_size+self.classes[0]:vocab_size+self.classes[1],:])
        elif flag=="train":
            lab_ct1 = torch.softmax(labels[:,self.classes[0]:self.classes[1]],dim=1).mm(
                embeds_1[vocab_size+self.classes[0]:vocab_size+self.classes[1],:])

        rp2_ = torch.cat((rp2,lab_ct1),dim=1)
        logits2 = self.local_2(rp2_)



        embeds_3 = self.dropout(embeds[3])
        rp3 = tfidf.mm(embeds_3[:-self.num_classes,:])
        lab3 = F.normalize(embeds_3[vocab_size+self.classes[2]:vocab_size+self.classes[3],:])
        corm3 = lab3.mm(lab3.t())
        corm3 = torch.softmax(corm3,dim=0)
        if flag=="test":
            lab_ct2 = torch.softmax(logits2,dim=1).mm(embeds_2[vocab_size+self.classes[1]:vocab_size+self.classes[2],:])
        elif flag=="train":
            lab_ct2 = torch.softmax(labels[:,self.classes[1]:self.classes[2]],dim=1).mm(
                embeds_2[vocab_size+self.classes[1]:vocab_size+self.classes[2],:])

        rp3_ = torch.cat((rp3,lab_ct2),dim=1)
        logits3 = self.local_3(rp3_)


        logitsG = self._highway_layer(torch.cat((rp0,rp1,rp2,rp3),dim=1))
        #rpG = self._highway_layer(torch.cat((rp0,rp1,rp2),dim=1))
        #rpG = self.dropoutg(rpG)
        logitsG = self.high_g(logitsG)
        #print(logitsG)
        #m0,_ = torch.max(logitsG,dim=0)
        #print(m0)
        #embeds_g = self.dropoutg(embeds[2])
        #rpG = tfidf.mm(embeds_g[:-self.num_classes,:])
        #logitsG = self.high_g(rpG)
        #labG = F.normalize(embeds_g[-self.num_classes:,:])
        #cormg = labG.mm(labG.t())
        #cormg = torch.softmax(cormg,dim=0)

        #labg = F.normalize(embeds[1][-self.classes[-1]:,:])
        #corm = labg.mm(labg.t())
        #corm = torch.softmax(corm,dim=1)
        #logitsG = logitsG.matmul(corm)

        lab_loss0 = self.lab_cls0(embeds_0[vocab_size:vocab_size+self.classes[0],:])
        lab_loss1 = self.lab_cls1(embeds_1[vocab_size+self.classes[0]:vocab_size+self.classes[1],:])
        lab_loss2 = self.lab_cls2(embeds_2[vocab_size+self.classes[1]:vocab_size+self.classes[2],:])
        lab_loss3 = self.lab_cls3(embeds_3[vocab_size+self.classes[2]:vocab_size+self.classes[3],:])
        #lab_lossg = self.lab_clsg(embeds_g[vocab_size:vocab_size+self.classes[1],:])
        #print(lab_loss1)
        #print(lab_loss2)
        #lab_loss = torch.cat((lab_loss1,lab_loss2),dim=1)

        div_loss = torch.mean(torch.cosine_similarity(embeds_0,embeds_1,dim=1)**2)
        #print("--------------")
        #print(logits0.tolist()[0])
        #print(logits1.tolist()[0])
        #print(logits2.tolist()[0])
        #print(logitsG.tolist()[0])

        return {"0":logits0,
                "1":logits1,
                "2":logits2,
                "3":logits3,
                "g":logitsG,
                'c_0':corm0,
                'c_1':corm1,
                'c_2':corm2,
                'c_3':corm3,
                'sc0': 0,
                'sc1': 1,
                'l_0':lab_loss0,
                'l_1':lab_loss1,
                'l_2':lab_loss2,
                'l_3':lab_loss3
               }

