import torch
import numpy as np
import torch.nn as nn
from torch.autograd import Variable
from torch.nn.parameter import Parameter
import torch.nn.functional as F
from util import *
import math


def hard_sigmoid(x):
    """
    Computes element-wise hard sigmoid of x.
    See e.g. https://github.com/Theano/Theano/blob/master/theano/tensor/nnet/sigm.py#L279
    """
    x = (0.2 * x) + 0.5
    x = F.threshold(-x, -1, -1)
    x = F.threshold(-x, 0, 0)
    return x


def average_pooling(state,lens):
    """
    average pool of hidden states in rnn
    :param state: [ batch, len, state]
    :param lens: [ batch]
    :return:
    """
    state_summed = torch.sum(state,dim=1) # [ batch, state]
    return state_summed / lens.unsqueeze(1).float()

def last_pooling(state,lens):
    """
    return last hidden states in rnn
    :param state: [ batch, len, state]
    :param lens: [ batch]
    :return:
    """
    row_indices = torch.arange(0, state.size(0)).long()
    col_indices = lens - 1
    if state.is_cuda:
        row_indices = row_indices.cuda()
        col_indices = col_indices.cuda()
    last_tensor = state[row_indices, col_indices, :]
    return last_tensor

def mask_lengths(lengths, max_len=None):
    """

    :param lengths: [batch_size] indicates lengths of sequence
    :return: [batch_size, max_len] ones for within the lengths zeros for exeeding lengths

    [4,2] -> [[1,1,1,1]
              ,[1,1,0,0]]
    """
    device = lengths.device
    if not max_len:
        max_len = torch.max(lengths).item()
    idxes = torch.arange(0,max_len,out=torch.LongTensor(max_len)).unsqueeze(0).to(device)
    masks = (idxes<lengths.unsqueeze(1)).float()
    return masks

# def mask_lengths(lengths, max_len=None):
#     """
#
#     :param lengths: [batch_size] indicates lengths of sequence
#     :return: [batch_size, max_len] ones for within the lengths zeros for exeeding lengths
#
#     [4,2] -> [[1,1,1,1]
#               ,[1,1,0,0]]
#     """
#     device = lengths.device
#     if not max_len:
#         max_len = torch.max(lengths).item()
#     idxes = torch.arange(0,max_len,out=torch.LongTensor(max_len)).unsqueeze(0).to(device)
#     masks = idxes>=lengths.unsqueeze(1)
#     return masks


def sample_gumbel(shape,device, eps=1e-20):
    U = torch.rand(shape).to(device)
    return -Variable(torch.log(-torch.log(U + eps) + eps))


def gumbel_softmax_sample(logits, temperature):
    device = logits.device
    y = logits + sample_gumbel(logits.size(),device)
    return F.softmax(y / temperature, dim=-1)


def gumbel_softmax(logits, temperature):
    """
    input: [*, n_class]
    return: [*, n_class] an one-hot vector
    """
    y = gumbel_softmax_sample(logits, temperature)
    shape = y.size()
    _, ind = y.max(dim=-1)
    y_hard = torch.zeros_like(y).view(-1, shape[-1])
    y_hard.scatter_(1, ind.view(-1, 1), 1)
    y_hard = y_hard.view(*shape)
    return (y_hard - y).detach() + y

def hard_softmax(logits):
    layer_prob = F.softmax(logits, -1)
    _, k = layer_prob.max(-1)
    shape = logits.size()
    layer_prob = logits.new_zeros(*shape).scatter_(-1, k.view(-1, 1), 1.0).unsqueeze(-1)
    return layer_prob

def dotprod_selfatt_lengths(states,contexts,lengths):
    """
    attention through dotproduct between states and contexts
    :param states: [batch,lengths,hidden]
    :param contexts: [hidden]
    :param lengths: [batch]
    :return:
    """
    contexts = contexts.unsqueeze(0).unsqueeze(1)
    att_logit = torch.sum(torch.mul(states,contexts),2)
    att_score = softmax_with_len(att_logit,lengths).unsqueeze(2)
    # print(att_score)
    aligned = att_score * states
    aligned = torch.sum(aligned, dim=1)
    return aligned


def dotprod_catt_lengths(states,contexts,lengths):
    """
    attention through dotproduct between states and contexts
    :param states: [batch,lengths,hidden]
    :param contexts: [class,hidden]
    :param lengths: [batch]
    :return:
    """
    contexts = contexts.unsqueeze(0).unsqueeze(1)   # [ 1, 1, class, hidden ]
    states = states.unsqueeze(2)                    # [ batch, lengths, 1, hidden ]
    att_logit = torch.sum(torch.mul(states,contexts),3)   # [ batch, lengths, class ]
    masks = mask_lengths(lengths) # [batch,lengths]
    exp = torch.exp(att_logit)
    masked = exp * masks.unsqueeze(2)
    normalizing_factor = torch.sum(masked,dim=1).unsqueeze(1)
    att_score = masked / normalizing_factor
    att_score = att_score.unsqueeze(3)
    # print(att_score)
    aligned = att_score * states
    aligned = torch.sum(aligned, dim=1).squeeze()
    return aligned


