# coding: utf-8

import torch
import numpy as np
from torch import nn
from torch_geometric.nn import MessagePassing
from torch_geometric.utils import remove_self_loops, add_self_loops
from collections import defaultdict

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

class SelfAttention(nn.Module):
    ''' Self attention module '''
    def __init__(self, in_features, n_hidden=100):
        super(SelfAttention, self).__init__()

        self.key_linear = nn.Sequential(
            nn.Linear(in_features, n_hidden),
            nn.BatchNorm1d(n_hidden),
            nn.LeakyReLU(inplace=True),
            nn.Linear(n_hidden, 1)
        )

    def forward(self, x, lengths):
        ''' 
        Args:
            x <Tensor> of shape [batch_size, seq_len, rnn_hidden*2]
            lengths <LongTensor>: shape [batch_size], the length of each input
        '''
        logits = self.key_linear(x)
        # generate weights
        weights = self.masked_softmax(logits, lengths)
        # weighted summation
        summation = (x * weights.unsqueeze(2)).sum(dim=1)
        return summation

    def masked_softmax(self, matrix, lengths):
        ''' Softmax on a given length
        
        Args:
            matrix <Tensor> of shape [batch_size, seq_len]
            lengths <LongTensor>: shape [batch_size], indicates the length of sequence
                of each sample.
        Returns:
            the softmax logits
        '''
        batchsize, seq_len = matrix.size(0), matrix.size(1)
        
        mask = (torch.arange(seq_len, device=device).repeat([batchsize, 1]) < \
                lengths.view(-1, 1).repeat([1, seq_len])).unsqueeze(2)
        logits = (torch.exp(matrix) * mask).squeeze()
        logits = logits / (1e-10+torch.sum(logits, dim=1)).unsqueeze(1)
        return logits

class AttentionWeight(nn.Module):
    ''' Compute the unnormalized attention digit '''
    def __init__(self, in_features):
        super(AttentionWeight, self).__init__()
        self.linear = nn.Linear(in_features, in_features, bias=False)

    def forward(self, x_i, x_j):
        Wx_j = self.linear(x_j)
        
        digits = torch.sum(x_i * Wx_j, dim=1, keepdim=True)
        return digits

class GraphConv(MessagePassing):
    ''' Graph convolution layer '''
    def __init__(self, in_features, hidden_features, dropout=0.5):
        super(GraphConv, self).__init__(aggr='add')
        self.hidden_features = hidden_features

        self.l_l_attn = AttentionWeight(in_features)
        self.s_s_attn = AttentionWeight(in_features)
        self.l_s_attn = AttentionWeight(in_features)
        self.s_l_attn = AttentionWeight(in_features)

        self.l_l_layer = nn.Sequential(
            nn.Linear(in_features, hidden_features),
            # nn.Dropout(dropout),
            # nn.BatchNorm1d(hidden_features),
            # nn.ReLU()
            nn.Sigmoid()
        )
        self.s_s_layer = nn.Sequential(
            nn.Linear(in_features, hidden_features),
            # nn.Dropout(dropout),
            # nn.BatchNorm1d(hidden_features),
            # nn.ReLU()
            nn.Sigmoid()
        )
        self.l_s_layer = nn.Sequential(
            nn.Linear(in_features, hidden_features),
            # nn.Dropout(dropout),
            # nn.BatchNorm1d(hidden_features),
            # nn.ReLU()
            nn.Sigmoid()
        )
        self.s_l_layer = nn.Sequential(
            nn.Linear(in_features, hidden_features),
            # nn.Dropout(dropout),
            # nn.BatchNorm1d(hidden_features),
            # nn.ReLU()
            nn.Sigmoid()
        )

        self.update_layer = nn.Sequential(
            nn.Linear(hidden_features + in_features, in_features),
            # nn.Dropout(dropout),
            # nn.BatchNorm1d(in_features),
            # nn.ReLU()
            nn.Sigmoid()
        )

        self.sigmoid = nn.Sigmoid()


    def forward(self, x, edge_index, from_mask, to_mask):
        n_nodes = x.size(0)
        # edge_index, _ = remove_self_loops(edge_index)
        # edge_index, _ = add_self_loops(edge_index, num_nodes=n_nodes)
        return self.propagate(edge_index=edge_index, \
            size=(n_nodes, n_nodes), x=x, from_mask=from_mask, to_mask=to_mask)

    def message(self, x_i, x_j, from_mask, to_mask, edge_index_i, edge_index_j):
        '''
        mask digit == True means the node is label, False means it is sentence
        '''
        cpl_from_mask = ~from_mask
        cpl_to_mask = ~to_mask

        attention_digits = torch.empty([x_j.size(0), 1], device=device)

        ' both are label nodes '
        lbls_ = from_mask & to_mask
        ' both are sent nodes '
        sents_ = cpl_from_mask & cpl_to_mask
        ' from is label, to is sent '
        lbl_sent_ = from_mask & cpl_to_mask
        ' from is sent, to is label '
        sent_lbl_ = cpl_from_mask & to_mask

        lbl_inds =      self.get_mask_ind(lbls_)
        sent_inds =     self.get_mask_ind(sents_)
        lbl_sent_inds = self.get_mask_ind(lbl_sent_)
        sent_lbl_inds = self.get_mask_ind(sent_lbl_)

        # # attention  
        # attention_digits[lbl_inds] =        self.l_l_attn(x_i[lbl_inds],        x_j[lbl_inds])
        # attention_digits[sent_inds] =       self.s_s_attn(x_i[sent_inds],       x_j[sent_inds])
        # attention_digits[lbl_sent_inds] =   self.l_s_attn(x_i[lbl_sent_inds], x_j[lbl_sent_inds])
        # attention_digits[sent_lbl_inds] =   self.s_l_attn(x_i[sent_lbl_inds], x_j[sent_lbl_inds])

        # # restrict to 0-1
        # attention_digits = self.sigmoid(attention_digits)

        # attention  
        def _group_norm(digits, filter_inds):
            centre = edge_index_i[filter_inds]
            centre = centre.cpu().numpy().tolist()
            group_inds = defaultdict(list)
            for i, id_ in enumerate(centre):
                group_inds[id_].append(i)
            for groups in group_inds.values():
                digits[groups] = digits[groups]/digits[groups].sum()
            return digits

        attention_digits[lbl_inds] =    _group_norm(self.l_l_attn(x_i[lbl_inds],        x_j[lbl_inds]), lbl_inds)
        attention_digits[sent_inds] =   _group_norm(self.s_s_attn(x_i[sent_inds],       x_j[sent_inds]), sent_inds)
        attention_digits[lbl_sent_inds] =   _group_norm(self.l_s_attn(x_i[lbl_sent_inds], x_j[lbl_sent_inds]), lbl_sent_inds)
        attention_digits[sent_lbl_inds] =   _group_norm(self.s_l_attn(x_i[sent_lbl_inds], x_j[sent_lbl_inds]), sent_lbl_inds)

        # restrict to 0-1
        attention_digits = self.sigmoid(attention_digits)

        # fc layer
        hidden_vec = torch.empty([x_j.size(0), self.hidden_features], device=device)

        hidden_vec[lbl_inds] =      self.l_l_layer(x_j[lbl_inds])
        hidden_vec[sent_inds] =     self.s_s_layer(x_j[sent_inds])
        hidden_vec[lbl_sent_inds] = self.l_s_layer(x_j[lbl_sent_inds])
        hidden_vec[sent_lbl_inds] = self.s_l_layer(x_j[sent_lbl_inds])
        
        return attention_digits * hidden_vec

    def update(self, aggr_out, x):
        input_ = torch.cat([aggr_out, x], dim=1)
        return self.update_layer(input_)

    def get_mask_ind(self, mask):
        ''' Retrieve the non-zero element's indices. '''
        # torch.topk with k= torch.sum(mask)
        k = torch.sum(mask)
        _, inds = mask.int().topk(k, sorted=False)
        return inds

