#!/usr/bin/python3
#-*- coding:utf-8 -*-
############################
#File Name: classifier.py
#Created Time:
############################
import torch
import torch.nn as nn
from models.mlp import MLP, ACT, MLP1
import torch.nn.functional as F
from torch.autograd import Variable

class Base(nn.Module):
    def __init__(self,input_dim=300,hidden_dim=300,classes=[8,137]):
        super(Base,self).__init__()
        self.num_classes = classes[-1]
        #self.w1 = nn.Linear(2*input_dim,hidden_dim)
        #self.w2 = nn.Linear(hidden_dim,self.num_classes)
        #self.w2 = nn.Linear(hidden_dim,self.num_classes)
        #self.w2 = nn.Linear(hidden_dim,self.num_classes)
        #self.label_w = nn.Linear(input_dim,self.num_classes)
        #self.act = nn.LeakyReLU()
        self.global_cls = MLP(input_dim=4*hidden_dim,output_dim=classes[-1])

    def forward(self,embeds,tfidf,labels=None,flag=None):
        rp0 = tfidf.mm(embeds[0][:-self.num_classes,:])
        rp1 = tfidf.mm(embeds[1][:-self.num_classes,:])
        rp2 = tfidf.mm(embeds[2][:-self.num_classes,:])
        rp3 = tfidf.mm(embeds[3][:-self.num_classes,:])
        rps = torch.cat((rp0,rp1,rp2,rp3),dim=1)
        #logits = self.w2(self.act(self.w1(rps)))
        logits = self.global_cls(rps)
        return {"g":logits}


class LabelAttention(nn.Module):
    def __init__(self,input_dim=300,hidden_dim=300,classes=[8,137]):
        super(LabelAttention,self).__init__()
        self.num_classes = classes[-1]
        self.w1 = nn.Linear(2*input_dim,hidden_dim)
        self.w2 = nn.Linear(hidden_dim,self.num_classes)
        self.label_w = nn.Linear(input_dim,self.num_classes)
        self.act = nn.LeakyReLU()
        self.com = ACT(2*hidden_dim,hidden_dim)

    def forward(self,embeds,tfidf):
        #rp0 = tfidf.mm(embeds[0][:-self.num_classes,:])
        rp1 = tfidf.mm(embeds[1][:-self.num_classes,:])
        #rps = torch.cat((rp0,rp1),dim=1)
        #rps = self.com(rps)
        rps = rp1

        # calculate label attention
        sc = rps.mm(embeds[1][-self.num_classes:,:].t())
        sc = torch.softmax(sc/10,dim=1) # n x num_class
        context = sc.mm(embeds[1][-self.num_classes:,:])

        rps = torch.cat((rps, context),dim=1)

        logits = self.w2(self.act(self.w1(rps)))
        #logits = rps

        return {"g":logits}




class HMCLab(nn.Module):
    '''
    this classifier
    '''
    def __init__(self,input_dim=300,hidden_dim=300,classes=[9,137]):
        super(HMCLab,self).__init__()
        self.classes = classes
        self.local_1 = MLP(input_dim=hidden_dim,hidden_dim=hidden_dim,output_dim=classes[0])
        self.local_2 = MLP(hidden_dim,hidden_dim,classes[1]-classes[0])
        self.globalc = MLP(hidden_dim,hidden_dim,classes[-1])
        self.high_1 = ACT(input_dim,hidden_dim)
        self.high_2 = ACT(2*hidden_dim,hidden_dim)
        #self.high_g = ACT(2*hidden_dim,hidden_dim)
        self.num_classes = sum(classes)

    def forward(self,embeds,tfidf):
        # get the last output of GCN and compute the X
        x = tfidf.mm(embeds[1][:-self.num_classes,:])

        # compute the first_layer hidden state and output logits1
        AG1 = self.high_1(x)
        logits1 = self.local_1(AG1)

        # compute the second layer hidden state and output logits2
        AG2 = self.high_2(torch.cat((AG1,x),dim=1))
        logits2 = self.local_2(AG2)


        # compute the logitsG
        AGG = self.high_g(torch.cat((AG2,x),dim=1))
        logitsG = self.globalc(AGG)

        return [logits1, logits2, logitsG]



class HMC(nn.Module):
    '''
    this classifier
    '''
    def __init__(self,input_dim=300,hidden_dim=300,classes=[9,137]):
        super(HMC,self).__init__()
        self.classes = classes
        self.local_1 = MLP(input_dim=hidden_dim,hidden_dim=hidden_dim,output_dim=classes[0])
        self.local_2 = MLP(hidden_dim,hidden_dim,classes[1]-classes[0])
        self.globalc = MLP(hidden_dim,hidden_dim,classes[1])
        self.high_1 = ACT(input_dim,hidden_dim)
        self.high_2 = ACT(2*hidden_dim,hidden_dim)
        self.high_g = ACT(2*hidden_dim,hidden_dim)
        self.num_classes = classes[-1]

    def forward(self,embeds,tfidf):
        # get the last output of GCN and compute the X
        x = tfidf.mm(embeds[1][:-self.num_classes,:])

        # compute the first_layer hidden state and output logits1
        AG1 = self.high_1(x)
        logits1 = self.local_1(AG1)

        # compute the second layer hidden state and output logits2
        AG2 = self.high_2(torch.cat((AG1,x),dim=1))
        logits2 = self.local_2(AG2)


        # compute the logitsG
        AGG = self.high_g(torch.cat((AG2,AG1),dim=1))
        logitsG = self.globalc(AGG)

        return {'1':logits1, '2':logits2, 'g':logitsG}


class MHMC(nn.Module):
    '''
    this classifier
    '''
    def __init__(self,input_dim=300,hidden_dim=300,classes=[9,137]):
        super(MHMC,self).__init__()
        self.classes = classes
        self.local_1 = MLP(hidden_dim,hidden_dim,classes[0])
        self.local_2 = MLP(hidden_dim,hidden_dim,classes[1]-classes[0])
        #self.high_1 = MLP(input_dim,hidden_dim)
        #self.high_2 = MLP(hidden_dim,hidden_dim)
        self.high_g = MLP(2*hidden_dim,hidden_dim,classes[-1])
        self.num_classes = classes[-1]

    def forward(self,embeds,tfidf):
        # get the last output of GCN and compute the X
        rp1 = tfidf.mm(embeds[0][:-self.num_classes,:])
        rp2 = tfidf.mm(embeds[1][:-self.num_classes,:])

        # compute the first_layer hidden state and output logits1
        #AG1 = self.high_1(x)
        logits1 = self.local_1(rp1)

        # compute the second layer hidden state and output logits2
        #AG2 = self.high_2(torch.cat((AG1,x),dim=1))
        logits2 = self.local_2(rp2)

        # compute the logitsG
        logitsG = self.high_g(torch.cat((rp1,rp2),dim=1))

        #sc0 = rp0.mm(embeds[0][-self.num_classes:,:].t())
        #sc1 = rp1.mm(embeds[1][-self.num_classes:,:].t())
        #rps = torch.cat((rp0,rp1),dim=1)
        #logits = self.w2(self.act(self.w1(rps)))
        #lab = self.label_w(embeds[-1][-self.num_classes:,:])

        return {'1':logits1, '2':logits2, 'g':logitsG}

class SuperCls0(nn.Module):
    '''
    this classifier
    '''
    def __init__(self,input_dim=300,hidden_dim=300,classes=[9,137]):
        super(SuperCls0,self).__init__()
        self.classes = classes
        self.local_1 = MLP(hidden_dim,hidden_dim,classes[0])
        self.local_2 = MLP(hidden_dim,hidden_dim,classes[1]-classes[0])
        self.lab_cls = MLP(hidden_dim,hidden_dim,classes[-1])
        #self.high_1 = MLP(input_dim,hidden_dim)
        #self.high_2 = MLP(hidden_dim,hidden_dim)
        #self.comb = MLP(2*hidden_dim,hidden_dim)
        self.mapp1 = nn.Linear(hidden_dim,hidden_dim)
        self.mapp2 = nn.Linear(hidden_dim,hidden_dim)
        self.high_g = MLP(2*hidden_dim,hidden_dim*4,classes[-1])
        self.num_classes = classes[-1]
        self.dropout1 = nn.Dropout(0.5)
        self.dropout2 = nn.Dropout(0.5)

    def forward(self,embeds,tfidf):

        vocab_size = tfidf.shape[1]
        # get the last output of GCN and compute the X
        embeds_0 = self.dropout1(embeds[0])
        #rp1 = tfidf.mm(embeds[0][:-self.num_classes,:])
        rp1 = tfidf.mm(embeds_0[:-self.num_classes,:])
        embeds_1 = self.dropout1(embeds[1])
        #rp2 = tfidf.mm(embeds[1][:-self.num_classes,:])
        rp2 = 0.5*tfidf.mm(embeds_1[:-self.num_classes,:]) + 0.5*rp1

        # compute the first_layer hidden state and output logits1
        #AG1 = self.high_1(x)
        logits1 = self.local_1(rp1)
        #print(embeds[0][-self.classes[-1]:-self.classes[-2],:])
        #rp1 = F.normalize(rp1)
        #la1 = F.normalize(embeds[0][vocab_size:vocab_size+self.classes[0],:])
        #sc1 = rp1.mm(self.mapp1(embeds[0][vocab_size:vocab_size+self.classes[0],:]).t())
        sc1 = rp1.mm(self.mapp1(embeds_0[vocab_size:vocab_size+self.classes[0],:]).t())
        #sc1 = rp1.mm(la1.t())

        # compute the second layer hidden state and output logits2
        #AG2 = self.high_2(torch.cat((AG1,x),dim=1))
        logits2 = self.local_2(rp2)
        #rp2 = F.normalize(rp2)
        #la2 = F.normalize(embeds[1][vocab_size+self.classes[0]:,:])
        sc2 = rp1.mm(self.mapp2(embeds_1[vocab_size+self.classes[0]:,:]).t())
        #sc2 = rp2.mm(la2.t())


        # compute the logitsG
        logitsG = self.high_g(torch.cat((rp1,rp2),dim=1))
        #logitsG = self.high_g(rp1)


        #la2 = F.normalize(embeds[1][vocab_size:,:])
        #lab_loss = la2.mm(la2.t())
        #lab_loss = self.lab_cls(embeds[1][vocab_size:,:])
        lab_loss = self.lab_cls(embeds_1[vocab_size:,:])

        #scG = rp1.mm(embeds[0][-sum(self.classes):-sum(self.classes[1:]),:].t())

        #sc1 = rp1.mm(embeds[1][-self.num_classes:,:].t())
        #rps = torch.cat((rp0,rp1),dim=1)
        #logits = self.w2(self.act(self.w1(rps)))
        #lab = self.label_w(embeds[-1][-self.num_classes:,:])

        #return [logits1, logits2, logitsG, sc1, sc2]
        #print(logitsG.shape)
        return {"1":logits1,"2":logits2,"g":logitsG, 'sc1': sc1, 'sc2':sc2, 'lab_loss':lab_loss}