def softmax_with_len(logits,lengths,maxlen=None):
    """
    :param logits: [batch,lengths]
    :param lengths: [batch]
    :return:
    """
    masks = mask_lengths(lengths,maxlen)
    smx = F.softmax(logits,-1)
    # print(masks)
    masked = masks * smx
    normalizing_factor = torch.sum(masked,dim=-1).unsqueeze(1)
    att_score = masked / (normalizing_factor +1e-15)
    # att_score = att_score * masks.unsqueeze(-1)
    # print(torch.sum(att_score,-1))
    return att_score


def softmax_with_len_3d(logits,lengths):
    """
    :param logits: [batch,l,l]
    :param lengths: [batch]
    :return:
    """
    b, l1, l2 = logits.size()
    masks = mask_lengths(lengths)
    # print(masks)
    # max_len = torch.max(lengths).item()
    masks = masks.unsqueeze(-1) * masks.unsqueeze(-2)
    # print(masks[-1])
    # print(logits[-1],'before')
    # logits *= masks
    # print(logits[-1],'after')
    # logits = logits* (mask_lengths(lengths,device).unsqueeze(-1)*mask_lengths(lengths,device).unsqueeze(-2))
    # squared_logits = logits **2
    # squared_logits = torch.sum(squared_logits,dim=-1)
    # squared_logits_mean = squared_logits / lengths.float().unsqueeze(-1)
    # logit_sum = torch.sum(logits,dim=-1)
    # logits_mean = logit_sum / lengths.float().unsqueeze(-1) # [b,l]
    # diff = logits_mean.unsqueeze(-1) - logits
    # sum_squared_diff = torch.sum(diff **2,dim=-1)
    # variation = sum_squared_diff / lengths.float().unsqueeze(-1)
    # variation = squared_logits_mean - logits_mean**2 + 1e-7
    #
    # # print(torch.mean(variation),torch.min(variation),torch.max(variation))
    # std = torch.sqrt(variation)
    # logits = logits / std.unsqueeze(-1)
    # print(std)
    # std = torch.std(logits,dim=-1)
    # logits = logits / std.unsqueeze(-1)
    # print(std[-1])


    # print(masks)
    # max_logits = torch.max(logits,-1)[0].unsqueeze(-1)
    # exp = torch.exp(logits - max_logits)
    # masked = exp * masks
    # normalizing_factor = torch.sum(masked,dim=-1).unsqueeze(-1)
    # logits = masks * logits
    # neg_masks = (masks - 1) * 20
    # logits += neg_masks
    # logits.masked_fill_(masks,float(-20))
    smx = F.softmax(logits,-1)
    # print(smx[-1])
    # att_score = smx
    # masked = smx
    masked = masks * smx
    normalizing_factor = torch.sum(masked,dim=-1).unsqueeze(-1)
    att_score = masked / (normalizing_factor +1e-15)
    # att_score = att_score * masks
    return att_score




def reorder_sequence(x,index):
    x2 = torch.empty_like(x)
    x2[index,:,:] = x
    return x2


def run_rnn(x,lengths,rnn):
    sorted_lengths, sort_index = lengths.sort(0, True)
    x_sorted = x.index_select(0, sort_index)
    packed_input = nn.utils.rnn.pack_padded_sequence(x_sorted, sorted_lengths, batch_first=True)
    packed_output, _ = rnn(packed_input, None)
    out_rnn, _ = nn.utils.rnn.pad_packed_sequence(packed_output, batch_first=True)
    out_rnn = reorder_sequence(out_rnn, sort_index)
    return out_rnn


def reverse_sequence(x,lengths):
    batch_size = x.size(0)
    max_len = x.size(1)
    indice = lengths.unsqueeze(1) - torch.arange(max_len).repeat(batch_size, 1) - 1
    reversed = torch.stack([x[i][indice[i]] for i in range(batch_size)],0)
    return reversed