class HierarchicalLossNormalization(nn.Module):
    ''' The hierarchical loss normalization module mentioned in paper 
    Neural Fine-Grained Entity Type Classification
            with Hierarchy-Aware Loss
    '''
    def __init__(self, ontology_fn, beta=0.3):
        super(HierarchicalLossNormalization, self).__init__()
        with open(ontology_fn) as f:
            lbls = f.read().strip().split('\n')
        self.lbl2id = {j:i for i,j in enumerate(lbls)}
        self.lbl_transform = self.create_hierarchy(lbls, beta)
    
    def forward(self, preds, mode='mul'):
        ''' Input the probability distribution of prediction, output the updated distribution.
        ''' 
        if mode == 'sum':
            return self.forward_sum(preds)
        elif mode == 'mul':
            return self.forward_mul(preds)

    def forward_sum(self, preds):
        ''' Sum all the probabilities of A's ancestors '''
        outputs = torch.matmul(preds, self.lbl_transform)
        # renormalize
        outputs.clamp_(min=1e-10, max=1.0)
        return outputs

    def forward_mul(self, preds):
        ''' Multiply all the probabilities of A's ancestors '''
        tmptensor = torch.einsum('ma,ab->mab', preds, self.lbl_transform)
        valid_mask = tmptensor != 0
        tmptensor.clamp_min_(1e-10).log_()
        tmptensor = tmptensor * valid_mask
        tmptensor = torch.sum(tmptensor, dim=1)
        tmptensor.exp_()
        return tmptensor
    
    def create_hierarchy(self, lbls, beta=0.3):
        # build hierarchy
        from collections import defaultdict
        hier_dict = defaultdict(list)
        lbl_set = set(lbls)
        for lbl in lbl_set:
            splits = lbl.split('/')
            if len(splits) > 2:
                for i in range(2, len(splits)):
                    ancester = '/'.join(splits[:i])
                    if ancester in lbl_set:
                        hier_dict[lbl].append(ancester)
        
        # construct the transformation matrix
        hier_mat = np.zeros([len(self.lbl2id), len(self.lbl2id)])
        for lbl in self.lbl2id:
            # self
            hier_mat[self.lbl2id[lbl], self.lbl2id[lbl]] = 1
            for ancester in hier_dict.get(lbl, []):
                hier_mat[self.lbl2id[ancester], self.lbl2id[lbl]] = beta
        
        # move to device
        return torch.tensor(hier_mat, dtype=torch.float, device=device, requires_grad=False)





if __name__ == '__main__':
    # attweight = AttentionWeight(20)
    # a = torch.zeros([100, 20])
    # a[[0,1,2]] = 1
    # print(a)
    # print(attweight(a, a))

    # graph_conv = GraphConv(2, 4)

    # x = torch.tensor([[2, 1], [5, 6], [3, 7], [12, 0]], dtype=torch.float32)
    # y = torch.tensor([0, 1, 0, 1], dtype=torch.float)
    # edge_index = torch.tensor([[0, 1, 0], \
    #                             [1, 0, 3]], dtype=torch.long)

    # out = graph_conv(x, edge_index, from_mask=torch.tensor([False, True, False]), to_mask=torch.tensor([True, True, False]))
    # out = graph_conv(out, edge_index, from_mask=torch.tensor([False, True, False]), to_mask=torch.tensor([True, True, False]))
    # print(x, out)


    hier = HierarchicalLossNormalization('data/ontology/onto_ontology.txt')