# coding: utf-8

import torch
from torch import nn
from modules import GraphConv

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

class GraphConvModel(nn.Module):
    ''' Aggregate graph information '''
    def __init__(self, in_features=100, hidden_features=200, n_layers=1):
        super(GraphConvModel, self).__init__()
        tmp_layer_list = [GraphConv(in_features, hidden_features)]
        for _ in range(n_layers - 1):
            tmp_layer_list.append(GraphConv(in_features, hidden_features))
        self.conv_layer_list = nn.ModuleList(tmp_layer_list)

    def forward(self, x, edge_index, from_mask, to_mask):
        for conv_layer in self.conv_layer_list:
            x = conv_layer(x, edge_index, from_mask, to_mask)
        return x

class HFet(nn.Module):

    def __init__(self,
                 label_size,
                 elmo_option,
                 elmo_weight,
                 elmo_dropout=.5,
                 repr_dropout=.2,
                 dist_dropout=.5,
                 latent_size=0,
                 svd=None,
                 ):
        super(HFet, self).__init__()
        import math
        from allennlp.modules.elmo import Elmo
        
        self.label_size = label_size
        self.elmo = Elmo(elmo_option, elmo_weight, 1,
                         dropout=elmo_dropout)
        self.elmo_dim = self.elmo.get_output_dim()

        self.attn_dim = 1
        self.attn_inner_dim = self.elmo_dim
        # Mention attention
        self.men_attn_linear_m = nn.Linear(self.elmo_dim, self.attn_inner_dim, bias=False)
        self.men_attn_linear_o = nn.Linear(self.attn_inner_dim, self.attn_dim, bias=False)
        # Context attention
        self.ctx_attn_linear_c = nn.Linear(self.elmo_dim, self.attn_inner_dim, bias=False)
        self.ctx_attn_linear_m = nn.Linear(self.elmo_dim, self.attn_inner_dim, bias=False)
        self.ctx_attn_linear_d = nn.Linear(1, self.attn_inner_dim, bias=False)
        self.ctx_attn_linear_o = nn.Linear(self.attn_inner_dim,
                                        self.attn_dim, bias=False)
        # Output linear layers
        self.repr_dropout = nn.Dropout(p=repr_dropout)

        self.criterion = nn.MultiLabelSoftMarginLoss()
        self.mse = nn.MSELoss()
        # Relative position (distance)
        self.dist_dropout = nn.Dropout(p=dist_dropout)

    def forward_nn(self, inputs, men_mask, ctx_mask, dist, gathers):
        # Elmo contextualized embeddings

        # print(inputs, inputs.size())        # 160,85,50
        # print(men_mask, men_mask.size())    # 160, 85
        # print(ctx_mask, ctx_mask.size())    # 160, 85
        # print(dist, dist.size())            # 160, 85
        # print(gathers, gathers.size())      # 160

        elmo_outputs = self.elmo(inputs)['elmo_representations'][0]
        _, seq_len, feat_dim = elmo_outputs.size()
        gathers = gathers.unsqueeze(-1).unsqueeze(-1).expand(-1, seq_len, feat_dim)
        elmo_outputs = torch.gather(elmo_outputs, 0, gathers)

        men_attn = self.men_attn_linear_m(elmo_outputs).tanh()
        men_attn = self.men_attn_linear_o(men_attn)
        men_attn = men_attn + (1.0 - men_mask.unsqueeze(-1)) * -10000.0
        men_attn = men_attn.softmax(1)
        men_repr = (elmo_outputs * men_attn).sum(1)

        dist = self.dist_dropout(dist)
        ctx_attn = (self.ctx_attn_linear_c(elmo_outputs) +
                    self.ctx_attn_linear_m(men_repr.unsqueeze(1)) +
                    self.ctx_attn_linear_d(dist.unsqueeze(2))).tanh()
        ctx_attn = self.ctx_attn_linear_o(ctx_attn)

        ctx_attn = ctx_attn + (1.0 - ctx_mask.unsqueeze(-1)) * -10000.0
        ctx_attn = ctx_attn.softmax(1)
        ctx_repr = (elmo_outputs * ctx_attn).sum(1)

        # Classification
        final_repr = torch.cat([men_repr, ctx_repr], dim=1)
        final_repr = self.repr_dropout(final_repr)
        # print(final_repr.size())    # 160, 2048

        return final_repr

    def forward(self, inputs, labels, men_mask, ctx_mask, dist, gathers):
        # outputs, outputs_latent = self.forward_nn(inputs, men_mask, ctx_mask, dist, gathers)
        repr = self.forward_nn(inputs, men_mask, ctx_mask, dist, gathers)
        # loss = self.criterion(outputs, labels)
        # return loss
        return repr