class CL(nn.Module):
    def __init__(self,context_size,state_size,num_layers=1,modulation='resnet',average_pooling=False):
        super(CL,self).__init__()
        print('CL',modulation)
        self.context_size = context_size
        self.state_size = state_size
        self.num_layers = num_layers
        self.modulation = modulation
        self.average_pooling = average_pooling
        # self.layers = torch.nn.ModuleList()
        g_layers = []
        b_layers = []
        u_layers = []
        gn_layers = torch.nn.ModuleList()
        bn_layers = torch.nn.ModuleList()
        un_layers = torch.nn.ModuleList()

        # if self.modulation =='linear' or self.modulation =='highway':
        #     self.gamma_biases = Parameter(torch.ones(state_size * num_layers))
        # else:
        #     self.gamma_biases = Parameter(torch.zeros(state_size * num_layers))
        for i in range(num_layers):
            if i ==0:
                # self.layers.append(nn.Linear(cat * state_size + context_size, state_size*(2 + (self.modulation=='highway'))))
                g = nn.Linear(context_size, state_size)
                b = nn.Linear(context_size, state_size)
                # gn = nn.LayerNorm(state_size)
                # bn = nn.LayerNorm(state_size)

                # nn.init.xavier_normal_(gn.weight.data)
                # nn.init.xavier_normal_(bn.weight.data)
                nn.init.xavier_normal_(g.weight.data,gain=5.0/3)
                nn.init.xavier_normal_(b.weight.data, gain=5.0/3)
                # nn.init.kaiming_normal_(g.weight.data,nonlinearity='tanh')
                # nn.init.kaiming_normal_(b.weight.data,nonlinearity='tanh')
                nn.init.constant_(g.bias, val=0.0)
                nn.init.constant_(b.bias, val=0.0)
                g_layers.append(g)
                b_layers.append(b)
                # gn_layers.append(gn)
                # bn_layers.append(bn)
                if self.modulation =='highway':
                    u = nn.Linear(context_size, state_size)
                    # nn.init.kaiming_normal_(u.weight.data,nonlinearity='sigmoid')
                    nn.init.xavier_normal_(u.weight.data, gain=1)
                    nn.init.constant_(u.bias, val=0.0)
                    u_layers.append(u)
                    # un = nn.LayerNorm(state_size)
                    # un_layers.append(un)
            else:
                # self.layers.append(nn.Linear(state_size*(2 + (self.modulation=='highway')), state_size*(2 + (self.modulation=='highway'))))
                g = nn.Linear(state_size, state_size)
                b = nn.Linear(state_size, state_size)
                # nn.init.kaiming_normal_(g.weight.data,nonlinearity='tanh')
                # nn.init.kaiming_normal_(b.weight.data,nonlinearity='tanh')
                nn.init.xavier_normal_(g.weight.data, gain = 5.0 / 3)
                nn.init.xavier_normal_(b.weight.data, gain = 5.0 / 3)
                nn.init.constant_(g.bias, val=0.0)
                nn.init.constant_(b.bias, val=0.0)
                g_layers.append(g)
                b_layers.append(b)

                # gn = nn.LayerNorm(state_size)
                # bn = nn.LayerNorm(state_size)
                # gn_layers.append(gn)
                # bn_layers.append(bn)
                if self.modulation == 'highway':
                    u = nn.Linear(state_size, state_size)
                    # nn.init.kaiming_normal_(u.weight.data,nonlinearity='sigmoid')
                    nn.init.xavier_normal_(u.weight.data, gain=1)
                    nn.init.constant_(u.bias, val=0.0)
                    u_layers.append(u)
                    # un = nn.LayerNorm(state_size)
                    # un_layers.append(un)
                # g_layers.append(nn.Linear(state_size, state_size))
                # b_layers.append(nn.Linear(state_size, state_size))
                # if self.modulation =='highway':
                #     u_layers.append(nn.Linear(state_size, state_size))
        self.g_layers = torch.nn.ModuleList(g_layers)
        self.b_layers = torch.nn.ModuleList(b_layers)
        self.gn_layers = gn_layers
        self.bn_layers = bn_layers
        if self.modulation == 'highway':
            self.u_layers = torch.nn.ModuleList(u_layers)
            self.un_layers = un_layers

    def forward(self, context, words):
        """
        :param context:  [batch, hidden]
        :param input: [ batch, len, embedding]
        :return:
        """
        x = words
        if self.average_pooling:
            lens = words.size(1)
            context = context.unsqueeze(1).repeat(1, lens, 1)
        if self.modulation =='resnet':
            gamma = context
            beta = context
            for i in range(self.num_layers):
                # lm_params = F.tanh(self.layers[i](context))
                # gamma,beta = lm_params.chunk(2,-1)
                gamma = F.leaky_relu(self.g_layers[i](gamma))
                beta = F.leaky_relu(self.b_layers[i](beta))
                if self.average_pooling:
                    residual = x * (gamma.unsqueeze(1) + self.gamma_biases[
                                            i * self.state_size:(i + 1) * self.state_size].unsqueeze(
                        0).unsqueeze(1)) + beta.unsqueeze(1)
                else:
                    residual = x * (gamma) + beta

                # residual = x * gamma.unsqueeze(1) + beta.unsqueeze(1)
                # context = lm_params
                x = x + residual
                x = F.leaky_relu(x)
        elif self.modulation =='linear':
            gamma = context
            beta = context
            for i in range(self.num_layers):
                # lm_params = self.layers[i](context)
                # gamma, beta = lm_params.chunk(2, -1)
                gamma = self.g_layers[i](gamma)
                beta = self.b_layers[i](beta)

                # gamma = F.tanh(gamma)
                # beta = F.tanh(beta)
                if self.average_pooling:
                    # x = x * (gamma.unsqueeze(1) + self.gamma_biases[
                    #                               i * self.state_size:(i + 1) * self.state_size].unsqueeze(0).unsqueeze(
                    #     1)) + beta.unsqueeze(1)
                    x = x * (gamma.unsqueeze(1) + 1) + beta.unsqueeze(1)

                else:
                    # x = x * (gamma + self.gamma_biases[
                    #                  i * self.state_size:(i + 1) * self.state_size].unsqueeze(0).unsqueeze(
                    #     1)) + beta
                    x = x * (gamma)+ beta
                    x = F.tanh(x)

                # x = x * gamma.unsqueeze(1) + beta.unsqueeze(1)
                gamma = F.relu6(gamma)
                beta = F.relu6(beta)
                # context = lm_params
                # x = torch.tanh(x)
                # if i+1 !=self.num_layers:
                #     x = torch.relu(x)
        elif self.modulation =='highway':
            # print(context.size(),words.size())
            # print(context.size())
            uc = context
            gc = context
            bc = context
            for i in range(self.num_layers):
                # lm_params = self.layers[i](context)
                # gamma, beta, update = lm_params.chunk(3,-1)
                gamma = self.g_layers[i](gc)
                beta = self.b_layers[i](bc)

                # gamma = self.gn_layers[i](gamma)
                # beta = self.bn_layers[i](beta)

                gamma = F.tanh(gamma)
                beta = F.tanh(beta)

                update = self.u_layers[i](uc)
                # update = self.un_layers[i](update)
                # if self.cat:
                #     update = update + self.gamma_biases[i * self.state_size:(i + 1) * self.state_size].unsqueeze(
                #         0).unsqueeze(1)
                # else:
                #     update = update + self.gamma_biases[i * self.state_size:(i + 1) * self.state_size].unsqueeze(
                #         0)
                update = F.sigmoid(update)
                # update = F.hardtanh(update,0,1)
                # update = hard_sigmoid(update)

                if self.average_pooling:
                    new = x * (gamma.unsqueeze(1) + self.gamma_biases[
                                                  i * self.state_size:(i + 1) * self.state_size].unsqueeze(0).unsqueeze(
                        1)) + beta.unsqueeze(1)
                else:
                    new = x * (gamma) + beta

                # beta = torch.relu(self.b_layers[i](beta))
                # update = torch.sigmoid(self.u_layers[i](update))
                # if self.cat:
                #     new = x * gamma + beta
                # else:
                #     new = x * gamma.unsqueeze(1) + beta.unsqueeze(1)
                # x = x * gamma.unsqueeze(1) + beta.unsqueeze(1)
                if self.average_pooling:
                    update = update.unsqueeze(1)
                x = update * x + (1-update) * new
                x = F.tanh(x)
                # gamma = F.relu(gamma)
                # beta = F.relu(beta)
                uc = update
                gc = gamma
                bc = beta
                # context = torch.cat([gamma,beta,update],-1)
                # if i+1 !=self.num_layers:
                #     x = torch.relu(x)

        return x