class SuperCls2(nn.Module):
    '''
    this classifier
    '''
    def __init__(self,input_dim=300,hidden_dim=300,classes=[9,137]):
        super(SuperCls2,self).__init__()
        self.classes = classes
        self.local_0 = MLP(input_dim=hidden_dim,output_dim=classes[0])
        self.local_1 = MLP(input_dim=hidden_dim,output_dim=classes[1]-classes[0])
        self.lab_cls = MLP(input_dim=input_dim,output_dim=classes[-1])
        self.mapp1 = nn.Linear(hidden_dim,hidden_dim)
        self.mapp2 = nn.Linear(hidden_dim,hidden_dim)
        self.high_g = MLP(input_dim=2*hidden_dim,output_dim=classes[-1])
        self.num_classes = classes[-1]
        self.dropout1 = nn.Dropout(0.5)
        self.dropout2 = nn.Dropout(0.5)
        self.layer_size = 2
        self.hidden_size = hidden_dim
        self.lstm = nn.LSTM(input_dim,
                            hidden_dim,
                            self.layer_size,
                            dropout=0.5)
        self.h_layers = 1
        self.linears_g = nn.ModuleList([nn.Linear(2*hidden_dim,2*hidden_dim) for idx in range(self.h_layers)])
        self.linears_t = nn.ModuleList([nn.Linear(2*hidden_dim,2*hidden_dim) for idx in range(self.h_layers)])

    def _highway_layer(self,input_,num_layers=1,act=nn.LeakyReLU()):
        for idx in range(num_layers):
            g = act(self.linears_g[idx](input_))
            t = torch.sigmoid(self.linears_t[idx](input_))
            output = t*g +(1.-t)*input_
        return output

    def forward(self,embeds,tfidf):

        vocab_size = tfidf.shape[1]
        batch_size = tfidf.shape[0]
        # get the last output of GCN and compute the X
        embeds_0 = self.dropout1(embeds[0])
        rp0 = tfidf.mm(embeds_0[:-self.num_classes,:])

        embeds_1 = self.dropout1(embeds[1])
        rp1 = tfidf.mm(embeds_1[:-self.num_classes,:])

        # make inputs for lstm
        inputs = torch.stack((rp0,rp1)) # seq x batch x dim
        #print(inputs.shape)

        h0 = Variable(torch.zeros(self.layer_size, batch_size, self.hidden_size).cuda())
        c0 = Variable(torch.zeros(self.layer_size, batch_size, self.hidden_size).cuda())
        lstm_outputs, (final_hidden_state, final_cell_state) = self.lstm(inputs,(h0,c0))
        #print(lstm_outputs.shape)

        # compute the first_layer hidden state and output logits1
        logits0 = self.local_0(lstm_outputs[0])
        print(logits0)

        #la1 = F.normalize(embeds[0][vocab_size:vocab_size+self.classes[0],:])
        #sc1 = rp1.mm(self.mapp1(embeds[0][vocab_size:vocab_size+self.classes[0],:]).t())
        #sc1 = lstm_outputs[0].mm(self.mapp1(embeds_0[vocab_size:vocab_size+self.classes[0],:]).t())
        #sc1 = rp1.mm(la1.t())

        # compute the second layer hidden state and output logits2
        #AG2 = self.high_2(torch.cat((AG1,x),dim=1))
        logits1 = self.local_1(lstm_outputs[1])
        #rp2 = F.normalize(rp2)
        #la2 = F.normalize(embeds[1][vocab_size+self.classes[0]:,:])
        sc2 = lstm_outputs[1].mm(self.mapp2(embeds_1[vocab_size+self.classes[0]:,:]).t())
        #sc2 = rp2.mm(la2.t())


        # compute the logitsG
        logitsG = self._highway_layer(torch.cat((lstm_outputs[0],lstm_outputs[1]),dim=1))
        logitsG = self.high_g(logitsG)


        #la2 = F.normalize(embeds[1][vocab_size:,:])
        #lab_loss = la2.mm(la2.t())
        #lab_loss = self.lab_cls(embeds[1][vocab_size:,:])
        lab_loss = self.lab_cls(embeds_1[vocab_size:,:])

        #scG = rp1.mm(embeds[0][-sum(self.classes):-sum(self.classes[1:]),:].t())

        #sc1 = rp1.mm(embeds[1][-self.num_classes:,:].t())
        #rps = torch.cat((rp0,rp1),dim=1)
        #logits = self.w2(self.act(self.w1(rps)))
        #lab = self.label_w(embeds[-1][-self.num_classes:,:])

        #return [logits1, logits2, logitsG, sc1, sc2]
        #print(logitsG.shape)
        return {"0":logits0,
                "1":logits1,
                "g":logitsG}
                #'sc1': sc1, 'sc2':sc2, 'lab_loss':lab_loss}

class SuperCls3(nn.Module):
    '''
    this classifier
    '''
    def __init__(self,input_dim=300,hidden_dim=512,classes=[9,137]):
        super(SuperCls3,self).__init__()
        self.classes = classes
        self.local_1 = MLP(input_dim=2*hidden_dim,output_dim=classes[0])
        self.local_2 = MLP(input_dim=3*hidden_dim,output_dim=classes[1]-classes[0])
        self.lab_cls = MLP(input_dim=input_dim,output_dim=classes[-1])
        self.mapp1 = nn.Linear(hidden_dim,hidden_dim)
        self.mapp2 = nn.Linear(hidden_dim,hidden_dim)
        self.high_g = MLP(input_dim=hidden_dim,output_dim=classes[-1])
        self.num_classes = classes[-1]
        self.dropout1 = nn.Dropout(0.5)
        self.dropout2 = nn.Dropout(0.5)
        self.layer_size = 1
        self.hidden_size = hidden_dim
        self.lstm = nn.LSTM(input_dim,
                            hidden_dim,
                            self.layer_size,
                            dropout=0.5)

    def forward(self,embeds,tfidf):

        vocab_size = tfidf.shape[1]
        batch_size = tfidf.shape[0]
        # get the last output of GCN and compute the X
        embeds_0 = self.dropout1(embeds[0])
        rp1 = tfidf.mm(embeds_0[:-self.num_classes,:])

        embeds_1 = self.dropout1(embeds[1])
        rp2 = tfidf.mm(embeds_1[:-self.num_classes,:])

        # make inputs for lstm
        inputs = torch.stack((rp1,rp2)) # seq x batch x dim

        h0 = Variable(torch.zeros(self.layer_size, batch_size, self.hidden_size).cuda())
        c0 = Variable(torch.zeros(self.layer_size, batch_size, self.hidden_size).cuda())
        lstm_outputs, (final_hidden_state, final_cell_state) = self.lstm(inputs,(h0,c0))

        #sc1 = lstm_outputs[0].mm(self.mapp1(embeds_0[vocab_size:vocab_size+self.classes[0],:]).t())
        sc1 = torch.softmax(lstm_outputs[0].mm(self.mapp1(embeds_0[vocab_size:vocab_size+self.classes[0],:]).t()),dim=1)
        ct1 = sc1.mm(embeds_0[vocab_size:vocab_size+self.classes[0],:])
        #logits1 = self.local_1(lstm_outputs[0]+ct1)
        logits1 = self.local_1(torch.cat((lstm_outputs[0],ct1),dim=1))

        sc2 = torch.softmax(lstm_outputs[1].mm(
                self.mapp2(embeds_1[vocab_size+self.classes[0]:vocab_size+self.classes[1],:]).t()),dim=1)
        ct2 = sc2.mm(embeds_1[vocab_size+self.classes[0]:vocab_size+self.classes[1],:])
        # compute the second layer hidden state and output logits2
        #logits2 = self.local_2(lstm_outputs[1]+ct1+ct2)
        logits2 = self.local_2(torch.cat((lstm_outputs[1],ct1,ct2),dim=1))
        #sc2 = lstm_outputs[1].mm(self.mapp2(embeds_1[vocab_size+self.classes[0]:,:]).t())

        # compute the logitsG
        logitsG = self.high_g(final_hidden_state.squeeze())

        lab_loss = self.lab_cls(embeds_1[vocab_size:,:])

        return {"1":logits1,"2":logits2,"g":logitsG, 'sc1': sc1, 'sc2':sc2, 'lab_loss':lab_loss}


