import sys
sys.path.append('..')
import fewshot_re_kit
import torch
from torch import autograd, optim, nn
from torch.autograd import Variable
from torch.nn import functional as F

class Proto(fewshot_re_kit.framework.FewShotREModel):
    
    def __init__(self, sentence_encoder, dot=False, relation_encoder=None, N=5, Q=1):
        fewshot_re_kit.framework.FewShotREModel.__init__(self, sentence_encoder)
        # self.fc = nn.Linear(hidden_size, hidden_size)
        self.drop = nn.Dropout()
        self.dot = dot
        
        
        self.relation_encoder = relation_encoder
        self.hidden_size = 768
        '''
        #***********************************************
        self.linear_t = nn.Linear(self.hidden_size*2, self.hidden_size)
        self.linear_h = nn.Linear(self.hidden_size*2, self.hidden_size)
        #self.linear = nn.Linear(self.hidden_size*3, self.hidden_size*2)
        #
        self.t_key = nn.Parameter(torch.randn(4*N*Q, self.hidden_size, self.hidden_size))
        self.h_key = nn.Parameter(torch.randn(4*N*Q, self.hidden_size, self.hidden_size))
        #************************************************
        '''
        self.linear = nn.Linear(self.hidden_size, self.hidden_size*2)
        self.gate_gol = nn.Linear(self.hidden_size*3, 1)
        self.gate_loc = nn.Linear(self.hidden_size*3, 1)
    
    def global_atten_entity(self, h_state, t_state, sequence_outputs, rel_vec=None, rel_gol=None):
        #the best model now, 2021/10/22, 86.12%
        t_temp0 = t_state.view(t_state.shape[0], 1, -1)
        
        #TODO use rel_vector to generate key vector [batch, dim]
        #[batch, max_len, dim] [batch, dim, dim]
        #*********************************
        if rel_vec is not None:
            m, n = rel_vec.shape
            rel_vec = rel_vec.view(m, 1, n).expand(m, sequence_outputs.shape[1], n)
            #import pdb
            #pdb.set_trace()
            t_key = torch.bmm(rel_vec, self.t_key)# + torch.bmm(rel_gol, self.t_key)
            h_key = torch.bmm(rel_vec, self.h_key)# + torch.bmm(rel_gol, self.h_key)
            t_temp = torch.softmax(torch.tanh(torch.matmul(t_key, t_temp0.permute(0,2,1))), 1)#.squeeze() ##[20, 128, 1]
        #*********************************
        else:
            t_temp = torch.softmax(torch.tanh(torch.matmul(sequence_outputs, t_temp0.permute(0,2,1))), 1)#.squeeze() ##[20, 128, 1]
        
        t_temp = t_temp.expand(sequence_outputs.shape[0], sequence_outputs.shape[1], sequence_outputs.shape[2])
        t_global_feature = torch.mean(t_temp * sequence_outputs, 1)
        t_state = torch.cat((t_state, t_global_feature), -1)
        t_state = self.linear_t(t_state)
        
        h_temp0 = h_state.view(h_state.shape[0], 1, -1)
        #import pdb
        #pdb.set_trace()
        if rel_vec is not None:
            h_temp = torch.softmax(torch.tanh(torch.matmul(h_key, h_temp0.permute(0,2,1))), 1)#.squeeze() ##[20, 128, 1]
        else:
            h_temp = torch.softmax(torch.tanh(torch.matmul(sequence_outputs, h_temp0.permute(0,2,1))), 1)#.squeeze() ##[20, 128, 1]
        h_temp = h_temp.expand(sequence_outputs.shape[0], sequence_outputs.shape[1], sequence_outputs.shape[2])
        h_global_feature = torch.mean(h_temp * sequence_outputs, 1)
        h_state = torch.cat((h_state, h_global_feature), -1)
        h_state = self.linear_h(h_state)
        
        final = torch.cat((h_state, t_state), -1)
        
        #return h_state, t_state
        return final
    
    def __dist__(self, x, y, dim):
        if self.dot:
            return (x * y).sum(dim)
        else:
            return -(torch.pow(x - y, 2)).sum(dim)

    def __batch_dist__(self, S, Q):
        return self.__dist__(S.unsqueeze(1), Q.unsqueeze(2), 3)

    def forward(self, support, query, rel_txt, N, K, total_Q, visual=None):
        '''
        support: Inputs of the support set.
        query: Inputs of the query set.
        N: Num of classes
        K: Num of instances for each class in the support set
        Q: Num of instances in the query set
        '''
        
        #import pdb
        #pdb.set_trace()
        ##get relation
        if self.relation_encoder:
            rel_gol, rel_loc = self.relation_encoder(rel_txt)
        else:
            rel_gol, rel_loc = self.sentence_encoder(rel_txt, cat=False)
        
        #import pdb
        #pdb.set_trace()
        
        rel_loc = torch.mean(rel_loc, 1) #[B*N, D]
        #rel_rep = (rel_loc + rel_gol) /2
        #rel_rep = rel_loc
        
        
        #import pdb
        #pdb.set_trace()
        
        
        #import pdb
        #pdb.set_trace()
        
        
        #rel_final = torch.cat((rel_gol, rel_loc), -1)
        #import pdb
        #pdb.set_trace()
        #TODO
        
        #support,  s_loc = self.sentence_encoder(support) # (B * N * K, D), where D is the hidden size
        #query,  q_loc = self.sentence_encoder(query) # (B * total_Q, D)
        
        
        support_h, support_t,  s_loc = self.sentence_encoder(support) # (B * N * K, D), where D is the hidden size
        query_h, query_t,  q_loc = self.sentence_encoder(query) # (B * total_Q, D)
        #support = self.global_atten_entity(support_h, support_t, s_loc, rel_loc, None)
        #query = self.global_atten_entity(query_h, query_t, q_loc, None, None)
        support = torch.cat((support_h, support_t), -1)
        query = torch.cat((query_h, query_t), -1)
        
        support = self.drop(support)
        query = self.drop(query)
        
        #support = self.linear(torch.cat((support, rel_loc), -1))
        
        #import pdb
        #pdb.set_trace()
        #####TODO
        
        #rel_loc_s = rel_loc.unsqueeze(1).expand(-1, K, -1).contiguous().view(s_loc.shape[0], -1)  # (B * N * K, D)
        #rel_loc_q = rel_loc.unsqueeze(1).expand(-1, int(total_Q/N), -1).contiguous().view(q_loc.shape[0], -1)  # (B * N * K, D)
        
        #glo_s = self.global_atten_relation(rel_loc_s, s_loc)
        #glo_q = self.global_atten_relation(rel_loc_q, q_loc)
        
        #support = torch.cat((support, glo_s), -1)
        #query = torch.cat((query, glo_q), -1)
        
        
        #support_emb, _ = self.sentence_encoder(support) # (B * N * K, D), where D is the hidden size
        #query_emb, _ = self.sentence_encoder(query) # (B * total_Q, D)
        #hidden_size = support_emb.size(-1)
        
        #support = self.drop(support_emb)
        #query = self.drop(query_emb)
        
        support = support.view(-1, N, K, self.hidden_size*2) # (B, N, K, D)
        query = query.view(-1, total_Q, self.hidden_size*2) # (B, total_Q, D)
        


        # Prototypical Networks 
        # Ignore NA policy
        support = torch.mean(support, 2) # Calculate prototype for each class
        #import pdb
        #pdb.set_trace()
        #import pdb
        #pdb.set_trace()
        ##
        gate_gol = torch.sigmoid(self.gate_gol(torch.cat((support.view(rel_gol.shape[0], -1), rel_gol),-1)))
        gate_loc = torch.sigmoid(self.gate_loc(torch.cat((support.view(rel_gol.shape[0], -1), rel_loc),-1)))
        rel_rep = rel_loc*gate_loc + rel_gol*gate_gol
        ###add relation into this this add a up relation dimension
        rel_rep = rel_rep.view(-1, N, rel_gol.shape[1])
        rel_rep = self.linear(rel_rep)
        support = support + rel_rep
        
        
        
        logits = self.__batch_dist__(support, query) # (B, total_Q, N)
        minn, _ = logits.min(-1)
        logits = torch.cat([logits, minn.unsqueeze(2) - 1], 2) # (B, total_Q, N + 1)
        _, pred = torch.max(logits.view(-1, N + 1), 1)
        
        if visual:
            return support, rel_rep, pred
        else:
            return logits, pred

    
    
    