class CLN(nn.Module):
    def __init__(self,context_size,state_size,num_layers=1,cat=True,average_pooling = False,eps = 1e-5):
        super(CLN,self).__init__()
        self.context_size = context_size
        self.state_size = state_size
        self.num_layers = num_layers
        self.cat = cat
        self.eps = eps
        self.average_pooling = average_pooling
        self.layers = nn.Sequential()
        self.gamma_biases = Parameter(torch.ones(state_size * num_layers))
        self.layers = nn.Sequential()
        for i in range(num_layers):
            if i == 0:
                self.layers.add_module('layer_{}'.format(i), nn.Linear(cat * state_size + context_size, state_size * 2))
            else:
                self.layers.add_module('layer_{}'.format(i), nn.Linear(state_size * 2 + context_size, state_size * 2))


    def forward(self, context, words, seq_lengths):
        """
        :param context:  [batch, hidden]
        :param input: [ batch, len, embedding]
        :return:
        """
        x = words
        mask = mask_lengths(seq_lengths)
        if self.cat:
            if self.average_pooling:
                lens = words.size(1)
                context = context.unsqueeze(1).repeat(1, lens, 1)
            context = torch.cat([context,words],-1)

        # gamma = context
        # beta = context
        mean = x.mean(-1, keepdim=True)
        std = x.std(-1, keepdim=True)
        lm_params = self.layers(context)
        gamma, beta = lm_params.chunk(2,-1) # [batch, hidden]
        x = (words - mean) / (std + self.eps) # [ batch, lens, hidden]
        if self.average_pooling and not self.cat:
            x = x * (gamma.unsqueeze(1)) + beta.unsqueeze(1)
        else:
            x = x * (gamma) + beta
        return x