class SuperClsF(nn.Module):
    '''
    this classifier
    '''
    def __init__(self,input_dim=300,hidden_dim=300,classes=[9,137]):
        super(SuperClsF,self).__init__()
        self.classes = classes
        self.local_0 = MLP(input_dim=1*hidden_dim,output_dim=classes[0])
        self.local_1 = MLP(input_dim=1*hidden_dim,output_dim=classes[1]-classes[0])
        self.local_2 = MLP(input_dim=1*hidden_dim,output_dim=classes[2]-classes[1])
        self.local_3 = MLP(input_dim=1*hidden_dim,output_dim=classes[3]-classes[2])

        #self.local_0 = MLP(input_dim=1*hidden_dim,output_dim=classes[0])
        #self.local_1 = MLP(input_dim=1*hidden_dim+classes[0],output_dim=classes[1]-classes[0])
        #self.local_2 = MLP(input_dim=1*hidden_dim+classes[1]-classes[0],output_dim=classes[2]-classes[1])
        #self.local_3 = MLP(input_dim=1*hidden_dim+classes[2]-classes[1],output_dim=classes[3]-classes[2])

        self.lab_cls0 = MLP(input_dim=input_dim,output_dim=classes[0])
        self.lab_cls1 = MLP(input_dim=input_dim,output_dim=classes[1]-classes[0])
        self.lab_cls2 = MLP(input_dim=input_dim,output_dim=classes[2]-classes[1])
        self.lab_cls3 = MLP(input_dim=input_dim,output_dim=classes[3]-classes[2])

        self.mapp1 = nn.Linear(hidden_dim,hidden_dim)
        self.mapp2 = nn.Linear(hidden_dim,hidden_dim)
        self.high_g = MLP(input_dim=4*hidden_dim,output_dim=classes[-1])
        self.num_classes = classes[-1]
        self.dropout0 = nn.Dropout(0.3)
        self.dropout1 = nn.Dropout(0.3)
        self.dropout2 = nn.Dropout(0.3)
        self.dropout3 = nn.Dropout(0.3)
        self.layer_size = 1
        self.hidden_size = hidden_dim
        self.h_layers = 1
        self.linears_g = nn.ModuleList([nn.Linear(4*hidden_dim,4*hidden_dim) for idx in range(self.h_layers)])
        self.linears_t = nn.ModuleList([nn.Linear(4*hidden_dim,4*hidden_dim) for idx in range(self.h_layers)])
        #self.lstm = nn.LSTM(input_dim,
        #                    hidden_dim,
        #                    self.layer_size,
        #                    dropout=0.5)
        #self.com_1 = nn.Sequential(nn.Linear(2*hidden_dim,hidden_dim),
        #                           nn.LeakyReLU())
        #self.com_2 = nn.Sequential(nn.Linear(3*hidden_dim,hidden_dim),
        #                           nn.LeakyReLU())

    def _highway_layer(self,input_,num_layers=1,act=nn.LeakyReLU()):
        for idx in range(num_layers):
            g = act(self.linears_g[idx](input_))
            t = torch.sigmoid(self.linears_t[idx](input_))
            output = t*g +(1.-t)*input_
        return output


    #def _hierarchical_violation(parent_scores,child_scores):
    #    index_list = [(0,15),(15,37),(52,20),(72,9),(81,7),(88,17),(106,14),(120,5),(125,3)]
    #    violation_losses = 0.0
    #    for i in range(len(index_list)):
    #        (left_index, step) = index_list[i]
    #        current_parent_scores = parent_scores[:,i]
    #        current_child_scores = child_scores[:,left_index:left_index+step]
    #        margin = torch.maximum((current_child_scores-current_parent_scores),0)
    #        losses = mean(sum(square(margin),axis=1))
    #        violation_losses = violation_losses + 0.1*losses


    def forward(self,embeds,tfidf):

        vocab_size = tfidf.shape[1]
        batch_size = tfidf.shape[0]

        # get the last output of GCN and compute the X
        embeds_0 = self.dropout0(embeds[0])
        rp0 = tfidf.mm(embeds_0[:-self.num_classes,:])
        #sc1 = rp1.mm(self.mapp1(embeds_1[vocab_size:vocab_size+self.classes[0],:]).t())
        #scs1 = torch.softmax(sc1,dim=1)
        #ct1 = scs1.mm(embeds_1[vocab_size:vocab_size+self.classes[0],:])
        logits0 = self.local_0(rp0)
        lab0 = F.normalize(embeds_0[vocab_size:vocab_size+self.classes[0]:,:])
        corm0 = lab0.mm(lab0.t())
        corm0 = torch.softmax(corm0,dim=0)

        embeds_1 = self.dropout1(embeds[1])
        rp1 = tfidf.mm(embeds_1[:-self.num_classes,:])
        #sc2 = rp2.mm(
        #        self.mapp2(embeds_2[vocab_size+self.classes[0]:vocab_size+self.classes[1],:]).t())
        #scs2 = torch.softmax(sc2,dim=1)
        #ct2 = scs2.mm(embeds_1[vocab_size+self.classes[0]:vocab_size+self.classes[1],:])
        # compute the second layer hidden state and output logits2
        #rp2 = self.com_2(torch.cat((rp2,ct1,ct2),dim=1))

        #rp1_ = torch.cat((rp1,logits0),dim=1)
        logits1 = self.local_1(rp1)
        lab1 = F.normalize(embeds_1[vocab_size+self.classes[0]:vocab_size+self.classes[1],:])
        corm1 = lab1.mm(lab1.t())
        corm1 = torch.softmax(corm1/1.1,dim=0)
        # compute the logitsG

        embeds_2 = self.dropout2(embeds[2])
        rp2 = tfidf.mm(embeds_2[:-self.num_classes,:])

        #rp2_ = torch.cat((rp2,logits1),dim=1)
        logits2 = self.local_2(rp2)
        lab2 = F.normalize(embeds_2[vocab_size+self.classes[1]:vocab_size+self.classes[2],:])
        corm2 = lab2.mm(lab2.t())
        corm2 = torch.softmax(corm2/1.1,dim=0)

        embeds_3 = self.dropout3(embeds[3])
        rp3 = tfidf.mm(embeds_3[:-self.num_classes,:])

        #rp3_ = torch.cat((rp3,logits2),dim=1)
        logits3 = self.local_3(rp3)
        lab3 = F.normalize(embeds_3[vocab_size+self.classes[2]:vocab_size+self.classes[3],:])
        corm3 = lab3.mm(lab3.t())
        corm3 = torch.softmax(corm3/1.1,dim=0)

        #logitsG = self._highway_layer(torch.cat((rp0,rp1),dim=1))
        logitsG = self._highway_layer(torch.cat((rp0,rp1,rp2,rp3),dim=1))
        logitsG = self.high_g(logitsG)

        lab_loss0 = self.lab_cls0(embeds_0[vocab_size:vocab_size+self.classes[0],:])
        lab_loss1 = self.lab_cls1(embeds_1[vocab_size+self.classes[0]:vocab_size+self.classes[1],:])
        lab_loss2 = self.lab_cls2(embeds_2[vocab_size+self.classes[1]:vocab_size+self.classes[2],:])
        lab_loss3 = self.lab_cls3(embeds_3[vocab_size+self.classes[2]:vocab_size+self.classes[3],:])

        return {"0":logits0,
                "1":logits1,
                "2":logits2,
                "3":logits3,
                "c_0":corm0,
                "c_1":corm1,
                "c_2":corm2,
                "c_3":corm3,
                "l_0":lab_loss0,
                "l_1":lab_loss1,
                "l_2":lab_loss2,
                "l_3":lab_loss3,
                "g":logitsG}



