#!/usr/bin/python3
#-*- coding:utf-8 -*-
############################
#File Name: classifier.py
#Created Time:
############################
import torch
import torch.nn as nn
from models.mlp import ACT
import torch.nn.functional as F

class MyClassifier(nn.Module):
    def __init__(self,input_dim=300,hidden_dim=300,classes=[8]):
        super(MyClassifier,self).__init__()
        self.num_classes = classes[-1]
        self.w1 = nn.Linear(5*input_dim,hidden_dim)
        self.w2 = nn.Linear(hidden_dim,self.num_classes)
        self.wrp2 = nn.Linear(hidden_dim,self.num_classes)
        self.wrp1 = nn.Linear(hidden_dim,hidden_dim)
        self.label_w = nn.Linear(input_dim,self.num_classes)
        self.act = nn.LeakyReLU()
        self.com = ACT(2*input_dim,input_dim)
        self.map1 = nn.Linear(hidden_dim,hidden_dim)
        self.map2 = nn.Linear(hidden_dim,hidden_dim)
        self.map3 = nn.Linear(hidden_dim,hidden_dim)
        self.dropout1 = nn.Dropout(0.5)
        self.dropout2 = nn.Dropout(0.5)
        self.g_1 = nn.Linear(2*hidden_dim,2*hidden_dim)
        self.t_1 = nn.Linear(2*hidden_dim,2*hidden_dim)

    def forward(self,embeds,tfidf):
        rp0 = tfidf.mm(embeds[0][:-self.num_classes,:])
        #rp0 = self.dropout1(rp0)
        rp1 = tfidf.mm(embeds[1][:-self.num_classes,:])
        rp2 = tfidf.mm(embeds[2][:-self.num_classes,:])
        rp3 = tfidf.mm(embeds[3][:-self.num_classes,:])
        rp4 = tfidf.mm(embeds[4][:-self.num_classes,:])
        #print(rp0)
        #print(rp1)
        #rp2 = tfidf.mm(embeds[2][:-self.num_classes,:])
        #rp3 = tfidf.mm(embeds[3][:-self.num_classes,:])
        #rp2 = tfidf.mm(embeds[2][:-self.num_classes,:])
        #rp1 = self.dropout2(rp1)

        #sc0 = rp0.mm(embeds[0][-self.num_classes:,:].t())
        #sc1 = rp1.mm(embeds[1][-self.num_classes:,:].t())
        sc0 = rp0.mm(self.map1(embeds[0][-self.num_classes:,:]).t())
        sc1 = rp1.mm(self.map2(embeds[1][-self.num_classes:,:]).t())
        #sc2 = rp2.mm(self.map3(embeds[2][-self.num_classes:,:]).t())
        #rps = torch.cat((rp0,rp1,rp2),dim=1)
        #rps = torch.cat((rp3),dim=1)
        #g_1 = torch.relu(self.g_1(rps))
        #t_1 = torch.sigmoid(self.t_1(rps))
        #pre_sup = t_1*g_1+ (1.0 - t_1)*rps
        #rps = self.com(rps)
        #sc = rps.mm(embeds[1][-self.num_classes:,:].t())
        #sc = torch.softmax(sc,dim=1)
        #context = sc.mm(embeds[1][-self.num_classes:,:])
        rps = torch.cat((rp0,rp1,rp2,rp3,rp4),dim=1)
        #rp1 = self.dropout2(rp1)
        logits = self.w2(self.act(self.w1(rps)))
        #logits = self.wrp2(self.act(self.wrp1(rp4)))
        lab = self.label_w(embeds[1][-self.num_classes:,:])
        #lab_1 = F.normalize(embeds[0][-self.num_classes:,:])
        #lab_2 = F.normalize(embeds[2][-self.num_classes:,:])
        #lab_cos1 = lab_1.mm(lab_1.t())
        #lab_cos2 = lab_2.mm(lab_2.t())

        return {'g':logits,'1':sc0,'2':sc1,'lab':lab}
        #return {'g':logits,'1':sc0,'2':sc1,'3':0,'lab':lab,'lab_cos1':lab_cos1,'lab_cos2':lab_cos2}