class CEN(nn.Module):
    def __init__(self, context_size, state_size, num_layers=1,cat=True,eps = 1e-5,average_pooling=False):
        super(CEN,self).__init__()
        self.context_size = context_size
        self.state_size = state_size
        self.num_layers = num_layers
        self.cat = cat
        self.eps = eps
        self.average_pooling = average_pooling
        self.gamma_biases = Parameter(torch.ones(state_size))
        # self.layers = Highway(cat *state_size + context_size, state_size,num_layers)
        # self.biases = Highway(cat *state_size + context_size, state_size,num_layers)
        self.layers = nn.Sequential()
        for i in range(num_layers):
            if i ==0:
                self.layers.add_module('layer_{}'.format(i), nn.Linear(cat * state_size + context_size, state_size*2))
            else:
                self.layers.add_module('layer_{}'.format(i), nn.Linear(state_size*2, state_size*2))
            self.layers.add_module('relu_{}'.format(i),nn.ReLU6())


    def forward(self, context, words, seq_lengths):
        """
        :param context:  [batch,len, hidden] or [batch, hidden]
        :param input: [ batch, len, embedding]
        :return:
        """
        x = words
        if self.cat:
            if self.average_pooling:
                lens = words.size(1)
                context = context.unsqueeze(1).repeat(1, lens, 1)
            context = torch.cat([context,words],-1)

        # gamma = context
        # beta = context
        embedding_summed = torch.sum(words, dim=1)  # [ batch, embedding]
        lengths = seq_lengths.unsqueeze(1).float()
        mean = embedding_summed / lengths
        squared_deviation = (words - mean.unsqueeze(1))**2 #[batch, len, embedding]
        mask = mask_lengths(seq_lengths.long()) # [batch, lens]
        squared_deviation *=mask.unsqueeze(-1)
        sum_squared_deviation = torch.sum(squared_deviation,dim=1)
        variation = (sum_squared_deviation / (lengths-1))
        std = torch.sqrt(variation)

        lm_params = self.layers(context)
        gamma, beta = lm_params.chunk(2,-1) # [batch, hidden]
        # gamma = F.tanh(gamma)
        # gamma = torch.tanh(gamma)

        # gamma = torch.sigmoid(gamma)
        # gamma = self.layers(x)
        # beta = self.biases(x)
        x = (words - mean.unsqueeze(1)) / (std.unsqueeze(1) + self.eps) # [ batch, lens, hidden]
        if self.average_pooling and not self.cat:
            x = x * (gamma.unsqueeze(1)) + beta.unsqueeze(1)
        else:
            # x = x * (gamma + 1) + beta
            x = x * (gamma) + beta
        x = x*mask.unsqueeze(-1)
        return x