class SuperClsF1(nn.Module):
    '''
    this classifier
    '''
    def __init__(self,input_dim=300,hidden_dim=300,classes=[9,137]):
        super(SuperClsF1,self).__init__()
        self.classes = classes
        self.local_0 = MLP(input_dim=1*hidden_dim,output_dim=classes[0])
        self.local_1 = MLP(input_dim=1*hidden_dim,output_dim=classes[1]-classes[0])
        self.local_2 = MLP(input_dim=1*hidden_dim,output_dim=classes[2]-classes[1])
        self.local_3 = MLP(input_dim=1*hidden_dim,output_dim=classes[3]-classes[2])

        #self.local_0 = MLP(input_dim=1*hidden_dim,output_dim=classes[0])
        #self.local_1 = MLP(input_dim=1*hidden_dim+classes[0],output_dim=classes[1]-classes[0])
        #self.local_2 = MLP(input_dim=1*hidden_dim+classes[1]-classes[0],output_dim=classes[2]-classes[1])
        #self.local_3 = MLP(input_dim=1*hidden_dim+classes[2]-classes[1],output_dim=classes[3]-classes[2])

        self.lab_cls0 = MLP(input_dim=input_dim,output_dim=classes[0])
        self.lab_cls1 = MLP(input_dim=input_dim,output_dim=classes[1]-classes[0])
        self.lab_cls2 = MLP(input_dim=input_dim,output_dim=classes[2]-classes[1])
        self.lab_cls3 = MLP(input_dim=input_dim,output_dim=classes[3]-classes[2])

        self.fc_1 = nn.Sequential(nn.Linear(2*input_dim,input_dim),
                                   nn.LeakyReLU())
        self.fc_2 = nn.Sequential(nn.Linear(2*input_dim,input_dim),
                                   nn.LeakyReLU())
        self.fc_3 = nn.Sequential(nn.Linear(2*input_dim,input_dim),
                                   nn.LeakyReLU())

        self.wc0 = nn.Linear(hidden_dim,hidden_dim)
        self.wc1 = nn.Linear(hidden_dim,hidden_dim)
        self.wc2 = nn.Linear(hidden_dim,hidden_dim)

        self.wf0 = nn.Linear(hidden_dim,hidden_dim)
        self.wf1 = nn.Linear(hidden_dim,hidden_dim)
        self.wf2 = nn.Linear(hidden_dim,hidden_dim)

        self.mapp1 = nn.Linear(hidden_dim,hidden_dim)
        self.mapp2 = nn.Linear(hidden_dim,hidden_dim)
        self.high_g = MLP(input_dim=4*hidden_dim,output_dim=classes[-1])
        self.num_classes = classes[-1]
        self.dropout0 = nn.Dropout(0.3)
        self.dropout1 = nn.Dropout(0.3)
        self.dropout2 = nn.Dropout(0.3)
        self.dropout3 = nn.Dropout(0.3)
        self.layer_size = 1
        self.hidden_size = hidden_dim
        self.h_layers = 1
        self.linears_g = nn.ModuleList([nn.Linear(4*hidden_dim,4*hidden_dim) for idx in range(self.h_layers)])
        self.linears_t = nn.ModuleList([nn.Linear(4*hidden_dim,4*hidden_dim) for idx in range(self.h_layers)])
        #self.lstm = nn.LSTM(input_dim,
        #                    hidden_dim,
        #                    self.layer_size,
        #                    dropout=0.5)
        #self.com_1 = nn.Sequential(nn.Linear(2*hidden_dim,hidden_dim),
        #                           nn.LeakyReLU())
        #self.com_2 = nn.Sequential(nn.Linear(3*hidden_dim,hidden_dim),
        #                           nn.LeakyReLU())

    def _highway_layer(self,input_,num_layers=1,act=nn.LeakyReLU()):
        for idx in range(num_layers):
            g = act(self.linears_g[idx](input_))
            t = torch.sigmoid(self.linears_t[idx](input_))
            output = t*g +(1.-t)*input_
        return output


    #def _hierarchical_violation(parent_scores,child_scores):
    #    index_list = [(0,15),(15,37),(52,20),(72,9),(81,7),(88,17),(106,14),(120,5),(125,3)]
    #    violation_losses = 0.0
    #    for i in range(len(index_list)):
    #        (left_index, step) = index_list[i]
    #        current_parent_scores = parent_scores[:,i]
    #        current_child_scores = child_scores[:,left_index:left_index+step]
    #        margin = torch.maximum((current_child_scores-current_parent_scores),0)
    #        losses = mean(sum(square(margin),axis=1))
    #        violation_losses = violation_losses + 0.1*losses


    def forward(self,embeds,tfidf):

        vocab_size = tfidf.shape[1]
        batch_size = tfidf.shape[0]

        # get the last output of GCN and compute the X
        embeds_0 = self.dropout0(embeds[0])
        rp0 = tfidf.mm(embeds_0[:-self.num_classes,:])
        #sc1 = rp1.mm(self.mapp1(embeds_1[vocab_size:vocab_size+self.classes[0],:]).t())
        #scs1 = torch.softmax(sc1,dim=1)
        #ct1 = scs1.mm(embeds_1[vocab_size:vocab_size+self.classes[0],:])
        logits0 = self.local_0(rp0)
        lab0 = F.normalize(embeds_0[vocab_size:vocab_size+self.classes[0]:,:])
        corm0 = lab0.mm(lab0.t())
        corm0 = torch.softmax(corm0,dim=0)

        embeds_1 = self.dropout1(embeds[1])
        rp1 = tfidf.mm(embeds_1[:-self.num_classes,:])
        #sc2 = rp2.mm(
        #        self.mapp2(embeds_2[vocab_size+self.classes[0]:vocab_size+self.classes[1],:]).t())
        #scs2 = torch.softmax(sc2,dim=1)
        #ct2 = scs2.mm(embeds_1[vocab_size+self.classes[0]:vocab_size+self.classes[1],:])
        # compute the second layer hidden state and output logits2
        #rp2 = self.com_2(torch.cat((rp2,ct1,ct2),dim=1))

        #rp1_ = torch.cat((rp1,logits0),dim=1)
        lab_ct0 = torch.sigmoid(logits0).mm(embeds_0[vocab_size:vocab_size+self.classes[0],:])

        rp1 = self.fc_1(torch.cat((self.wf0(rp1),self.wc0(lab_ct0)),dim=1))
        logits1 = self.local_1(rp1)
        lab1 = F.normalize(embeds_1[vocab_size+self.classes[0]:vocab_size+self.classes[1],:])
        corm1 = lab1.mm(lab1.t())
        corm1 = torch.softmax(corm1,dim=0)
        # compute the logitsG

        embeds_2 = self.dropout2(embeds[2])
        rp2 = tfidf.mm(embeds_2[:-self.num_classes,:])

        lab_ct1 = torch.sigmoid(logits1).mm(embeds_1[vocab_size+self.classes[0]:vocab_size+self.classes[1],:])

        rp2 = self.fc_2(torch.cat((self.wf1(rp2),self.wc1(lab_ct1)),dim=1))
        #rp2_ = torch.cat((rp2,logits1),dim=1)
        logits2 = self.local_2(rp2)
        lab2 = F.normalize(embeds_2[vocab_size+self.classes[1]:vocab_size+self.classes[2],:])
        corm2 = lab2.mm(lab2.t())
        corm2 = torch.softmax(corm2,dim=0)

        embeds_3 = self.dropout3(embeds[3])
        rp3 = tfidf.mm(embeds_3[:-self.num_classes,:])
        lab_ct2 = torch.sigmoid(logits2).mm(embeds_2[vocab_size+self.classes[1]:vocab_size+self.classes[2],:])

        rp3 = self.fc_3(torch.cat((self.wf2(rp3),self.wc2(lab_ct2)),dim=1))

        #rp3_ = torch.cat((rp3,logits2),dim=1)
        logits3 = self.local_3(rp3)
        lab3 = F.normalize(embeds_3[vocab_size+self.classes[2]:vocab_size+self.classes[3],:])
        corm3 = lab3.mm(lab3.t())
        corm3 = torch.softmax(corm3,dim=0)

        #logitsG = self._highway_layer(torch.cat((rp0,rp1),dim=1))
        logitsG = self._highway_layer(torch.cat((rp0,rp1,rp2,rp3),dim=1))
        logitsG = self.high_g(logitsG)

        lab_loss0 = self.lab_cls0(embeds_0[vocab_size:vocab_size+self.classes[0],:])
        lab_loss1 = self.lab_cls1(embeds_1[vocab_size+self.classes[0]:vocab_size+self.classes[1],:])
        lab_loss2 = self.lab_cls2(embeds_2[vocab_size+self.classes[1]:vocab_size+self.classes[2],:])
        lab_loss3 = self.lab_cls3(embeds_3[vocab_size+self.classes[2]:vocab_size+self.classes[3],:])

        return {"0":logits0,
                "1":logits1,
                "2":logits2,
                "3":logits3,
                "c_0":corm0,
                "c_1":corm1,
                "c_2":corm2,
                "c_3":corm3,
                "l_0":lab_loss0,
                "l_1":lab_loss1,
                "l_2":lab_loss2,
                "l_3":lab_loss3,
                "g":logitsG}