class Encoder(nn.Module):
    ''' Get the embeddings from either labels or from sentences. '''
    def __init__(self, sent_idx_bias, sentence_encoder, sentence_dataset, label_embeddings):
        super(Encoder, self).__init__()
        self.sent_idx_bias = sent_idx_bias
        self.sentence_encoder = sentence_encoder
        self.sentence_dataset = sentence_dataset
        self.label_embeddings = label_embeddings
        self.hidden_dim = label_embeddings.embedding_dim
    
    def forward(self, indices):
        embeds = torch.empty([indices.size(0), self.hidden_dim], device=device)

        lbl_mask = indices < self.sent_idx_bias
        sent_mask = indices >= self.sent_idx_bias
        lbl_mask_inds = self.get_mask_ind(lbl_mask)
        sent_mask_inds = self.get_mask_ind(sent_mask)

        # encode labels
        embeds[lbl_mask_inds] = self.label_embeddings(indices[lbl_mask_inds])

        # many sentences are repeatedly passed through encoder, optimize this situation
        sent_inds = indices[sent_mask_inds] - self.sent_idx_bias
        sent_inds = sent_inds.cpu().numpy()

        in_dict = {}
        unique_list = []
        restore_list = []
        k = 0

        for idx in sent_inds:
            idx = int(idx)
            if idx not in in_dict:
                unique_list.append(idx)
                in_dict[idx] = k
                restore_list.append(k)
                k += 1
                continue
            restore_list.append(in_dict[idx])

        sents_output = self.sentence_dataset[unique_list]
        unique_sent_embeds = self.sentence_encoder(*sents_output)
        embeds[sent_mask_inds] = unique_sent_embeds[restore_list]

        return embeds
        
    def get_mask_ind(self, mask):
        ''' Retrieve the non-zero element's indices. '''
        # torch.topk with k= torch.sum(mask)
        k = torch.sum(mask)
        _, inds = mask.int().topk(k, sorted=False)
        return inds

class Classifier(nn.Module):
    def __init__(self, sample_dim, label_dim, n_hidden=100):
        super(Classifier, self).__init__()
        self.sample_bn = nn.BatchNorm1d(sample_dim)
        self.label_bn = nn.BatchNorm1d(label_dim)

        self.pre_layer = nn.Sequential(
            nn.Linear(sample_dim + label_dim, n_hidden),
        )
        self.batchnorm_activation = nn.Sequential(
            nn.BatchNorm1d(n_hidden),
            nn.LeakyReLU()
        )
        self.post_layer = nn.Sequential(
            nn.Linear(n_hidden, 1)
        )
    
    def forward(self, sample_feature, label_feature):
        ''' Input sample feature (size [B, sample_dim]) and features [B, k, feat_dim] to be compared.
        Output predition logits.
        '''

        if len(sample_feature.size()) < 3 and len(label_feature.size()) == 3:  # Input sample feature (size [B, sample_dim]) and features [B, k, feat_dim]
            k = label_feature.size(1)
            sample_feature = sample_feature.unsqueeze(1).repeat(1, k, 1)
        elif sample_feature.size(0) == label_feature.size(0):   # Input sample feature (size [B, k, sample_dim]) and features [B, feat_dim]
            k = sample_feature.size(1)
            label_feature = label_feature.unsqueeze(1).repeat(1, k, 1)
        else:
            k = sample_feature.size(0)
            b = label_feature.size(0)
            sample_feature = sample_feature.unsqueeze(1).repeat(1, b, 1)
            label_feature = label_feature.unsqueeze(0).repeat(k, 1, 1)
        

        sample_feature = self.sample_bn(sample_feature.transpose(1, 2)).transpose(1, 2)
        label_feature = self.label_bn(label_feature.transpose(1, 2)).transpose(1, 2)

        input_ = torch.cat([sample_feature, label_feature], dim=-1)
        input_ = self.pre_layer(input_)
        input_ = self.batchnorm_activation(input_.transpose(1, 2))
        output = self.post_layer(input_.transpose(1, 2))
        return output

class HFetClassifier(nn.Module):
    def __init__(self, sample_dim, label_dim, n_lbls, hierlossnorm_ontology=None):    
        super(HFetClassifier, self).__init__()
        import math
        n_hidden = int(math.sqrt(n_lbls))
        self.latent_layer = nn.Sequential(
            nn.Linear(sample_dim, n_hidden, bias=False),
            nn.Linear(n_hidden, n_lbls, bias=False)
        )
        self.latent_scalar = nn.Parameter(torch.tensor([.1]))

        if hierlossnorm_ontology:
            from modules import HierarchicalLossNormalization
            self.use_hierlossnorm = True
            self.hierlossnorm = HierarchicalLossNormalization(hierlossnorm_ontology)
        else:
            self.use_hierlossnorm = False

    def forward(self, sample_feature, label_feature, moredim=False):
        ''' Input sample feature (size [B, sample_dim]) and features [n_lbls, feat_dim]
        sample_dim == feature_dim 
        '''
        if moredim:
            outputs = torch.einsum('abc,ac->ab', sample_feature, label_feature)
        else:
            outputs = torch.einsum('ab,cb->ac', sample_feature, label_feature)

            outputs_latent = self.latent_layer(sample_feature)
            outputs = outputs + self.latent_scalar * outputs_latent
        outputs = outputs.sigmoid()
        if self.use_hierlossnorm:
            outputs = self.hierlossnorm(outputs)

        return outputs