class ACL(nn.Module):
    def __init__(self,context_size,state_size, maxlen, num_layers=1,att_type='mul',modulation=False,cat=False):
        print('ACl ',modulation, cat)
        super(ACL,self).__init__()
        self.cl = CL(context_size + (cat*context_size), state_size, num_layers, modulation)
        # self.cl = CL(context_size, state_size, num_layers, modulation)
        self.context_size = context_size
        self.state_size = state_size
        self.num_layers = num_layers
        self.modulation = modulation
        self.maxlen = maxlen
        self.att_type= att_type
        self.scale = Parameter(torch.Tensor([1.0]))
        # self.lp = LC_Softmax(state_size)
        self.cat = cat
        # self.u = nn.Linear(context_size+state_size,state_size)
        self.state_transform = nn.Sequential(nn.Linear(state_size, context_size))
        # self.context_transform = nn.Sequential(nn.Linear(state_size+context_size, state_size))
        for i in self.state_transform:
            if isinstance(i,nn.Linear):
                # nn.init.normal_(i.weight)
                nn.init.kaiming_normal_(i.weight,nonlinearity='linear')
                nn.init.constant_(i.bias,val=0.0)
        # for i in self.context_transform:
        #     if isinstance(i,nn.Linear):
        #         # nn.init.normal_(i.weight)
        #         nn.init.kaiming_normal_(i.weight,nonlinearity='tanh')
        if att_type =='multiaspect':
            self.score_cal = nn.Sequential(nn.Linear(context_size*4,context_size),nn.Tanh(),nn.Linear(context_size,1))
            # nn.init.xavier_normal_(self.score_cal[0].weight,5.0/3)
            # nn.init.xavier_normal_(self.score_cal[-1].weight)
            nn.init.kaiming_normal_(self.score_cal[0].weight)
            nn.init.kaiming_normal_(self.score_cal[-1].weight)

        # if self.cat:
        #     self.projection = nn.Linear(context_size + state_size, context_size)
        #     nn.init.kaiming_normal_(self.projection.weight,nonlinearity='tanh')
        # self.context_projection = nn.Linear(context_size,state_size)
        # nn.init.kaiming_normal_(self.context_projection.weight)
        # self.position_bias = Parameter(torch.eye(maxlen))
        # g_layers = []
        # b_layers = []
        # u_layers = []
        # for i in range(num_layers):
        #     if i ==0:
        #         # self.layers.append(nn.Linear(cat * state_size + context_size, state_size*(2 + (self.modulation=='highway'))))
        #         g = nn.Linear(cat * state_size + context_size, state_size)
        #         b = nn.Linear(cat * state_size + context_size, state_size)
        #         nn.init.xavier_normal_(g.weight.data)
        #         nn.init.xavier_normal_(b.weight.data)
        #         g_layers.append(g)
        #         b_layers.append(b)
        #         if self.modulation =='highway':
        #             u = nn.Linear(cat*state_size+context_size, state_size)
        #             nn.init.xavier_normal_(u.weight.data)
        #             u_layers.append(u)
        #     else:
        #         # self.layers.append(nn.Linear(state_size*(2 + (self.modulation=='highway')), state_size*(2 + (self.modulation=='highway'))))
        #         g = nn.Linear(state_size, state_size)
        #         b = nn.Linear(state_size, state_size)
        #         nn.init.xavier_normal_(g.weight.data)
        #         nn.init.xavier_normal_(b.weight.data)
        #         g_layers.append(g)
        #         b_layers.append(b)
        #         if self.modulation == 'highway':
        #             u = nn.Linear(state_size, state_size)
        #             nn.init.xavier_normal_(u.weight.data)
        #             u_layers.append(u)
        #         # g_layers.append(nn.Linear(state_size, state_size))
        #         # b_layers.append(nn.Linear(state_size, state_size))
        #         # if self.modulation =='highway':
        #         #     u_layers.append(nn.Linear(state_size, state_size))
        # self.g_layers = torch.nn.ModuleList(g_layers)
        # self.b_layers = torch.nn.ModuleList(b_layers)
        # if self.modulation == 'highway':
        #     self.u_layers = torch.nn.ModuleList(u_layers)

    def multiplicative_att(self,context,words,lens):
        """
        :param context: [batch, lens, hidden]
        :param words:  [ batch, lens, embedding]
        :param lens : [batch]
        :return: [batch, lens, embedding]
        att should be [batch, lens]
        """
        b_size = context.size(0)
        max_len = context.size(1)

        # tot_word = torch.zeros(b_size,self.maxlen,self.state_size).cuda()
        # tot_context = torch.zeros(b_size, self.maxlen, self.context_size).cuda()
        # tot_word[:,:max_len,:] = words
        # tot_context[:,:max_len,:] = context

        # p_bias = self.position_bias[:max_len,:max_len] #[len,len]
        # print(p_bias.size())
        # masks = mask_lengths(lens)  # [batch,len]

        # masks = mask_lengths(lens)
        # masks = mask_lengths(lens, self.maxlen)
        # masks3d = masks.unsqueeze(-1) * masks.unsqueeze(-2)
        # neg_masks = (masks3d - 1) * 20
        words_projected = self.state_transform(words) # Q [batch, lens, hidden]
        # words_projected = words
        # context_projected = F.tanh(self.context_transform(torch.cat([context,words],-1)))
        context_projected = context
        # context = self.state_transform(context)
        # words_projected = context
        # words_projected = F.tanh(self.state_transform(tot_word))
        # words_projected *= masks.unsqueeze(-1)
        # words_projected = F.tanh(self.context_projection(context)).unsqueeze(1) #[batch,1,lens,hidden]
        # affinity = torch.bmm(words_projected, torch.transpose(context,1,2)) #[batch, len, len]
        affinity = torch.bmm(context_projected,words_projected.transpose(1,2)) * self.scale
        # std = torch.
        # affinity += p_bias.unsqueeze(0)

        att = softmax_with_len_3d(affinity.transpose(1,2), lens)
        attentionized_context = torch.bmm(att,context)

        if self.cat:
            attentionized_context = torch.cat([attentionized_context,context],-1)
            # attentionized_context = F.tanh(self.projection(attentionized_context))
        # print(attentionized_context.size())

        # affinity = F.relu(affinity)
        # affinity = affinity + neg_masks
        # att = F.softmax(affinity.transpose(1,2),-1)

        # att = softmax_with_len_3d(affinity, lens)
        # attentionized_context = torch.bmm(att.transpose(1,2), context)
        # if self.cat:
        #     attentionized_context = torch.cat([attentionized_context,context],-1)


        # affinity = torch.baddbmm(p_bias,words_projected, torch.transpose(context, 1, 2))  # [batch, len, len]
        # affinity = torch.bmm(words_projected, torch.transpose(tot_context, 1, 2))
        # affinity = torch.addbmm(self.position_bias, words_projected, torch.transpose(tot_context, 1, 2))
        # max_affinity = torch.max(affinity,-1)[0] #[batch, len)
        # p_bias = max_affinity.unsqueeze(-1) * self.position_bias.unsqueeze(0)
        # affinity += p_bias
        # affinity = affinity + neg_masks
        # att = F.softmax(affinity, -1)
        # max_a = torch.max(affinity,dim=-1)[0].unsqueeze(-1) #[batch,len,1]
        # p_bias = p_bias.unsqueeze(0)*max_a
        # affinity = affinity + p_bias.unsqueeze(0)

        # attentionized_context = torch.bmm(att,context) #[batch, len, hidden]
        # attentionized_context = torch.bmm(att, tot_context)
        # attentionized_context *= masks.unsqueeze(-1)
        # attentionized_context = attentionized_context[:,:max_len,:]
        return attentionized_context

    def multiaspect_attention(self,context,words,lens):
        """
        :param context: [batch, lens, hidden]
        :param words:  [ batch, lens, embedding]
        :param lens : [batch]
        :return: [batch, lens, embedding]
        """
        l = context.size(1)
        words_projected = self.state_transform(words)  # Q [batch, lens, hidden]
        # words_projected = words
        # context_projected = self.context_transform(torch.cat([context,words],-1))
        context_projected = context
        words_expanded = words_projected.unsqueeze(1)
        context_expanded = context_projected.unsqueeze(2)

        diff = words_expanded - context_expanded
        mult = words_expanded * context_expanded
        aspect = torch.cat([words_expanded.repeat(1,l,1,1),context_expanded.repeat(1,1,l,1),diff,mult],-1)
        affinity = self.score_cal(aspect).squeeze(-1)
        att = softmax_with_len_3d(affinity.transpose(1, 2), lens)
        attentionized_context = torch.bmm(context.transpose(1, 2), att).transpose(1, 2)
        if self.cat:
            attentionized_context = torch.cat([attentionized_context,context],-1)
        
        

        return attentionized_context

    def forward(self, context, words,lens):
        """
        :param context: [batch,lens, hidden]
        :param words:  [ batch, lens, embedding]
        :param lens : [batch]
        :return: [batch, lens, embedding]
        """
        if self.att_type =='mul':
            att_context = self.multiplicative_att(context,words,lens)# + context
        elif self.att_type =='multiaspect':
            att_context = self.multiaspect_attention(context,words,lens)
        # print(att_context.size())
        x = self.cl(att_context, words)

        # a,b = self.lp()
        # context = a * context + b * self.multiplicative_att(context,words,lens)

        # u = F.sigmoid(self.u(torch.cat([context,words],-1)))
        # context = u * context + (1-u) * words


        # gc = context
        # bc = context
        # uc = context
        # x = words
        # for i in range(self.num_layers):
        #
        #     if self.modulation == 'highway':
        #     # lm_params = self.layers[i](context)
        #     # gamma, beta, update = lm_params.chunk(3,-1)
        #         gamma = self.g_layers[i](gc)
        #         beta = self.b_layers[i](bc)
        #         update = self.u_layers[i](uc)
        #         gamma = torch.tanh(gamma)
        #         beta = torch.tanh(beta)
        #         # if self.cat:
        #         #     update = update + self.gamma_biases[i * self.state_size:(i + 1) * self.state_size].unsqueeze(
        #         #         0).unsqueeze(1)
        #         # else:
        #         #     update = update + self.gamma_biases[i * self.state_size:(i + 1) * self.state_size].unsqueeze(
        #         #         0)
        #         update = torch.sigmoid(update)
        #         new = x * (gamma + 1) + beta
        #
        #         # beta = torch.relu(self.b_layers[i](beta))
        #         # update = torch.sigmoid(self.u_layers[i](update))
        #         # if self.cat:
        #         #     new = x * gamma + beta
        #         # else:
        #         #     new = x * gamma.unsqueeze(1) + beta.unsqueeze(1)
        #         # x = x * gamma.unsqueeze(1) + beta.unsqueeze(1)
        #         x = (1-update) * x + (update) * new
        #         x = torch.tanh(x)
        #         uc = update
        #         gc = gamma
        #         bc = beta
        #     elif self.modulation == 'linear':
        #         gamma = self.g_layers[i](gc)
        #         beta = self.b_layers[i](bc)
        #         gamma = torch.relu(gamma)
        #         beta = torch.relu(beta)
        #
        #         x = x * (gamma + 1) + beta
        #         gc = gamma
        #         bc = beta

        return x