class SuperClsFL(nn.Module):
    '''
    this classifier
    '''
    def __init__(self,input_dim=300,hidden_dim=300,classes=[9,137]):
        super(SuperClsFL,self).__init__()
        self.classes = classes
        self.local_0 = MLP(input_dim=1*hidden_dim,output_dim=classes[0])
        self.local_1 = MLP(input_dim=1*hidden_dim,output_dim=classes[1]-classes[0])
        self.local_2 = MLP(input_dim=1*hidden_dim,output_dim=classes[2]-classes[1])
        self.local_3 = MLP(input_dim=1*hidden_dim,output_dim=classes[3]-classes[2])

        #self.local_0 = MLP(input_dim=1*hidden_dim,output_dim=classes[0])
        #self.local_1 = MLP(input_dim=1*hidden_dim+classes[0],output_dim=classes[1]-classes[0])
        #self.local_2 = MLP(input_dim=1*hidden_dim+classes[1]-classes[0],output_dim=classes[2]-classes[1])
        #self.local_3 = MLP(input_dim=1*hidden_dim+classes[2]-classes[1],output_dim=classes[3]-classes[2])

        self.lab_cls0 = MLP(input_dim=input_dim,output_dim=classes[0])
        self.lab_cls1 = MLP(input_dim=input_dim,output_dim=classes[1]-classes[0])
        self.lab_cls2 = MLP(input_dim=input_dim,output_dim=classes[2]-classes[1])
        self.lab_cls3 = MLP(input_dim=input_dim,output_dim=classes[3]-classes[2])

        self.lstm = nn.LSTMCell(2*hidden_dim,hidden_dim)

        #self.fc_1 = nn.Sequential(nn.Linear(2*input_dim,input_dim),
        #                           nn.LeakyReLU())
        #self.fc_2 = nn.Sequential(nn.Linear(2*input_dim,input_dim),
        #                           nn.LeakyReLU())
        #self.fc_3 = nn.Sequential(nn.Linear(2*input_dim,input_dim),
        #                           nn.LeakyReLU())

        #self.wc0 = nn.Linear(hidden_dim,hidden_dim)
        #self.wc1 = nn.Linear(hidden_dim,hidden_dim)
        #self.wc2 = nn.Linear(hidden_dim,hidden_dim)

        #self.wf0 = nn.Linear(hidden_dim,hidden_dim)
        #self.wf1 = nn.Linear(hidden_dim,hidden_dim)
        #self.wf2 = nn.Linear(hidden_dim,hidden_dim)

        #self.mapp1 = nn.Linear(hidden_dim,hidden_dim)
        #self.mapp2 = nn.Linear(hidden_dim,hidden_dim)
        self.high_g = MLP(input_dim=4*hidden_dim,output_dim=classes[-1])
        self.num_classes = classes[-1]
        self.dropout0 = nn.Dropout(0.3)
        self.dropout1 = nn.Dropout(0.3)
        self.dropout2 = nn.Dropout(0.3)
        self.dropout3 = nn.Dropout(0.3)
        self.layer_size = 1
        self.hidden_size = hidden_dim
        self.h_layers = 1
        self.linears_g = nn.ModuleList([nn.Linear(4*hidden_dim,4*hidden_dim) for idx in range(self.h_layers)])
        self.linears_t = nn.ModuleList([nn.Linear(4*hidden_dim,4*hidden_dim) for idx in range(self.h_layers)])
        #self.lstm = nn.LSTM(input_dim,
        #                    hidden_dim,
        #                    self.layer_size,
        #                    dropout=0.5)
        #self.com_1 = nn.Sequential(nn.Linear(2*hidden_dim,hidden_dim),
        #                           nn.LeakyReLU())
        #self.com_2 = nn.Sequential(nn.Linear(3*hidden_dim,hidden_dim),
        #                           nn.LeakyReLU())

    def _highway_layer(self,input_,num_layers=1,act=nn.LeakyReLU()):
        for idx in range(num_layers):
            g = act(self.linears_g[idx](input_))
            t = torch.sigmoid(self.linears_t[idx](input_))
            output = t*g +(1.-t)*input_
        return output


    #def _hierarchical_violation(parent_scores,child_scores):
    #    index_list = [(0,15),(15,37),(52,20),(72,9),(81,7),(88,17),(106,14),(120,5),(125,3)]
    #    violation_losses = 0.0
    #    for i in range(len(index_list)):
    #        (left_index, step) = index_list[i]
    #        current_parent_scores = parent_scores[:,i]
    #        current_child_scores = child_scores[:,left_index:left_index+step]
    #        margin = torch.maximum((current_child_scores-current_parent_scores),0)
    #        losses = mean(sum(square(margin),axis=1))
    #        violation_losses = violation_losses + 0.1*losses


    def forward(self,embeds,tfidf,label=None,flag=None):

        vocab_size = tfidf.shape[1]
        batch_size = tfidf.shape[0]
        hx = torch.zeros(batch_size,self.hidden_size).cuda()
        cx = torch.zeros(batch_size,self.hidden_size).cuda()
        pad = torch.zeros(batch_size,self.hidden_size).cuda()

        # get the last output of GCN and compute the X
        embeds_0 = self.dropout0(embeds[0])
        rp0 = tfidf.mm(embeds_0[:-self.num_classes,:])
        #sc1 = rp1.mm(self.mapp1(embeds_1[vocab_size:vocab_size+self.classes[0],:]).t())
        #scs1 = torch.softmax(sc1,dim=1)
        #ct1 = scs1.mm(embeds_1[vocab_size:vocab_size+self.classes[0],:])
        rp0 = torch.cat((rp0,pad),dim=1)
        rp0, cx = self.lstm(rp0,(hx,cx))
        logits0 = self.local_0(rp0)
        lab0 = F.normalize(embeds_0[vocab_size:vocab_size+self.classes[0]:,:])
        corm0 = lab0.mm(lab0.t())
        corm0 = torch.softmax(corm0,dim=0)

        embeds_1 = self.dropout1(embeds[1])
        rp1 = tfidf.mm(embeds_1[:-self.num_classes,:])
        #sc2 = rp2.mm(
        #        self.mapp2(embeds_2[vocab_size+self.classes[0]:vocab_size+self.classes[1],:]).t())
        #scs2 = torch.softmax(sc2,dim=1)
        #ct2 = scs2.mm(embeds_1[vocab_size+self.classes[0]:vocab_size+self.classes[1],:])
        # compute the second layer hidden state and output logits2
        #rp2 = self.com_2(torch.cat((rp2,ct1,ct2),dim=1))

        #rp1_ = torch.cat((rp1,logits0),dim=1)
        lab_ct0 = torch.softmax(logits0,dim=1).mm(embeds_0[vocab_size:vocab_size+self.classes[0],:])

        #rp1 = self.fc_1(torch.cat((self.wf0(rp1),self.wc0(lab_ct0)),dim=1))
        rp1 = torch.cat((rp1,lab_ct0),dim=1)
        rp1, cx = self.lstm(rp1,(rp0,cx))
        logits1 = self.local_1(rp1)
        lab1 = F.normalize(embeds_1[vocab_size+self.classes[0]:vocab_size+self.classes[1],:])
        corm1 = lab1.mm(lab1.t())
        corm1 = torch.softmax(corm1,dim=0)
        # compute the logitsG

        embeds_2 = self.dropout2(embeds[2])
        rp2 = tfidf.mm(embeds_2[:-self.num_classes,:])

        lab_ct1 = torch.softmax(logits1,dim=1).mm(embeds_1[vocab_size+self.classes[0]:vocab_size+self.classes[1],:])
        rp2 = torch.cat((rp2,lab_ct1),dim=1)
        #rp2 = self.fc_2(torch.cat((self.wf1(rp2),self.wc1(lab_ct1)),dim=1))
        #rp2_ = torch.cat((rp2,logits1),dim=1)
        rp2, cx = self.lstm(rp2,(rp1,cx))
        logits2 = self.local_2(rp2)
        lab2 = F.normalize(embeds_2[vocab_size+self.classes[1]:vocab_size+self.classes[2],:])
        corm2 = lab2.mm(lab2.t())
        corm2 = torch.softmax(corm2,dim=0)

        embeds_3 = self.dropout3(embeds[3])
        rp3 = tfidf.mm(embeds_3[:-self.num_classes,:])
        lab_ct2 = torch.softmax(logits2,dim=1).mm(embeds_2[vocab_size+self.classes[1]:vocab_size+self.classes[2],:])
        rp3 = torch.cat((rp3,lab_ct2),dim=1)

        #rp3 = self.fc_3(torch.cat((self.wf2(rp3),self.wc2(lab_ct2)),dim=1))
        rp3, cx = self.lstm(rp3,(rp2,cx))

        #rp3_ = torch.cat((rp3,logits2),dim=1)
        logits3 = self.local_3(rp3)
        lab3 = F.normalize(embeds_3[vocab_size+self.classes[2]:vocab_size+self.classes[3],:])
        corm3 = lab3.mm(lab3.t())
        corm3 = torch.softmax(corm3,dim=0)

        #logitsG = self._highway_layer(torch.cat((rp0,rp1),dim=1))
        logitsG = self._highway_layer(torch.cat((rp0,rp1,rp2,rp3),dim=1))
        logitsG = self.high_g(logitsG)

        lab_loss0 = self.lab_cls0(embeds_0[vocab_size:vocab_size+self.classes[0],:])
        lab_loss1 = self.lab_cls1(embeds_1[vocab_size+self.classes[0]:vocab_size+self.classes[1],:])
        lab_loss2 = self.lab_cls2(embeds_2[vocab_size+self.classes[1]:vocab_size+self.classes[2],:])
        lab_loss3 = self.lab_cls3(embeds_3[vocab_size+self.classes[2]:vocab_size+self.classes[3],:])

        return {"0":logits0,
                "1":logits1,
                "2":logits2,
                "3":logits3,
                "c_0":corm0,
                "c_1":corm1,
                "c_2":corm2,
                "c_3":corm3,
                "l_0":lab_loss0,
                "l_1":lab_loss1,
                "l_2":lab_loss2,
                "l_3":lab_loss3,
                "g":logitsG}