class WC_Softmax(nn.Module):
    def __init__(self,init_val=1.0):
        super(WC_Softmax, self).__init__()
        self.cw = Parameter(torch.Tensor([init_val]))
        self.ww = Parameter(torch.Tensor([init_val]))

    def forward(self):
        # sc, sw = F.softmax(torch.cat([self.cw,self.ww]),-1)
        return self.cw, self.ww

class LC_Softmax(nn.Module):
    def __init__(self,state_size):
        super(LC_Softmax, self).__init__()
        self.l1 = Parameter(torch.zeros(state_size,1))
        self.l2 = Parameter(torch.zeros(state_size,1))

    def forward(self):
        smx = F.softmax(torch.cat([self.l1,self.l2],-1),-1)
        l1, l2 = smx.chunk(2,-1)
        l1 = l1.squeeze(-1)
        l2 = l2.squeeze(-1)
        return l1, l2

class Char_embedding(nn.Module):
    def __init__(self, embed_size,n_chars, conv_channels=100, Ks=[5]):
        super(Char_embedding, self).__init__()
        self.embed_size = embed_size
        self.conv_channels = conv_channels
        self.Ks = Ks
        self.char_embed = nn.Embedding(n_chars + 1, embed_size, n_chars)
        self.convs1 = nn.ModuleList(
            [nn.Conv2d(1, self.conv_channels, (K, embed_size)) for K in self.Ks])
        self.reset_weights()

    def reset_weights(self):
        nn.init.xavier_normal_(self.char_embed.weight)
        for conv in self.convs1:
            nn.init.xavier_normal_(conv.weight.data,gain=math.sqrt(2))

    def forward(self, x):
        """
        :param x: [batch, word, char]
        :param seq_lengths: [batch]
        :return:
        """
        x = self.char_embed(x)
        b_size, w_size, c_size,e_size = x.size()
        x = x.view(b_size * w_size, c_size,e_size)
        x = x.unsqueeze(1)  # [batch * word,1,char,embedding]

        convolutioned = [F.relu(conv(x)).squeeze(3) for conv in self.convs1]
        pooled = [F.max_pool1d(i, i.size(2)).squeeze(2) for i in convolutioned]
        last_tensor = torch.cat(pooled, 1)
        last_tensor = last_tensor.view(b_size, w_size, -1)

        return last_tensor