class SuperCls(nn.Module):
    '''
    this classifier
    '''
    def __init__(self,input_dim=300,hidden_dim=300,classes=[9,137]):
        super(SuperCls,self).__init__()
        self.classes = classes
        self.local_0 = MLP(input_dim=1*hidden_dim,output_dim=classes[0])
        self.local_1 = MLP(input_dim=1*hidden_dim,output_dim=classes[1]-classes[0])
        #self.local_2 = MLP(input_dim=1*hidden_dim,output_dim=classes[2]-classes[1])
        #self.local_4 = MLP(input_dim=1*hidden_dim,output_dim=classes[4]-classes[3])
        #self.local_0 = nn.Linear(input_dim,classes[0])
        #self.local_1 = nn.Linear(input_dim,classes[1]-classes[0])
        self.lab_cls0 = MLP(input_dim=input_dim,output_dim=classes[0])
        self.lab_cls1 = MLP(input_dim=input_dim,output_dim=classes[1]-classes[0])
        #self.lab_cls2 = MLP(input_dim=input_dim,output_dim=classes[2]-classes[1])
        self.lab_clsg = MLP(input_dim=input_dim,output_dim=classes[-1])
        self.mapp1 = nn.Linear(hidden_dim,hidden_dim)
        self.mapp0 = nn.Linear(hidden_dim,hidden_dim)
        self.wc0 = nn.Linear(hidden_dim,hidden_dim)
        self.wc1 = nn.Linear(hidden_dim,hidden_dim)
        self.wf0 = nn.Linear(hidden_dim,hidden_dim)
        self.wf1 = nn.Linear(hidden_dim,hidden_dim)
        self.high_g = MLP(input_dim=2*hidden_dim,output_dim=classes[-1])
        self.num_classes = classes[-1]
        self.dropout0 = nn.Dropout(0.5)
        self.dropout1 = nn.Dropout(0.5)
        self.dropout2 = nn.Dropout(0.5)
        self.dropoutg = nn.Dropout(0.5)
        self.dropout0_ = nn.Dropout(0.5)
        self.dropout1_ = nn.Dropout(0.5)
        self.dropoutg_ = nn.Dropout(0.5)
        self.layer_size = 1
        self.hidden_size = hidden_dim
        self.lstm = nn.LSTMCell(2*input_dim,hidden_dim)
        self.h_layers = 1
        self.linears_g = nn.ModuleList([nn.Linear(2*hidden_dim,2*hidden_dim) for idx in range(self.h_layers)])
        self.linears_t = nn.ModuleList([nn.Linear(2*hidden_dim,2*hidden_dim) for idx in range(self.h_layers)])
        #self.lstm = nn.LSTM(input_dim,
        #                    hidden_dim,
        #                    self.layer_size,
        #                    dropout=0.5)
        self.fc_0 = nn.Sequential(nn.Linear(2*input_dim,input_dim),
                                   nn.LeakyReLU())
        self.fc_1 = nn.Sequential(nn.Linear(2*input_dim,input_dim),
                                   nn.LeakyReLU())
        self.fc_2 = nn.Sequential(nn.Linear(2*input_dim,input_dim),
                                  nn.LeakyReLU())

        self.g_1 = nn.Linear(hidden_dim,hidden_dim)
        self.t_1 = nn.Linear(hidden_dim,hidden_dim)
        self.ws = nn.ModuleList([nn.ModuleList([nn.Linear(hidden_dim,hidden_dim) for _ in range(3)]) for i in range(2)])

    def _highway_layer(self,input_,num_layers=1,act=nn.LeakyReLU()):
        for idx in range(num_layers):
            g = act(self.linears_g[idx](input_))
            t = torch.sigmoid(self.linears_t[idx](input_))
            output = t*g +(1.-t)*input_
        return output


    def _hierarchical_violation(parent_scores,child_scores):
        index_list = [(0,15),(15,37),(52,20),(72,9),(81,7),(88,17),(106,14),(120,5),(125,3)]
        violation_losses = 0.0
        for i in range(len(index_list)):
            (left_index, step) = index_list[i]
            current_parent_scores = parent_scores[:,i]
            current_child_scores = child_scores[:,left_index:left_index+step]
            margin = torch.maximum((current_child_scores-current_parent_scores),0)
            losses = mean(sum(square(margin),axis=1))
            violation_losses = violation_losses + 0.1*losses

    def lab_reps(self,lab_vec,indx):
        WQ = self.ws[indx][0](lab_vec) # num x hidden
        WK = self.ws[indx][1](lab_vec)
        WV = self.ws[indx][2](lab_vec)
        QK = WQ.matmul(WK.t())/10 # num x num
        QK = torch.softmax(QK,dim=1)
        V  = QK.matmul(WV)
        return V


    def forward(self,embeds,tfidf,labels=None,flag="test"):
        #print(embeds)

        vocab_size = tfidf.shape[1]
        batch_size = tfidf.shape[0]

        hx = torch.zeros(batch_size,self.hidden_size).cuda()
        cx = torch.zeros(batch_size,self.hidden_size).cuda()
        pad = torch.zeros(batch_size,self.hidden_size).cuda()

        # get the last output of GCN and compute the X
        embeds_0 = self.dropout0(embeds[0])
        rp0 = tfidf.mm(embeds_0[:-self.num_classes,:])
        #lab0 = self.lab_reps(embeds_0[vocab_size:vocab_size+self.classes[0],:],0)
        sc0 = rp0.mm(self.mapp0(embeds_0[vocab_size:vocab_size+self.classes[0],:]).t())
        #sc0 = rp0.mm(lab0.t())
        scs0 = torch.softmax(sc0,dim=1)
        #ct0 = scs0.mm(embeds_0[vocab_size:vocab_size+self.classes[0],:])
        #rp0 = self.fc_0(torch.cat((rp0,ct0),dim=1))
        #rp0 = self.dropout0(rp0)
        rp0 = torch.cat((rp0,pad),dim=1)
        rp0 , cx = self.lstm(rp0,(hx,cx))
        logits0 = self.local_0(rp0)
        lab0 = F.normalize(embeds_0[vocab_size:vocab_size+self.classes[0]:,:])
        corm0 = lab0.mm(lab0.t())
        corm0 = torch.softmax(corm0,dim=0)

        embeds_1 = self.dropout1(embeds[1])
        rp1 = tfidf.mm(embeds_1[:-self.num_classes,:])
        sc1 = rp1.mm(
                self.mapp1(embeds_1[vocab_size+self.classes[0]:vocab_size+self.classes[1],:]).t())
        scs1 = torch.softmax(sc1,dim=1)
        ct1 = scs1.mm(embeds_1[vocab_size+self.classes[0]:vocab_size+self.classes[1],:])
        # compute the second layer hidden state and output logits2
        #rp2 = self.com_2(torch.cat((rp2,ct1,ct2),dim=1))
        #g_1 = torch.relu(self.g_1(torch.cat((rp0,rp1),dim=1)))
        #t_1 = torch.sigmoid(self.t_1(rp0))
        #rp1 = t_1*rp1 +(1.-t_1)*rp0
        #print(rp0.shape)
        #print(rp1.shape)
        #lab_ct0 = torch.softmax(logits0,dim=1).mm(embeds_0[vocab_size:vocab_size+self.classes[0],:])
        if flag=="test":
            lab_ct0 = torch.softmax(logits0*5,dim=1).mm(embeds_0[vocab_size:vocab_size+self.classes[0],:])
        elif flag=="train":
            lab_ct0 = torch.softmax(labels[:,:self.classes[0]],dim=1).mm(
                embeds_0[vocab_size:vocab_size+self.classes[0],:])


        #rp1 = self.fc_1(torch.cat((self.wf0(rp1),self.wc0(lab_ct0)),dim=1))
        rp1 = torch.cat((rp1,lab_ct0),dim=1)
        rp1,cx = self.lstm(rp1,(rp0,cx))
        logits1 = self.local_1(rp1)

        lab1 = F.normalize(embeds_1[vocab_size+self.classes[0]:vocab_size+self.classes[1],:])
        corm1 = lab1.mm(lab1.t())
        corm1 = torch.softmax(corm1,dim=0)
        #logits1 = logits1.matmul(corm1) + logits1
        # compute the logitsG

        #embeds_2 = self.dropout2(embeds[2])
        #rp2 = tfidf.mm(embeds_2[:-self.num_classes,:])
        #lab2 = F.normalize(embeds_2[vocab_size+self.classes[1]:vocab_size+self.classes[2],:])
        #corm2 = lab2.mm(lab2.t())
        #corm2 = torch.softmax(corm2,dim=0)
        #if flag=="test":
        #    lab_ct1 = torch.softmax(logits1,dim=1).mm(embeds_1[vocab_size+self.classes[0]:vocab_size+self.classes[1],:])
        #elif flag=="train":
        #    lab_ct1 = torch.softmax(labels[:,self.classes[0]:self.classes[1]],dim=1).mm(
        #        embeds_1[vocab_size+self.classes[0]:vocab_size+self.classes[1],:])

        #rp2 = self.fc_2(torch.cat((self.wf1(rp2),self.wc1(lab_ct1)),dim=1))
        #logits2 = self.local_2(rp2)


        logitsG = self._highway_layer(torch.cat((rp0,rp1),dim=1))
        #rpG = self._highway_layer(torch.cat((rp0,rp1,rp2),dim=1))
        #rpG = self.dropoutg(rpG)
        logitsG = self.high_g(logitsG)
        #print(logitsG)
        #m0,_ = torch.max(logitsG,dim=0)
        #print(m0)
        #embeds_g = self.dropoutg(embeds[2])
        #rpG = tfidf.mm(embeds_g[:-self.num_classes,:])
        #logitsG = self.high_g(rpG)
        #labG = F.normalize(embeds_g[-self.num_classes:,:])
        #cormg = labG.mm(labG.t())
        #cormg = torch.softmax(cormg,dim=0)

        #labg = F.normalize(embeds[1][-self.classes[-1]:,:])
        #corm = labg.mm(labg.t())
        #corm = torch.softmax(corm,dim=1)
        #logitsG = logitsG.matmul(corm)

        lab_loss0 = self.lab_cls0(embeds_0[vocab_size:vocab_size+self.classes[0],:])
        lab_loss1 = self.lab_cls1(embeds_1[vocab_size+self.classes[0]:vocab_size+self.classes[1],:])
        #lab_loss2 = self.lab_cls2(embeds_2[vocab_size+self.classes[1]:vocab_size+self.classes[2],:])
        #lab_lossg = self.lab_clsg(embeds_g[vocab_size:vocab_size+self.classes[1],:])
        #print(lab_loss1)
        #print(lab_loss2)
        #lab_loss = torch.cat((lab_loss1,lab_loss2),dim=1)

        div_loss = torch.mean(torch.cosine_similarity(embeds_0,embeds_1,dim=1)**2)
        #print("--------------")
        #print(logits0.tolist()[0])
        #print(logits1.tolist()[0])
        #print(logits2.tolist()[0])
        #print(logitsG.tolist()[0])

        return {"0":logits0,
                "1":logits1,
                #"2":logits2,
                "g":logitsG,
                'corm0':corm0,
                'corm1':corm1,
                #'corm2':corm2,
                'sc0': sc0,
                'sc1': sc1,
                'lab_loss0':lab_loss0,
                'lab_loss1':lab_loss1,
                #'lab_loss2':lab_loss2,
                'div_loss':div_loss}



class SuperClsL(nn.Module):
    '''
    this classifier
    '''
    def __init__(self,input_dim=300,hidden_dim=300,classes=[9,137]):
        super(SuperClsL,self).__init__()
        self.classes = classes
        self.local_0 = MLP(input_dim=1*hidden_dim,output_dim=classes[0])
        self.local_1 = MLP(input_dim=1*hidden_dim,output_dim=classes[1]-classes[0])
        self.local_2 = MLP(input_dim=1*hidden_dim,output_dim=classes[2]-classes[1])
        #self.local_4 = MLP(input_dim=1*hidden_dim,output_dim=classes[4]-classes[3])
        #self.local_0 = nn.Linear(input_dim,classes[0])
        #self.local_1 = nn.Linear(input_dim,classes[1]-classes[0])
        self.lab_cls0 = MLP(input_dim=input_dim,output_dim=classes[0])
        self.lab_cls1 = MLP(input_dim=input_dim,output_dim=classes[1]-classes[0])
        self.lab_cls2 = MLP(input_dim=input_dim,output_dim=classes[2]-classes[1])
        self.lab_clsg = MLP(input_dim=input_dim,output_dim=classes[-1])
        self.mapp1 = nn.Linear(hidden_dim,hidden_dim)
        self.mapp0 = nn.Linear(hidden_dim,hidden_dim)
        self.wc0 = nn.Linear(hidden_dim,hidden_dim)
        self.wc1 = nn.Linear(hidden_dim,hidden_dim)
        self.wf0 = nn.Linear(hidden_dim,hidden_dim)
        self.wf1 = nn.Linear(hidden_dim,hidden_dim)
        self.high_g = MLP(input_dim=3*hidden_dim,output_dim=classes[-1])
        self.num_classes = classes[-1]
        self.dropout0 = nn.Dropout(0.5)
        self.dropout1 = nn.Dropout(0.5)
        self.dropout2 = nn.Dropout(0.5)
        self.dropoutg = nn.Dropout(0.5)
        self.dropout0_ = nn.Dropout(0.5)
        self.dropout1_ = nn.Dropout(0.5)
        self.dropoutg_ = nn.Dropout(0.5)
        self.layer_size = 1
        self.hidden_size = hidden_dim
        self.h_layers = 1
        self.linears_g = nn.ModuleList([nn.Linear(3*hidden_dim,3*hidden_dim) for idx in range(self.h_layers)])
        self.linears_t = nn.ModuleList([nn.Linear(3*hidden_dim,3*hidden_dim) for idx in range(self.h_layers)])
        self.lstm = nn.LSTMCell(2*hidden_dim,hidden_dim)
        #self.lstm = nn.LSTM(input_dim,
        #                    hidden_dim,
        #                    self.layer_size,
        #                    dropout=0.5)
        self.fc_0 = nn.Sequential(nn.Linear(2*input_dim,input_dim),
                                   nn.LeakyReLU())
        self.fc_1 = nn.Sequential(nn.Linear(2*input_dim,input_dim),
                                   nn.LeakyReLU())
        self.fc_2 = nn.Sequential(nn.Linear(2*input_dim,input_dim),
                                  nn.LeakyReLU())

        self.g_1 = nn.Linear(hidden_dim,hidden_dim)
        self.t_1 = nn.Linear(hidden_dim,hidden_dim)
        self.ws = nn.ModuleList([nn.ModuleList([nn.Linear(hidden_dim,hidden_dim) for _ in range(3)]) for i in range(2)])

    def _highway_layer(self,input_,num_layers=1,act=nn.LeakyReLU()):
        for idx in range(num_layers):
            g = act(self.linears_g[idx](input_))
            t = torch.sigmoid(self.linears_t[idx](input_))
            output = t*g +(1.-t)*input_
        return output


    def _hierarchical_violation(parent_scores,child_scores):
        index_list = [(0,15),(15,37),(52,20),(72,9),(81,7),(88,17),(106,14),(120,5),(125,3)]
        violation_losses = 0.0
        for i in range(len(index_list)):
            (left_index, step) = index_list[i]
            current_parent_scores = parent_scores[:,i]
            current_child_scores = child_scores[:,left_index:left_index+step]
            margin = torch.maximum((current_child_scores-current_parent_scores),0)
            losses = mean(sum(square(margin),axis=1))
            violation_losses = violation_losses + 0.1*losses

    def lab_reps(self,lab_vec,indx):
        WQ = self.ws[indx][0](lab_vec) # num x hidden
        WK = self.ws[indx][1](lab_vec)
        WV = self.ws[indx][2](lab_vec)
        QK = WQ.matmul(WK.t())/10 # num x num
        QK = torch.softmax(QK,dim=1)
        V  = QK.matmul(WV)
        return V


    def forward(self,embeds,tfidf,labels=None,flag="test"):
        #print(embeds)

        vocab_size = tfidf.shape[1]
        batch_size = tfidf.shape[0]

        hx = torch.zeros(batch_size,self.hidden_size).cuda()
        cx = torch.zeros(batch_size,self.hidden_size).cuda()
        pad = torch.zeros(batch_size,self.hidden_size).cuda()

        # get the last output of GCN and compute the X
        embeds_0 = self.dropout0(embeds[0])
        rp0 = tfidf.mm(embeds_0[:-self.num_classes,:])
        #lab0 = self.lab_reps(embeds_0[vocab_size:vocab_size+self.classes[0],:],0)
        #sc0 = rp0.mm(self.mapp0(embeds_0[vocab_size:vocab_size+self.classes[0],:]).t())
        #sc0 = rp0.mm(lab0.t())
        #scs0 = torch.softmax(sc0,dim=1)
        #ct0 = scs0.mm(embeds_0[vocab_size:vocab_size+self.classes[0],:])
        #rp0 = self.fc_0(torch.cat((rp0,ct0),dim=1))
        #rp0 = self.dropout0(rp0)
        rp0 = torch.cat((rp0,pad),dim=1)
        rp0,cx = self.lstm(rp0,(hx,cx))

        logits0 = self.local_0(rp0)
        lab0 = F.normalize(embeds_0[vocab_size:vocab_size+self.classes[0]:,:])
        corm0 = lab0.mm(lab0.t())
        corm0 = torch.softmax(corm0,dim=0)

        embeds_1 = self.dropout1(embeds[0])
        rp1 = tfidf.mm(embeds_1[:-self.num_classes,:])
        #sc1 = rp1.mm(
        #        self.mapp1(embeds_1[vocab_size+self.classes[0]:vocab_size+self.classes[1],:]).t())
        #scs1 = torch.softmax(sc1,dim=1)
        #ct1 = scs1.mm(embeds_1[vocab_size+self.classes[0]:vocab_size+self.classes[1],:])
        # compute the second layer hidden state and output logits2
        #rp2 = self.com_2(torch.cat((rp2,ct1,ct2),dim=1))
        #g_1 = torch.relu(self.g_1(torch.cat((rp0,rp1),dim=1)))
        #t_1 = torch.sigmoid(self.t_1(rp0))
        #rp1 = t_1*rp1 +(1.-t_1)*rp0
        #print(rp0.shape)
        #print(rp1.shape)
        #lab_ct0 = torch.softmax(logits0,dim=1).mm(embeds_0[vocab_size:vocab_size+self.classes[0],:])
        if flag=="test":
            lab_ct0 = torch.softmax(logits0,dim=1).mm(embeds_0[vocab_size:vocab_size+self.classes[0],:])
        elif flag=="train":
            lab_ct0 = torch.softmax(labels[:,:self.classes[0]],dim=1).mm(
                embeds_0[vocab_size:vocab_size+self.classes[0],:])
        #lab_ct0 = torch.softmax(logits0,dim=1).mm(embeds_0[vocab_size:vocab_size+self.classes[0],:])


        #rp1 = self.fc_1(torch.cat((self.wf0(rp1),self.wc0(lab_ct0)),dim=1))
        rp1 = torch.cat((rp1,lab_ct0),dim=1)
        rp1, cx = self.lstm(rp1,(rp0,cx))
        logits1 = self.local_1(rp1)

        lab1 = F.normalize(embeds_1[vocab_size+self.classes[0]:vocab_size+self.classes[1],:])
        corm1 = lab1.mm(lab1.t())
        corm1 = torch.softmax(corm1,dim=0)
        #logits1 = logits1.matmul(corm1) + logits1
        # compute the logitsG

        embeds_2 = self.dropout2(embeds[0])
        rp2 = tfidf.mm(embeds_2[:-self.num_classes,:])
        lab2 = F.normalize(embeds_2[vocab_size+self.classes[1]:vocab_size+self.classes[2],:])
        corm2 = lab2.mm(lab2.t())
        corm2 = torch.softmax(corm2,dim=0)
        if flag=="test":
            lab_ct1 = torch.softmax(logits1,dim=1).mm(embeds_1[vocab_size+self.classes[0]:vocab_size+self.classes[1],:])
        elif flag=="train":
            lab_ct1 = torch.softmax(labels[:,self.classes[0]:self.classes[1]],dim=1).mm(
                embeds_1[vocab_size+self.classes[0]:vocab_size+self.classes[1],:])
        #lab_ct1 = torch.softmax(logits1,dim=1).mm(embeds_1[vocab_size+self.classes[0]:vocab_size+self.classes[1],:])
        rp2 = torch.cat((rp2,lab_ct1),dim=1)
        #rp2 = self.fc_2(torch.cat((self.wf1(rp2),self.wc1(lab_ct1)),dim=1))
        rp2, cx = self.lstm(rp2,(rp1,cx))
        logits2 = self.local_2(rp2)


        #logitsG = self._highway_layer(torch.cat((rp0,rp1),dim=1))
        rpG = self._highway_layer(torch.cat((rp0,rp1,rp2),dim=1))
        #rpG = self.dropoutg(rpG)
        logitsG = self.high_g(rpG)
        #print(logitsG)
        #m0,_ = torch.max(logitsG,dim=0)
        #print(m0)
        #embeds_g = self.dropoutg(embeds[2])
        #rpG = tfidf.mm(embeds_g[:-self.num_classes,:])
        #logitsG = self.high_g(rpG)
        #labG = F.normalize(embeds_g[-self.num_classes:,:])
        #cormg = labG.mm(labG.t())
        #cormg = torch.softmax(cormg,dim=0)

        #labg = F.normalize(embeds[1][-self.classes[-1]:,:])
        #corm = labg.mm(labg.t())
        #corm = torch.softmax(corm,dim=1)
        #logitsG = logitsG.matmul(corm)

        lab_loss0 = self.lab_cls0(embeds_0[vocab_size:vocab_size+self.classes[0],:])
        lab_loss1 = self.lab_cls1(embeds_1[vocab_size+self.classes[0]:vocab_size+self.classes[1],:])
        lab_loss2 = self.lab_cls2(embeds_2[vocab_size+self.classes[1]:vocab_size+self.classes[2],:])

        div_loss = torch.mean(torch.cosine_similarity(embeds_0,embeds_1,dim=1)**2)

        return {"0":logits0,
                "1":logits1,
                "2":logits2,
                "g":logitsG,
                'corm0':corm0,
                'corm1':corm1,
                'corm2':corm2,
                #'sc0': sc0,
                #'sc1': sc1,
                'lab_loss0':lab_loss0,
                'lab_loss1':lab_loss1,
                'lab_loss2':lab_loss2,
                'div_loss':div_loss}