class Word_embedding(nn.Module):
    def __init__(self, embed_size,n_words,glove_filename=None,dic=None,fine_tune=True,init='kaiming'):
        super(Word_embedding, self).__init__()
        self.embed_x = nn.Embedding(n_words + 1, embed_size, n_words)
        if glove_filename:
            print('use glove')
            try:
                glove_vectors = load_file(glove_filename)
            except:
                print('no glove files')
                glove_vectors = load_glove('data/glove.840B.300d.txt', dic)
                save_file(glove_filename, glove_vectors)
            self.embed_x.weight.data.copy_(torch.from_numpy(glove_vectors))
            self.embed_x.weight.requires_grad = fine_tune
        else:
            if init =='kaiming':
                nn.init.kaiming_normal_(self.embed_x.weight.data,nonlinearity='linear')
                # self.embed_x.weight[n_words] = torch.zeros(embed_size)
            elif init =='xavier':
                nn.init.xavier_normal_(self.embed_x.weight.data)
                # self.embed_x.weight[n_words] = torch.zeros(embed_size)

    def forward(self, x):
        return self.embed_x(x)

class WC_embedding(nn.Module):
    def __init__(self, word_embed_size, n_words, char_embed_size, n_chars, glove_filename=None, dic=None, fine_tune=True):
        super(WC_embedding, self).__init__()
        self.chars = Char_embedding(char_embed_size,n_chars)
        self.words = Word_embedding(word_embed_size,n_words,glove_filename,dic,fine_tune)
        self.highway = nn.ModuleList()
        for i in range(2):
            self.highway.append(nn.Linear(word_embed_size + char_embed_size, (word_embed_size + char_embed_size) * 2))
        for i in self.highway:
            nn.init.xavier_normal(i.weight, gain=math.sqrt(2))
            nn.init.constant_(i.bias, val=0)
        nn.init.kaiming_normal_(self.projection.weight.data)

    def forward(self,x_w,x_c):
        x_w = self.words(x_w)
        x_c = self.chars(x_c)
        x = torch.cat([x_w,x_c],-1)
        for i in range(2):
            h = self.highway[i](x)
            new_x, a = h.chunk(2,-1)
            x = a * x + (1-a) * new_x
        return x_w + x_c
    
class Coattention(nn.Module):
    def __init__(self, state_size):
        super(Coattention, self).__init__()
        self.state_size = state_size
        self.transform = nn.Sequential(nn.Linear(state_size,state_size),nn.ReLU())
        self.scale = Parameter(torch.Tensor([1.0]))
        self.reset_parameter()

    def reset_parameter(self):
        for i in self.transform:
            if isinstance(i,nn.Linear):
                nn.init.xavier_normal_(i.weight,gain=math.sqrt(2))

    def softmax_with_len_3d(self, logits, q_lengths,k_lengths):
        q_masks = mask_lengths(q_lengths)
        k_masks = mask_lengths(k_lengths)
        masks = k_masks.unsqueeze(-1) * q_masks.unsqueeze(-2)
        smx = F.softmax(logits, -1)
        masked = masks * smx
        normalizing_factor = torch.sum(masked, dim=-1).unsqueeze(-1)
        att_score = masked / (normalizing_factor + 1e-15)
        return att_score

    def mask_lengths(self, lengths, max_len=None):
        device = lengths.device
        if not max_len:
            max_len = torch.max(lengths).item()
        idxes = torch.arange(0, max_len, out=torch.LongTensor(max_len)).unsqueeze(0).to(device)
        masks = (idxes < lengths.unsqueeze(1)).float()
        return masks

    def forward(self, k, q, k_lengths, q_lengths):
        q_projected = self.transform(q)
        affinity = torch.bmm(k, q_projected.transpose(1, 2)) * self.scale
        att_q = softmax_with_len_3d(affinity.transpose(1, 2), q_lengths, k_lengths)
        att_k = (affinity, q_lengths, k_lengths)
        new_q = torch.bmm(att_q, k)
        new_k = torch.bmm(att_k,q)

        new_q = torch.cat([q,new_q, q*new_q],-1) #유동적으로 변경 가능
        new_k = torch.cat([k,new_k,k*new_k],-1)

        return new_q,new_k

def count_parameters(model):
    return sum(p.numel() for p in model.parameters() if p.requires_grad)

if __name__ == '__main__':
    x = torch.zeros(3, 5, 10)
    x[2, :3, :] = torch.Tensor(torch.stack([torch.ones(10) * (i + 1) for i in range(3)], 0))
    x[1, :4, :] = torch.Tensor(torch.stack([torch.ones(10) * (i + 1) for i in range(4)], 0))
    x[0, :, :] = torch.Tensor(torch.stack([torch.ones(10) * (i + 1) for i in range(5)], 0))
    l = torch.Tensor([5,4,3]).long()
    print(x)
    print(reverse_sequence(x,l))