class SuperClsA(nn.Module):
    '''
    this classifier
    '''
    def __init__(self,input_dim=300,hidden_dim=300,classes=[9,137]):
        super(SuperClsA,self).__init__()
        self.classes = classes
        self.local_0 = MLP(input_dim=1*hidden_dim,output_dim=classes[0])
        self.local_1 = MLP(input_dim=1*hidden_dim,output_dim=classes[1]-classes[0])
        #self.local_4 = MLP(input_dim=1*hidden_dim,output_dim=classes[4]-classes[3])
        #self.local_0 = nn.Linear(input_dim,classes[0])
        #self.local_1 = nn.Linear(input_dim,classes[1]-classes[0])
        self.lab_cls0 = MLP(input_dim=input_dim,output_dim=classes[0])
        self.lab_cls1 = MLP(input_dim=input_dim,output_dim=classes[1]-classes[0])
        self.lab_clsg = MLP(input_dim=input_dim,output_dim=classes[-1])
        self.mapp1 = nn.Linear(hidden_dim,hidden_dim)
        self.mapp0 = nn.Linear(hidden_dim,hidden_dim)
        self.wc0 = nn.Linear(hidden_dim,hidden_dim)
        self.wc1 = nn.Linear(hidden_dim,hidden_dim)
        self.wf0 = nn.Linear(hidden_dim,hidden_dim)
        self.wf1 = nn.Linear(hidden_dim,hidden_dim)
        self.high_g = MLP(input_dim=2*hidden_dim,output_dim=classes[-1])
        self.num_classes = classes[-1]
        self.dropout0 = nn.Dropout(0.5)
        self.dropout1 = nn.Dropout(0.5)
        self.dropout2 = nn.Dropout(0.5)
        self.dropoutg = nn.Dropout(0.5)
        self.dropout0_ = nn.Dropout(0.5)
        self.dropout1_ = nn.Dropout(0.5)
        self.dropoutg_ = nn.Dropout(0.5)
        self.layer_size = 1
        self.hidden_size = hidden_dim
        self.h_layers = 1
        self.linears_g = nn.ModuleList([nn.Linear(2*hidden_dim,2*hidden_dim) for idx in range(self.h_layers)])
        self.linears_t = nn.ModuleList([nn.Linear(2*hidden_dim,2*hidden_dim) for idx in range(self.h_layers)])
        #self.lstm = nn.LSTM(input_dim,
        #                    hidden_dim,
        #                    self.layer_size,
        #                    dropout=0.5)
        self.fc_0 = nn.Sequential(nn.Linear(2*input_dim,input_dim),
                                   nn.LeakyReLU())
        self.fc_1 = nn.Sequential(nn.Linear(2*input_dim,input_dim),
                                   nn.LeakyReLU())
        self.fc_2 = nn.Sequential(nn.Linear(2*input_dim,input_dim),
                                  nn.LeakyReLU())

        self.g_1 = nn.Linear(hidden_dim,hidden_dim)
        self.t_1 = nn.Linear(hidden_dim,hidden_dim)
        self.ws = nn.ModuleList([nn.ModuleList([nn.Linear(hidden_dim,hidden_dim) for _ in range(3)]) for i in range(2)])

    def _highway_layer(self,input_,num_layers=1,act=nn.LeakyReLU()):
        for idx in range(num_layers):
            g = act(self.linears_g[idx](input_))
            t = torch.sigmoid(self.linears_t[idx](input_))
            output = t*g +(1.-t)*input_
        return output


    def _hierarchical_violation(parent_scores,child_scores):
        index_list = [(0,15),(15,37),(52,20),(72,9),(81,7),(88,17),(106,14),(120,5),(125,3)]
        violation_losses = 0.0
        for i in range(len(index_list)):
            (left_index, step) = index_list[i]
            current_parent_scores = parent_scores[:,i]
            current_child_scores = child_scores[:,left_index:left_index+step]
            margin = torch.maximum((current_child_scores-current_parent_scores),0)
            losses = mean(sum(square(margin),axis=1))
            violation_losses = violation_losses + 0.1*losses

    def lab_reps(self,lab_vec,indx):
        WQ = self.ws[indx][0](lab_vec) # num x hidden
        WK = self.ws[indx][1](lab_vec)
        WV = self.ws[indx][2](lab_vec)
        QK = WQ.matmul(WK.t())/10 # num x num
        QK = torch.softmax(QK,dim=1)
        V  = QK.matmul(WV)
        return V


    def forward(self,embeds,tfidf):
        #print(embeds)

        vocab_size = tfidf.shape[1]
        batch_size = tfidf.shape[0]

        # get the last output of GCN and compute the X
        embeds_0 = self.dropout0(embeds[0])
        rp0 = tfidf.mm(embeds_0[:-self.num_classes,:])
        #lab0 = self.lab_reps(embeds_0[vocab_size:vocab_size+self.classes[0],:],0)
        sc0 = rp0.mm(self.mapp0(embeds_0[vocab_size:vocab_size+self.classes[0],:]).t())
        #sc0 = rp0.mm(lab0.t())
        scs0 = torch.softmax(sc0,dim=1)
        #ct0 = scs0.mm(embeds_0[vocab_size:vocab_size+self.classes[0],:])
        #rp0 = self.fc_0(torch.cat((rp0,ct0),dim=1))
        #rp0 = self.dropout0(rp0)
        logits0 = self.local_0(rp0)
        #print("--------------")
        #print(logits0)
        #m0,_ = torch.max(logits0,dim=0)
        #print(m0)

        lab0 = F.normalize(embeds_0[vocab_size:vocab_size+self.classes[0]:,:])
        corm0 = lab0.mm(lab0.t())
        corm0 = torch.softmax(corm0,dim=0)
        #logits0 = logits0.matmul(corm0) + logits0

        embeds_1 = self.dropout1(embeds[1])
        #embeds_1 = embeds[1]
        rp1 = tfidf.mm(embeds_1[:-self.num_classes,:])
        #lab1 = self.lab_reps(embeds_1[vocab_size+self.classes[0]:vocab_size+self.classes[1],:],1)
        #sc0 = rp0.mm(self.mapp0(embeds_0[vocab_size:vocab_size+self.classes[0],:]).t())
        #sc1 = rp1.mm(lab1.t())
        sc1 = rp1.mm(
                self.mapp1(embeds_1[vocab_size+self.classes[0]:vocab_size+self.classes[1],:]).t())
        scs1 = torch.softmax(sc1,dim=1)
        ct1 = scs1.mm(embeds_1[vocab_size+self.classes[0]:vocab_size+self.classes[1],:])
        # compute the second layer hidden state and output logits2
        #rp2 = self.com_2(torch.cat((rp2,ct1,ct2),dim=1))
        #g_1 = torch.relu(self.g_1(torch.cat((rp0,rp1),dim=1)))
        #t_1 = torch.sigmoid(self.t_1(rp0))
        #rp1 = t_1*rp1 +(1.-t_1)*rp0
        #print(rp0.shape)
        #print(rp1.shape)
        #lab_ct0 = torch.softmax(logits0,dim=1).mm(embeds_0[vocab_size:vocab_size+self.classes[0],:])
        lab_ct0 = torch.sigmoid(logits0).mm(embeds_0[vocab_size:vocab_size+self.classes[0],:])

        rp1 = self.fc_1(torch.cat((self.wf0(rp1),self.wc0(lab_ct0)),dim=1))
        logits1 = self.local_1(rp1)
        #print(logits1)
        #m0,_ = torch.max(logits1,dim=0)
        #print(m0)

        lab1 = F.normalize(embeds_1[vocab_size+self.classes[0]:vocab_size+self.classes[1],:])
        corm1 = lab1.mm(lab1.t())
        corm1 = torch.softmax(corm1,dim=0)
        #logits1 = logits1.matmul(corm1) + logits1
        # compute the logitsG

        #embeds_2 = self.dropout2(embeds[2])
        #rp2 = tfidf.mm(embeds_2[:-self.num_classes,:])
        #lab2 = F.normalize(embeds_2[vocab_size+self.classes[1]:vocab_size+self.classes[2],:])
        #corm2 = lab2.mm(lab2.t())
        #corm2 = torch.softmax(corm2,dim=0)
        #embeds_1 = embeds[1]
        #lab_ct1 = torch.softmax(logits1,dim=1).mm(embeds_1[vocab_size+self.classes[0]:vocab_size+self.classes[1],:])
        #lab_ct1 = torch.sigmoid(logits1).mm(embeds_1[vocab_size+self.classes[0]:vocab_size+self.classes[1],:])
        #rp2 = self.fc_2(torch.cat((self.wf1(rp2),self.wc1(lab_ct1)),dim=1))
        #logits2 = self.local_2(rp2)


        #logitsG = self._highway_layer(torch.cat((rp0,rp1),dim=1))
        rpG = self._highway_layer(torch.cat((rp0,rp1),dim=1))
        #rpG = self.dropoutg(rpG)
        logitsG = self.high_g(rpG)
        #print(logitsG)
        #m0,_ = torch.max(logitsG,dim=0)
        #print(m0)
        #embeds_g = self.dropoutg(embeds[2])
        #rpG = tfidf.mm(embeds_g[:-self.num_classes,:])
        #logitsG = self.high_g(rpG)
        #labG = F.normalize(embeds_g[-self.num_classes:,:])
        #cormg = labG.mm(labG.t())
        #cormg = torch.softmax(cormg,dim=0)

        #labg = F.normalize(embeds[1][-self.classes[-1]:,:])
        #corm = labg.mm(labg.t())
        #corm = torch.softmax(corm,dim=1)
        #logitsG = logitsG.matmul(corm)

        lab_loss0 = self.lab_cls0(embeds_0[vocab_size:vocab_size+self.classes[0],:])
        lab_loss1 = self.lab_cls1(embeds_1[vocab_size+self.classes[0]:vocab_size+self.classes[1],:])

        div_loss = torch.mean(torch.cosine_similarity(embeds_0,embeds_1,dim=1)**2)
        #print("--------------")
        #print(logits0.tolist()[0])
        #print(logits1.tolist()[0])
        #print(logits2.tolist()[0])
        #print(logitsG.tolist()[0])

        return {"0":logits0,
                "1":logits1,
                "g":logitsG,
                'corm0':corm0,
                'corm1':corm1,
                'sc0': sc0,
                'sc1': sc1,
                'lab_loss0':lab_loss0,
                'lab_loss1':lab_loss1,
                'div_loss':div_loss}

