import math
import torch.nn as nn
from torch.autograd import Variable
from torch.nn.parameter import Parameter
from torch.nn.modules.rnn import RNNCellBase
import torch
import torch.nn.functional as F
from layers import *


class PCA_RNN(nn.Module):

    """A module that runs multiple steps of LSTM."""

    def __init__(self, cell_class, input_size, hidden_size, num_layers=1,
                 use_bias=True,bidirectional=False, batch_first=False, dropout=0, **kwargs):
        super(PCA_RNN, self).__init__()
        self.cell_class = cell_class
        self.input_size = input_size
        self.hidden_size = hidden_size
        self.num_layers = num_layers
        self.use_bias = use_bias
        self.batch_first = batch_first
        self.dropout = dropout
        self.bidirectional = bidirectional
        self.num_cell = 2 if self.bidirectional else 1

        for layer in range(num_layers):
            layer_input_size = input_size if layer == 0 else hidden_size * self.num_cell
            cell = cell_class(input_size=layer_input_size,
                              hidden_size=hidden_size,
                              **kwargs)
            setattr(self, 'cell_{}'.format(layer), cell)
        self.dropout_layer = nn.Dropout(dropout)
        self.reset_parameters()

    def get_cell(self, layer):
        return getattr(self, 'cell_{}'.format(layer))

    def reset_parameters(self):
        for layer in range(self.num_layers):
            cell = self.get_cell(layer)
            cell.reset_parameters()

    def _forward_rnn(self,cell, input_, length, hx):
        max_time = input_.size(0)
        output = torch.Tensor()
        output = output.cuda()
        for time in range(max_time):
            h_next = cell(input_[time], output,length)   # (input_, memory_tape, length):
            mask = (time < length).float().unsqueeze(1).expand_as(h_next)
            if input_.is_cuda:
                mask = mask.cuda()
            h_next = h_next*mask + hx*(1 - mask)
            output = torch.cat([output, h_next.unsqueeze(1)], 1)
            hx = h_next
        return output, hx

    def forward(self, input_, length=None, hx=None):
        if self.batch_first:
            input_ = input_.transpose(0, 1)
        max_time, batch_size, _ = input_.size()
        if length is None:
            length = Variable(torch.LongTensor([max_time] * batch_size))
            if input_.is_cuda:
                device = input_.get_device()
                length = length.cuda(device)
        if hx is None:
            hx = Variable(input_.data.new(batch_size, self.hidden_size).zero_())
            if input_.is_cuda:
                hx = hx.cuda()
        h_n = []
        layer_output = None
        for layer in range(self.num_layers):
            cell = self.get_cell(layer)
            layer_output, layer_h_n = self._forward_rnn(
                cell=cell, input_=input_, length=length, hx=hx)
            input_ = self.dropout_layer(layer_output)
            h_n.append(layer_h_n)
        output = layer_output
        if not self.batch_first:
            output = output.transpose(0, 1)
        h_n = torch.stack(h_n, 0)
        return output, h_n


class PA_RNN(nn.Module):

    """A module that runs multiple steps of LSTM."""

    def __init__(self, cell_class, input_size, hidden_size, num_layers=1,
                 use_bias=True,bidirectional=False, batch_first=False, dropout=0, **kwargs):
        super(PA_RNN, self).__init__()
        self.cell_class = cell_class
        self.input_size = input_size
        self.hidden_size = hidden_size
        self.num_layers = num_layers
        self.use_bias = use_bias
        self.batch_first = batch_first
        self.dropout = dropout
        self.bidirectional = bidirectional
        self.num_cell = 2 if self.bidirectional else 1

        self.projection = nn.Linear(input_size,hidden_size)

        # self.projection_weight = Parameter(torch.Tensor(self.hidden_size, self.num_cell*self.input_size )) #+ self.num_cell*self.hidden_size
        # self.projection_bias = Parameter(torch.Tensor(self.hidden_size))
        # self.hidden_projection_weight = Parameter(torch.Tensor(self.hidden_size,self.hidden_size))
        # self.hidden_projection_bias = Parameter(torch.Tensor(self.hidden_size))


        for layer in range(num_layers):
            layer_input_size = input_size if layer == 0 else hidden_size * self.num_cell
            cell = cell_class(input_size=layer_input_size,
                              hidden_size=hidden_size,
                              **kwargs)
            setattr(self, 'cell_{}'.format(layer), cell)
        self.dropout_layer = nn.Dropout(dropout)
        self.reset_parameters()

    def get_cell(self, layer):
        return getattr(self, 'cell_{}'.format(layer))

    def reset_parameters(self):
        nn.init.xavier_uniform_(self.projection.weight)
        for layer in range(self.num_layers):
            cell = self.get_cell(layer)
            cell.reset_parameters()
        # nn.init.xavier_uniform_(self.projection_weight)
        # nn.init.xavier_uniform_(self.hidden_projection_weight)
        # nn.init.constant_(self.projection_bias,val=0.0)
        # nn.init.constant_(self.hidden_projection_bias,val=0.0)

    def _forward_rnn(self,cell, input_, length, hx):
        device = input_.device
        batch_size = input_.size(1)
        max_time = input_.size(0)
        output = []
        # output = torch.zeros(batch_size,max_time,self.hidden_size).to(device)
        for time in range(max_time):
            h_candidate = self.context_gen(input_[time], output, hx, length)
            if h_candidate is not None:
                hx = h_candidate
            h_next = cell(input_=input_[time], hx=hx)
            mask = (time < length).float().unsqueeze(1).expand_as(h_next)
            # print(h_next.unsqueeze(1).size())
            output.append(h_next)
            # output = torch.cat([output,h_next.unsqueeze(1)],1)
            h_next = h_next*mask + hx*(1 - mask)
            hx = h_next
        output = torch.stack(output,1)
        return output, hx

    def context_gen(self, input_,output, last_state, length):
        """
        :param input_:  Tensor (batch, state)
        :param outputs: [ Tensor(batch,1,state), Tensor(...), ... ]
        :param length: Tensor(batch)
        :return:
        """
        device = input_.device
        if len(output) ==0:
            return None
        else:
            output = torch.stack(output,1) #[b, l, h]
            time = output.shape[1]
            time = torch.Tensor([time]).long().to(device)
            length = torch.min(length,time).to(device)
            # context = torch.cat([input_,last_state],1)
            # context = input_
            input_projected = torch.tanh(self.projection(input_).unsqueeze(1)) # [b,1, h]
            att_logits = torch.bmm(input_projected, output.transpose(1,2)).squeeze(1)  #[b,l]
            att_score = softmax_with_len(att_logits,length).unsqueeze(1) #[b, 1, l]
            context = torch.bmm(att_score,output).squeeze(1)
            context = context.to(device)
            return context

    def location_bias(self,lengths,max_len):
        loc = torch.arange(max_len).unsqueeze(0).cuda()
        distance = lengths.unsqueeze(1).float() - loc
        distance = torch.max(distance,torch.zeros_like(distance))
        distance /= 10
        one_distance = torch.ones_like(distance) - distance
        one_distance = torch.max(torch.zeros_like(one_distance),one_distance)
        loc_bias = torch.exp(one_distance)
        return loc_bias


    def forward(self, input_, length=None, hx=None):
        if self.batch_first:
            input_ = input_.transpose(0, 1)
        max_time, batch_size, _ = input_.size()
        if length is None:
            length = Variable(torch.LongTensor([max_time] * batch_size))
            if input_.is_cuda:
                device = input_.get_device()
                length = length.to(device)
        if hx is None:
            hx = Variable(input_.data.new(batch_size, self.hidden_size).zero_())
            if input_.is_cuda:
                device = input_.get_device()
                hx = hx.to(device)
        h_n = []
        layer_output = None
        for layer in range(self.num_layers):
            cell = self.get_cell(layer)
            layer_output, layer_h_n = self._forward_rnn(
                cell=cell, input_=input_, length=length, hx=hx)
            input_ = self.dropout_layer(layer_output)
            h_n.append(layer_h_n)
        output = layer_output
        if not self.batch_first:
            output = output.transpose(0, 1)
        h_n = torch.stack(h_n, 0)
        return output, h_n


class Custom_RNN(nn.Module):

    """A module that runs multiple steps of RNN."""

    def __init__(self, cell_class, input_size, hidden_size, num_layers=1,
                 use_bias=True, bidirectional=False, batch_first=False, dropout=0, **kwargs):
        super(Custom_RNN, self).__init__()
        self.cell_class = cell_class
        self.input_size = input_size
        self.hidden_size = hidden_size
        self.num_layers = num_layers
        self.use_bias = use_bias
        self.batch_first = batch_first
        self.dropout = dropout
        self.bidirectional = bidirectional
        self.num_cell = 2 if self.bidirectional else 1
        for layer in range(num_layers):
            layer_input_size = input_size if layer == 0 else hidden_size * self.num_cell
            cell = cell_class(input_size=layer_input_size,
                              hidden_size=hidden_size,
                              **kwargs)
            setattr(self, 'cellfw_{}'.format(layer), cell)
        if bidirectional:
            for layer in range(num_layers):
                layer_input_size = input_size if layer == 0 else hidden_size * self.num_cell
                cell = cell_class(input_size=layer_input_size,
                                  hidden_size=hidden_size,
                                  **kwargs)
                setattr(self, 'cellbw_{}'.format(layer), cell)
        self.dropout_layer = nn.Dropout(dropout)
        self.reset_parameters()

    def iscell(self):
        return getattr(self, 'cellfw_{}'.format(0)).iscell

    def get_forward_cell(self, layer):
        return getattr(self, 'cellfw_{}'.format(layer))

    def get_backward_cell(self, layer):
        return getattr(self, 'cellbw_{}'.format(layer))

    def reset_parameters(self):
        for layer in range(self.num_layers):
            cell = self.get_forward_cell(layer)
            cell.reset_parameters()
        if self.bidirectional:
            for layer in range(self.num_layers):
                cell = self.get_backward_cell(layer)
                cell.reset_parameters()

    @staticmethod
    def _forward_rnn(cell, input_, batch_sizes, hx):
        max_batch = batch_sizes[0]
        device = input_.device
        output = torch.Tensor().to(device)
        index = 0
        old_batch_size =None
        for new_batch_size in batch_sizes:
            if old_batch_size and new_batch_size != old_batch_size:
                if cell.iscell:
                    hx = (hx[0][:new_batch_size,:],hx[1][:new_batch_size,:])
                else:
                    hx = hx[:new_batch_size,:]
            # print(hx)
            input_timestep = input_[index:index+new_batch_size]
            index += new_batch_size
            # print(input_timestep.size(),hx.size())
            h_next = cell(input_=input_timestep, hx=hx)
            if cell.iscell:
                output = torch.cat([output, h_next[0]], 0)
            # output = torch.cat([output,torch.cat([h_next,torch.zeros(batch_diff,hidden_size).to(device)],0)],0)
            else:
                output = torch.cat([output, h_next], 0)
            hx = h_next
            old_batch_size = new_batch_size
        # print('finished')
        # print(output)
        return output

    @staticmethod
    def _backward_rnn(cell, input_, batch_sizes, hx):
        if cell.iscell:
            hidden_size = hx[0].size(-1)
        else:
            hidden_size = hx.size(-1)
        device = input_.device
        output = torch.Tensor().to(device)
        index = 0
        old_batch_size = None
        for i in range(len(batch_sizes)):
            new_batch_size = batch_sizes[-1-i]
            if index ==0:
                input_timestep = input_[index-new_batch_size:]
            else:
                input_timestep = input_[index - new_batch_size:index]
            if old_batch_size and new_batch_size != old_batch_size:
                dec = new_batch_size - old_batch_size
                if cell.iscell:
                    hx = (torch.cat([hx[0], torch.zeros(dec, hidden_size).to(device)], 0),
                          torch.cat([hx[1], torch.zeros(dec, hidden_size).to(device)], 0))
                else:
                    hx = torch.cat([hx, torch.zeros(dec, hidden_size).to(device)], 0)
            index -= new_batch_size
            h_next = cell(input_=input_timestep, hx=hx)
            if cell.iscell:
                output = torch.cat([h_next[0],output], 0)
            else:
                output = torch.cat([h_next, output], 0)
            hx = h_next
            old_batch_size = new_batch_size
        return output

    def forward(self, packed_sequence, hx_forward=None, hx_backward=None):
        input_, batch_sizes = packed_sequence
        device = input_.device
        _ = torch.Tensor()
        max_batch = batch_sizes[0]
        min_batch = batch_sizes[-1]
            # hx_forward = Variable(input_.data.new(max_batch, self.hidden_size).zero_()).to(device)
        layer_output = None
        for layer in range(self.num_layers):
            cell_fw = self.get_forward_cell(layer)
            if hx_forward is None:
                if cell_fw.iscell:
                    hx_forward = (torch.zeros(max_batch, self.hidden_size).to(device),
                                  torch.zeros(max_batch, self.hidden_size).to(device))
                else:
                    hx_forward = torch.zeros(max_batch, self.hidden_size).to(device)
            forward_output = Custom_RNN._forward_rnn(
                cell=cell_fw, input_=input_, batch_sizes=batch_sizes, hx=hx_forward)
            layer_output = forward_output
            if self.bidirectional:
                cell_bw = self.get_backward_cell(layer)
                if hx_backward is None:
                    if cell_bw.iscell:
                        hx_backward = (torch.zeros(min_batch, self.hidden_size).to(device),
                                            torch.zeros(min_batch, self.hidden_size).to(device))
                    else:
                        hx_backward = torch.zeros(min_batch, self.hidden_size).to(device)
                backward_output = Custom_RNN._backward_rnn(
                    cell=cell_bw, input_=input_, batch_sizes=batch_sizes, hx=hx_backward)
                layer_output = torch.cat([layer_output,backward_output],-1)
            input_ = self.dropout_layer(layer_output)
        output = layer_output
        return nn.utils.rnn.PackedSequence(output,batch_sizes), _


# class PA_RNN(nn.Module):
#
#     """A module that runs multiple steps of LSTM."""
#
#     def __init__(self, cell_class, input_size, hidden_size, num_layers=1,
#                  use_bias=True,bidirectional=False, batch_first=False, dropout=0, **kwargs):
#         super(PA_RNN, self).__init__()
#         self.cell_class = cell_class
#         self.input_size = input_size
#         self.hidden_size = hidden_size
#         self.num_layers = num_layers
#         self.use_bias = use_bias
#         self.batch_first = batch_first
#         self.dropout = dropout
#         self.bidirectional = bidirectional
#         self.num_cell = 2 if self.bidirectional else 1
#
#         self.projection_weight = Parameter(torch.Tensor(self.hidden_size, self.num_cell*self.input_size + self.num_cell*self.hidden_size))
#         self.projection_bias = Parameter(torch.Tensor(self.hidden_size))
#         self.hidden_projection_weight = Parameter(torch.Tensor(self.hidden_size,self.hidden_size))
#         self.hidden_projection_bias = Parameter(torch.Tensor(self.hidden_size))
#         self.att_context = Parameter(torch.Tensor(self.hidden_size))
#
#
#         for layer in range(num_layers):
#             layer_input_size = input_size if layer == 0 else hidden_size * self.num_cell
#             cell = cell_class(input_size=layer_input_size,
#                               hidden_size=hidden_size,
#                               **kwargs)
#             setattr(self, 'cell_{}'.format(layer), cell)
#         self.dropout_layer = nn.Dropout(dropout)
#         self.reset_parameters()
#
#     def get_cell(self, layer):
#         return getattr(self, 'cell_{}'.format(layer))
#
#     def reset_parameters(self):
#         for layer in range(self.num_layers):
#             cell = self.get_cell(layer)
#             cell.reset_parameters()
#         nn.init.xavier_uniform_(self.projection_weight)
#         nn.init.xavier_uniform_(self.hidden_projection_weight)
#         nn.init.constant_(self.projection_bias,val=0.0)
#         nn.init.constant_(self.hidden_projection_bias,val=0.0)
#         nn.init.uniform_(self.att_context,-np.sqrt(6.0 / (2 * self.hidden_size)),
#                           np.sqrt(6.0 / (2 * self.hidden_size)))
#
#     def _forward_rnn(self,cell, input_, length, hx):
#         max_time = input_.size(0)
#         output = torch.Tensor()
#         output = output.cuda()
#         for time in range(max_time):
#             h_candidate = self.context_gen(input_[time], output, hx, length)
#             if h_candidate is not None:
#                 hx = h_candidate
#             h_next = cell(input_=input_[time], hx=hx)
#             mask = (time < length).float().unsqueeze(1).expand_as(h_next)
#             if input_.is_cuda:
#                 mask = mask.cuda()
#             output = torch.cat([output,h_next.unsqueeze(1)],1)
#
#             h_next = h_next*mask + hx*(1 - mask)
#             hx = h_next
#         return output, hx
#
#     def context_gen(self, input_,output, last_state, length):
#         """
#         :param input_:  Tensor (batch, state)
#         :param outputs: [ Tensor(batch,1,state), Tensor(...), ... ]
#         :param length: Tensor(batch)
#         :return:
#         """
#         if output.dim() == 1 :
#             return None
#         else:
#             time = output.shape[1]
#             time = torch.Tensor([time]).long()
#             length = torch.min(length,time).cuda()
#             masked = mask_lengths(length)
#             context = torch.cat([input_,last_state],1)
#             # context = input_
#             context_projected = F.linear(context,self.projection_weight,self.projection_bias)
#             # context_projected = F.tanh(context_projected) # [batch, state]
#             output = F.linear(output,self.hidden_projection_weight,self.hidden_projection_bias)
#             att_context = context_projected.unsqueeze(1) + output # [batch, len, state]
#             att_context = torch.mul(att_context,self.att_context.unsqueeze(0).unsqueeze(1))
#             att_logits = torch.sum(att_context,dim=2)
#
#             # att_logits = torch.mul(output,context_projected.unsqueeze(1)) # [batch, len, state]
#             # att_logits = torch.sum(att_logits,dim=2) # [ batch, len]
#             exp = torch.exp(att_logits)
#             exp = masked * exp
#             normalizing_factor = torch.sum(exp,dim=1).unsqueeze(1)  #[batch, 1]
#             att_score = exp / normalizing_factor
#
#             context = att_score.unsqueeze(2) * output
#             context = torch.sum(context,dim=1)
#             if input_.is_cuda:
#                 context = context.cuda()
#             return context
#
#
#     def forward(self, input_, length=None, hx=None):
#         if self.batch_first:
#             input_ = input_.transpose(0, 1)
#         max_time, batch_size, _ = input_.size()
#         if length is None:
#             length = Variable(torch.LongTensor([max_time] * batch_size))
#             if input_.is_cuda:
#                 device = input_.get_device()
#                 length = length.cuda(device)
#         if hx is None:
#             hx = Variable(input_.data.new(batch_size, self.hidden_size).zero_())
#             if input_.is_cuda:
#                 hx = hx.cuda()
#         h_n = []
#         layer_output = None
#         for layer in range(self.num_layers):
#             cell = self.get_cell(layer)
#             layer_output, layer_h_n = self._forward_rnn(
#                 cell=cell, input_=input_, length=length, hx=hx)
#             input_ = self.dropout_layer(layer_output)
#             h_n.append(layer_h_n)
#         output = layer_output
#         if not self.batch_first:
#             output = output.transpose(0, 1)
#         h_n = torch.stack(h_n, 0)
#         return output, h_n


# class Custom_RNN(nn.Module):
#
#     """A module that runs multiple steps of LSTM."""
#
#     def __init__(self, cell_class, input_size, hidden_size, num_layers=1,
#                  use_bias=True,bidirectional=False, batch_first=False, dropout=0, **kwargs):
#         super(Custom_RNN, self).__init__()
#         self.cell_class = cell_class
#         self.input_size = input_size
#         self.hidden_size = hidden_size
#         self.num_layers = num_layers
#         self.use_bias = use_bias
#         self.batch_first = batch_first
#         self.dropout = dropout
#         self.bidirectional = bidirectional
#         self.num_cell = 2 if self.bidirectional else 1
#
#         for layer in range(num_layers):
#             layer_input_size = input_size if layer == 0 else hidden_size * self.num_cell
#             cell = cell_class(input_size=layer_input_size,
#                               hidden_size=hidden_size,
#                               **kwargs)
#             setattr(self, 'cell_{}'.format(layer), cell)
#         self.dropout_layer = nn.Dropout(dropout)
#         self.reset_parameters()
#
#     def get_cell(self, layer):
#         return getattr(self, 'cell_{}'.format(layer))
#
#     def reset_parameters(self):
#         for layer in range(self.num_layers):
#             cell = self.get_cell(layer)
#             cell.reset_parameters()
#
#     @staticmethod
#     def _forward_rnn(cell, input_, length, hx):
#         device = input_.device
#         max_time = input_.size(0)
#         output = []
#         for time in range(max_time):
#             h_next = cell(input_=input_[time], hx=hx)
#             mask = (time < length).float().unsqueeze(1).expand_as(h_next).to(device)
#             h_next = h_next*mask + hx*(1 - mask)
#             output.append(h_next)
#             hx = h_next
#         output = torch.stack(output, 0)
#         return output, hx
#
#     def forward(self, input_, length=None, hx=None):
#         device = input_.device
#         if self.batch_first:
#             input_ = input_.transpose(0, 1)
#         max_time, batch_size, _ = input_.size()
#         if length is None:
#             length = Variable(torch.LongTensor([max_time] * batch_size))
#             if input_.is_cuda:
#                 device = input_.get_device()
#                 length = length.cuda(device)
#         if hx is None:
#             hx = Variable(input_.data.new(batch_size, self.hidden_size).zero_()).to(device)
#
#         h_n = []
#         layer_output = None
#         for layer in range(self.num_layers):
#             cell = self.get_cell(layer)
#             layer_output, layer_h_n = Custom_RNN._forward_rnn(
#                 cell=cell, input_=input_, length=length, hx=hx)
#             input_ = self.dropout_layer(layer_output)
#             h_n.append(layer_h_n)
#         output = layer_output
#         if self.batch_first:
#             output = output.transpose(0, 1)
#         h_n = torch.stack(h_n, 0)
#         return output, h_n


class Selective_RNN(nn.Module):

    """A module that runs multiple steps of LSTM."""

    def __init__(self, cell_class, input_size, hidden_size, num_layers=1,
                 use_bias=True,bidirectional=False, batch_first=False, dropout=0, **kwargs):
        super(Selective_RNN, self).__init__()
        self.cell_class = cell_class
        self.input_size = input_size
        self.hidden_size = hidden_size
        self.num_layers = num_layers
        self.use_bias = use_bias
        self.batch_first = batch_first
        self.dropout = dropout
        self.bidirectional = bidirectional
        self.num_cell = 2 if self.bidirectional else 1
        self.update_seceltion = nn.Linear(input_size,2)
        self.avg_projection = nn.Sequential(nn.Linear(input_size,input_size),nn.Tanh())

        for layer in range(num_layers):
            layer_input_size = input_size if layer == 0 else hidden_size * self.num_cell
            cell = cell_class(input_size=layer_input_size,
                              hidden_size=hidden_size,
                              **kwargs)
            setattr(self, 'cell_{}'.format(layer), cell)
        self.dropout_layer = nn.Dropout(dropout)
        self.reset_parameters()

    def get_cell(self, layer):
        return getattr(self, 'cell_{}'.format(layer))

    def reset_parameters(self):
        for layer in range(self.num_layers):
            cell = self.get_cell(layer)
            cell.reset_parameters()
        nn.init.xavier_normal_(self.update_seceltion.weight)
        nn.init.constant_(self.update_seceltion.bias,val=0)
        for i in self.avg_projection:
            if isinstance(i,nn.Linear):
                nn.init.xavier_normal_(i.weight,gain=5.0/3)
                nn.init.constant_(i.bias,val=0)

    @staticmethod
    def _forward_rnn(cell, input_, update_probs, length, hx):
        device = input_.device
        max_time = input_.size(0)
        output = []
        for time in range(max_time):
            h_next = cell(input_=input_[time],hx=hx, update_probs=update_probs[time])
            mask = (time < length).float().unsqueeze(1).expand_as(h_next).to(device)
            h_next = h_next*mask + hx*(1 - mask)
            output.append(h_next)
            hx = h_next
        output = torch.stack(output, 0)
        return output, hx

    def forward(self, input_, length=None, hx=None):
        device = input_.device
        b, l, h = input_.size()
        mask = mask_lengths(length).unsqueeze(-1) # [b, l]
        input_ = input_* mask
        avg = torch.sum(input_, dim=1) /length.float().unsqueeze(-1)
        avg_projected = self.avg_projection(avg) #[b,h]
        logits = avg_projected.unsqueeze(1) * input_
        update_selection = self.update_seceltion(logits) #[b,l,2]
        update_probs = F.gumbel_softmax(update_selection.view(b*l,2)).view(b,l,2)
        # update_probs = hard_softmax(update_selection.view(b * l, 2)).view(b, l, 2)
        if self.batch_first:
            input_ = input_.transpose(0, 1)
            update_probs = update_probs.transpose(0,1)

        max_time, batch_size, _ = input_.size()
        if length is None:
            length = Variable(torch.LongTensor([max_time] * batch_size))
            if input_.is_cuda:
                device = input_.get_device()
                length = length.cuda(device)
        if hx is None:
            hx = Variable(input_.data.new(batch_size, self.hidden_size).zero_()).to(device)

        h_n = []
        layer_output = None
        for layer in range(self.num_layers):
            cell = self.get_cell(layer)
            layer_output, layer_h_n = Selective_RNN._forward_rnn(
                cell=cell, input_=input_, update_probs=update_probs, length=length, hx=hx)
            input_ = self.dropout_layer(layer_output)
            h_n.append(layer_h_n)
        output = layer_output
        if self.batch_first:
            output = output.transpose(0, 1)
        h_n = torch.stack(h_n, 0)
        return output, h_n


class ARNCell(RNNCellBase):
    def __init__(self, input_size, hidden_size, bias=True):
        super(ARNCell, self).__init__()
        self.input_size = input_size
        self.hidden_size = hidden_size
        self.bias = bias
        self.iscell = False
        self.sigmoid = torch.sigmoid
        self.tanh = torch.tanh
        # self.input_trans = nn.Linear(input_size,hidden_size,False)
        # self.input_gate = nn.Linear(input_size + hidden_size, hidden_size)
        self.weight_ih = Parameter(torch.Tensor(2 *hidden_size, input_size))
        self.weight_hh = Parameter(torch.Tensor(hidden_size, hidden_size))
        self.weight_ar = Parameter(torch.Tensor(hidden_size))
        if bias:
            self.bias_ih = Parameter(torch.Tensor(2 * hidden_size))
            self.bias_hh = Parameter(torch.Tensor(hidden_size))
        else:
            self.register_parameter('bias_ih', None)
            self.register_parameter('bias_hh', None)
        self.reset_parameters()

    def reset_parameters(self):
        # std = math.sqrt(1/self.hidden_size)
        # nn.init.normal_(self.weight_ih,std=std)
        # nn.init.normal_(self.weight_hh,std=std)
        nn.init.xavier_uniform_(self.weight_ih)
        nn.init.xavier_uniform_(self.weight_hh)
        # nn.init.xavier_normal_(self.input_trans.weight)
        # nn.init.constant_(self.input_trans.bias,val=0)
        # nn.init.xavier_normal_(self.input_gate.weight,gain=1)
        # nn.init.constant_(self.input_gate.bias,val=0)
        nn.init.constant_(self.weight_ar,val=0.95)
        #self.weight_ar.data = self.weight_ar.data.clamp(max=1.0)
        if self.bias:
            nn.init.constant_(self.bias_hh,val=0)
            nn.init.constant_(self.bias_ih,val=0)

    def reset_hidden(self):
        self.hidden = None

    def detach_hidden(self):
        self.hidden.detach_()

    def forward(self, input_, hx):
        """
        :param input_: [batch_size, input_size]
        :param hx: (h_0) which contains the initial hidden and cell state which size is (batch_size, hidden_size)
        :return: (h_1)
        """
        batch_size = hx.size(0)

        gi = F.linear(input_.view(batch_size, -1), self.weight_ih, self.bias_ih)
        gh = F.linear(hx, self.weight_hh, self.bias_hh)
        i_i, i_n = gi.chunk(2, 1)
        h_i = gh
        inputgate = torch.sigmoid(i_i + h_i)
        new_input = i_n

        # new_input = self.input_trans(input_)
        # inputgate = F.relu6(self.input_gate(torch.cat([input_,hx],-1)))

        hidden = inputgate * new_input + self.weight_ar*(hx)
        return hidden


class CARNCell(RNNCellBase):
    def __init__(self, input_size, hidden_size, bias=True):
        super(CARNCell, self).__init__()
        self.input_size = input_size
        self.hidden_size = hidden_size
        self.bias = bias
        self.iscell = True
        self.input_trans = nn.Linear(input_size,hidden_size,True)
        self.update = nn.Linear(input_size + hidden_size, hidden_size)
        self.receive = nn.Linear(input_size + hidden_size, hidden_size)
        self.weight_ar = Parameter(torch.Tensor(hidden_size))
        # if bias:
        #     self.bias_ih = Parameter(torch.Tensor(2 * hidden_size))
        #     self.bias_hh = Parameter(torch.Tensor(hidden_size))
        # else:
        #     self.register_parameter('bias_ih', None)
        #     self.register_parameter('bias_hh', None)
        self.reset_parameters()

    def reset_parameters(self):
        # std = math.sqrt(1/self.hidden_size)
        # nn.init.normal_(self.weight_ih,std=std)
        # nn.init.normal_(self.weight_hh,std=std)
        # nn.init.xavier_uniform_(self.weight_ih)
        # nn.init.xavier_uniform_(self.weight_hh)
        nn.init.xavier_normal_(self.input_trans.weight)
        nn.init.xavier_normal_(self.update.weight,gain=1)
        nn.init.xavier_normal_(self.receive.weight,gain=1)
        nn.init.constant_(self.input_trans.bias,val=0)
        nn.init.constant_(self.update.bias,val=0)
        nn.init.constant_(self.receive.bias,val=0)
        nn.init.constant_(self.weight_ar,val=0.95)
        #self.weight_ar.data = self.weight_ar.data.clamp(max=1.0)
        # if self.bias:
        #     nn.init.constant_(self.bias_hh,val=0)
        #     nn.init.constant_(self.bias_ih,val=0)

    def reset_hidden(self):
        self.hidden = None

    def detach_hidden(self):
        self.hidden.detach_()

    def forward(self, input_, hx):
        """
        :param input_: [batch_size, input_size]
        :param hx: (h_0) which contains the initial hidden and cell state which size is (batch_size, hidden_size)
        :return: (h_1)
        """
        h, c = hx
        # gi = F.linear(input_.view(batch_size, -1), self.weight_ih, self.bias_ih)
        # gh = F.linear(hx, self.weight_hh, self.bias_hh)
        # i_i, i_n = gi.chunk(2, 1)
        # h_i = gh
        # inputgate = self.sigmoid(i_i + h_i)
        # new_input = i_n
        # new_input = input_

        hc = torch.cat([h,c],-1)
        new_input = self.input_trans(input_)
        update = torch.sigmoid(self.update(hc))
        receive = torch.sigmoid(self.receive(hc))
        new_h = receive * c + (1-receive)* new_input
        new_c = update * new_input + c
        new_c = self.weight_ar * new_c
        return (new_h,new_c)



class BasicRNNCell(RNNCellBase):
    def __init__(self, input_size, hidden_size, bias=True):
        super(BasicRNNCell, self).__init__()
        self.input_size = input_size
        self.hidden_size = hidden_size
        self.bias = bias
        self.sigmoid = torch.sigmoid
        self.tanh = torch.tanh
        self.weight_ih = Parameter(torch.Tensor(hidden_size, input_size))
        self.weight_hh = Parameter(torch.Tensor(hidden_size, hidden_size))
        if bias:
            self.bias_ih = Parameter(torch.Tensor(hidden_size))
            self.bias_hh = Parameter(torch.Tensor(hidden_size))
        else:
            self.register_parameter('bias_ih', None)
            self.register_parameter('bias_hh', None)
        self.reset_parameters()

    def reset_parameters(self):
        nn.init.xavier_uniform_(self.weight_ih)
        nn.init.xavier_uniform_(self.weight_hh)
        #self.weight_ar.data = self.weight_ar.data.clamp(max=1.0)
        if self.bias:
            nn.init.constant_(self.bias_hh,val=0)
            nn.init.constant_(self.bias_ih,val=0)

    def reset_hidden(self):
        self.hidden = None

    def detach_hidden(self):
        self.hidden.detach_()

    def forward(self, input_, hx):
        """

        :param input_: [batch_size, input_size]
        :param hx: (h_0) which contains the initial hidden and cell state which size is (batch_size, hidden_size)
        :return: (h_1)
        """

        hi = F.linear(input_, self.weight_ih, self.bias_ih)
        hh = F.linear(hx, self.weight_hh, self.bias_hh)

        hidden = torch.tanh(hi+hh)
        return hidden


class CFNCell(RNNCellBase):
    def __init__(self, input_size, hidden_size, bias=True):
        super(CFNCell, self).__init__()
        self.input_size = input_size
        self.hidden_size = hidden_size
        self.bias = bias
        self.sigmoid = torch.sigmoid
        self.tanh = torch.tanh
        self.weight_ih = Parameter(torch.Tensor(3 * hidden_size, input_size))
        self.weight_hh = Parameter(torch.Tensor(2 * hidden_size, hidden_size))
        if bias:
            self.bias_ih = Parameter(torch.Tensor(3 * hidden_size))
            self.bias_hh = Parameter(torch.Tensor(2 * hidden_size))
        else:
            self.register_parameter('bias_ih', None)
            self.register_parameter('bias_hh', None)
        self.reset_parameters()

    def reset_parameters(self):
        nn.init.xavier_uniform_(self.weight_ih)
        nn.init.xavier_uniform_(self.weight_hh)
        if self.bias:
            nn.init.constant_(self.bias_hh,val=0)
            nn.init.constant_(self.bias_ih,val=0)

    def reset_hidden(self):
        self.hidden = None

    def detach_hidden(self):
        self.hidden.detach_()

    def forward(self, input_, hx):
        """

        :param input_: [batch_size, input_size]
        :param hx: (h_0) which contains the initial hidden and cell state which size is (batch_size, hidden_size)
        :return: (h_1)
        """
        batch_size = hx.size(0)

        gi = F.linear(input_.view(batch_size, -1), self.weight_ih, self.bias_ih)
        gh = F.linear(hx, self.weight_hh, self.bias_hh)
        i_i, i_f, i_n = gi.chunk(3, 1)
        h_i, h_f = gh.chunk(2,1)
        inputgate = self.sigmoid(i_i + h_i)
        forgetgate = self.sigmoid(i_f + h_f)
        new_input = i_n

        hidden = inputgate * self.tanh(new_input) + forgetgate * self.tanh(hx)
        return hidden


class GRUCell(RNNCellBase):
    def __init__(self, input_size, hidden_size, bias=True):
        super(GRUCell, self).__init__()
        self.input_size = input_size
        self.hidden_size = hidden_size
        self.bias = bias
        self.iscell = False
        self.gate_w = Parameter(torch.Tensor(2*hidden_size, input_size + hidden_size))
        self.candidate_w = Parameter(torch.Tensor(hidden_size,input_size + hidden_size))
        if bias:
            self.gate_b = Parameter(torch.Tensor(2 * hidden_size))
            self.candidate_b = Parameter(torch.Tensor(hidden_size))
        else:
            self.register_parameter('gate_b', None)
            self.register_parameter('candidate_b', None)
        self.reset_parameters()

    def reset_parameters(self):
        # stv = math.sqrt(self.input_size + self.hidden_size)
        # nn.init.uniform_(self.gate_w,-stv,stv)
        # nn.init.uniform_(self.candidate_w, -stv, stv)
        stv = math.sqrt(2.0 / (self.input_size + self.hidden_size))
        nn.init.normal_(self.gate_w,std = stv)
        nn.init.normal_(self.candidate_w,std = (5.0/3) *stv)
        # nn.init.xavier_uniform_(self.gate_w)
        # nn.init.xavier_uniform_(self.candidate_w)
        if self.bias:
            nn.init.constant_(self.gate_b,val=1.0)
            nn.init.constant_(self.candidate_b,val=0)

    def reset_hidden(self):
        self.hidden = None

    def detach_hidden(self):
        self.hidden.detach_()

    def forward(self, input_, hx):
        """
        :param input_: [batch_size, input_size]
        :param hx: (h_0) which contains the initial hidden and cell state which size is (batch_size, hidden_size)
        :return: (h_1)
        """
        batch_size = hx.size(0)
        concated = torch.cat([input_,hx],1)

        gi = F.linear(concated, self.gate_w, self.gate_b)
        gi = torch.sigmoid(gi)
        r, u = gi.chunk(2,1)
        r_state = r * hx
        reseted = torch.cat([input_,r_state],1)
        candidate = F.linear(reseted, self.candidate_w,self.candidate_b)
        c = torch.tanh(candidate)
        new_h = u * hx + (1 - u) * c
        return new_h


class SGRUCell(RNNCellBase):
    def __init__(self, input_size, hidden_size, bias=True):
        super(SGRUCell, self).__init__()
        self.input_size = input_size
        self.hidden_size = hidden_size
        self.bias = bias
        self.gate_w = Parameter(torch.Tensor(2*hidden_size, input_size + hidden_size))
        self.candidate_w = Parameter(torch.Tensor(hidden_size,input_size + hidden_size))
        if bias:
            self.gate_b = Parameter(torch.Tensor(2 * hidden_size))
            self.candidate_b = Parameter(torch.Tensor(hidden_size))
        else:
            self.register_parameter('gate_b', None)
            self.register_parameter('candidate_b', None)
        self.reset_parameters()

    def reset_parameters(self):
        # stv = math.sqrt(self.input_size + self.hidden_size)
        # nn.init.uniform_(self.gate_w,-stv,stv)
        # nn.init.uniform_(self.candidate_w, -stv, stv)
        stv = math.sqrt(2.0 / (self.input_size + self.hidden_size))
        nn.init.normal_(self.gate_w,std = stv)
        nn.init.normal_(self.candidate_w,std = (5.0/3) *stv)
        # nn.init.xavier_uniform_(self.gate_w)
        # nn.init.xavier_uniform_(self.candidate_w)
        if self.bias:
            nn.init.constant_(self.gate_b,val=1.0)
            nn.init.constant_(self.candidate_b,val=0)

    def reset_hidden(self):
        self.hidden = None

    def detach_hidden(self):
        self.hidden.detach_()

    def forward(self, input_, hx, update_probs):
        """
        :param input_: [batch_size, input_size]
        :param hx: (h_0) which contains the initial hidden and cell state which size is (batch_size, hidden_size)
        :param update_probs : [batch_size,2]
        :return: (h_1)
        """
        batch_size = hx.size(0)
        concated = torch.cat([input_,hx],1)

        gi = F.linear(concated, self.gate_w, self.gate_b)
        gi = torch.sigmoid(gi)
        r, u = gi.chunk(2,1)
        r_state = r * hx
        reseted = torch.cat([input_,r_state],1)
        candidate = F.linear(reseted, self.candidate_w,self.candidate_b)
        c = torch.tanh(candidate)
        new_h = u * hx + (1 - u) * c
        candidate = torch.stack([new_h,hx],-1)
        # print(candidate.size(),update_probs.size())
        new_h = torch.sum(update_probs.unsqueeze(1) * candidate,-1)

        return new_h


class PGRUCell(RNNCellBase):
    def __init__(self, input_size, hidden_size, bias=True):
        super(PGRUCell, self).__init__()
        self.input_size = input_size
        self.hidden_size = hidden_size
        self.weight_hh = nn.Linear(hidden_size,4*hidden_size)
        self.weight_ih = nn.Linear(input_size,4*hidden_size)
        self.ug = Parameter(torch.Tensor([0.7]))
        self.cg = Parameter(torch.Tensor([0.7]))
        self.rg = Parameter(torch.Tensor([0.7]))
        self.candidate = nn.Linear(input_size + input_size, 2*hidden_size)
        self.reset_parameters()

    def reset_parameters(self):
        # stv = math.sqrt(self.input_size + self.hidden_size)
        # nn.init.uniform_(self.gate_w,-stv,stv)
        # nn.init.uniform_(self.candidate_w, -stv, stv)
        hh_stv = math.sqrt(1.0 / self.hidden_size)
        ih_stv = math.sqrt(2.0 / (self.hidden_size + self.input_size))

        nn.init.normal_(self.weight_hh.weight,std = hh_stv)
        nn.init.constant_(self.weight_hh.bias,val=0)
        nn.init.normal_(self.weight_ih.weight,std = ih_stv)
        nn.init.constant_(self.weight_ih.bias,val=0)
        # nn.init.xavier_uniform_(self.gate_w)
        # nn.init.xavier_uniform_(self.candidate_w)

    def reset_hidden(self):
        self.hidden = None

    def detach_hidden(self):
        self.hidden.detach_()

    def reparametrize(self, mu, logvar,var_g):
        device = mu.device
        std = logvar.mul(var_g).exp_()
        # std = F.relu(logvar.mul(0.5))
        eps = torch.FloatTensor(std.size()).normal_().to(device)
        return eps.mul(std).add_(mu)

    def forward(self, input_, hx):
        """
        :param input_: [batch_size, input_size]
        :param hx: (h_0) which contains the initial hidden and cell state which size is (batch_size, hidden_size)
        :return: (h_1)
        """
        batch_size = hx.size(0)

        temp = self.weight_hh(hx)
        rm, rv, um, uv = temp.chunk(4,1)
        rm_, rv_, um_, uv_ = self.weight_ih(hx).chunk(4,1)

        rm = rm+rm_
        rv = torch.tanh(rv+rv_)
        um = um + um_
        uv = torch.tanh(uv+uv_)
        # cm = cm+cm_
        # cv = cv+cv_

        r = torch.sigmoid(self.reparametrize(rm,rv, self.rg**2))
        u = torch.sigmoid(self.reparametrize(um, uv, self.ug**2))
        r_state = r * hx
        reseted = torch.cat([input_, r_state], 1)
        cm,cv = self.candidate(reseted).chunk(2,1)

        c = self.reparametrize(cm,torch.tanh(cv), self.cg**2)
        # c = F.tanh(c)
        new_h = u * hx + (1 - u) * c
        return new_h

class MRNCell(RNNCellBase):
    def __init__(self, input_size, hidden_size, bias=True):
        super(MRNCell, self).__init__()
        self.input_size = input_size
        self.hidden_size = hidden_size
        self.bias = bias
        self.sigmoid = torch.sigmoid
        self.tanh = torch.tanh
        self.weight_ih = Parameter(torch.Tensor(2 * hidden_size, input_size))
        self.weight_hh = Parameter(torch.Tensor(1 * hidden_size, hidden_size))
        if bias:
            self.bias_ih = Parameter(torch.Tensor(2 * hidden_size))
            self.bias_hh = Parameter(torch.Tensor(1 * hidden_size))
        else:
            self.register_parameter('bias_ih', None)
            self.register_parameter('bias_hh', None)
        self.reset_parameters()

    def reset_parameters(self):
        nn.init.xavier_uniform_(self.weight_ih)
        nn.init.xavier_uniform_(self.weight_hh)
        if self.bias:
            nn.init.constant_(self.bias_hh,val=0)
            nn.init.constant_(self.bias_ih,val=0)

    def reset_hidden(self):
        self.hidden = None

    def detach_hidden(self):
        self.hidden.detach_()

    def forward(self, input_, hx):
        """

        :param input_: [batch_size, input_size]
        :param hx: (h_0) which contains the initial hidden and cell state which size is (batch_size, hidden_size)
        :return: (h_1)
        """
        batch_size = hx.size(0)

        gi = F.linear(input_.view(batch_size, -1), self.weight_ih, self.bias_ih)
        gh = F.linear(hx, self.weight_hh, self.bias_hh)
        i_i, i_n = gi.chunk(2, 1)
        h_i = gh
        inputgate = self.sigmoid(i_i + h_i)
        new_input = i_n

        hidden = inputgate * new_input + (1-inputgate) * hx
        return hidden


class HPA_RNNCell(RNNCellBase):
    """
    hard attention past attending rnn cell
    hard attention based on gumbel softmax
    """
    def __init__(self, input_size, hidden_size, bias=True, attention_type = 'add'):
        super(HPA_RNNCell, self).__init__()
        self.input_size = input_size
        self.hidden_size = hidden_size
        self.attention_type = attention_type
        self.alpha_projection = nn.Linear(self.input_size,self.hidden_size)

        self.weight_ih = nn.Linear(self.input_size,self.hidden_size,bias)
        self.weight_hh = nn.Linear(self.hidden_size, self.hidden_size,bias)
        self.weight_att_ih = nn.Linear(self.input_size,self.hidden_size,bias)
        self.weight_att_hh = nn.Linear(self.hidden_size, self.hidden_size,bias)
        self.hidden_projection = nn.Linear(self.hidden_size,self.hidden_size,bias)
        self.input_projection = nn.Linear(self.input_size,self.hidden_size,bias)
        self.query_vec = Parameter(torch.Tensor(self.hidden_size))
        self.reset_parameters()

    def reset_parameters(self):
        nn.init.uniform_(self.query_vec, -np.sqrt(6.0 / (2 * self.hidden_size)),np.sqrt(6.0 / (2 * self.hidden_size)))


    def reset_hidden(self):
        self.hidden = None

    def detach_hidden(self):
        self.hidden.detach_()

    def forward(self, input_,memory_tape,lens):
        """
        :param input_: [batch_size, input_size]
        :param memory_tape: [batch_size, len, state_size]
        :return: (h_1)
        """

        if memory_tape.dim() == 1:
            batch_size = input_.size(0)
            att_score = torch.ones(batch_size, 1).float().cuda()
            memory_tape = torch.zeros(batch_size, self.hidden_size).float().unsqueeze(1).cuda()
            last_hidden = memory_tape.squeeze(1)
        else:
            length = memory_tape.size(1)
            length = torch.Tensor([length]).long().cuda()
            length = torch.min(lens,length)
            last_hidden = memory_tape[:, -1, :]
            ai_logits = self.input_projection(input_)  # [batch, hidden]
            ah_logits = self.hidden_projection(memory_tape)
            att_context = torch.tanh(ai_logits.unsqueeze(1) + ah_logits)
            att_logits = torch.sum(self.query_vec.expand_as(att_context) * att_context, dim=-1)
            att_score = self.attention(att_logits,length) #[batch, len]


        hi_last = self.weight_ih(input_)
        hh_last = self.weight_hh(last_hidden)
        newh_last = torch.tanh(hi_last+hh_last)

        hi_att = self.weight_att_ih(input_)
        hh_att = self.weight_att_hh(memory_tape) #[batch, len, state]
        newh_att = torch.tanh(hi_att.unsqueeze(1)+hh_att) #[batch, len, state]
        newh_att = att_score.unsqueeze(-1) * newh_att
        newh_att = torch.sum(newh_att,dim=1)


        input_ap = self.alpha_projection(input_) #[batch, hidden]
        hcat = torch.cat([newh_att.unsqueeze(1),newh_last.unsqueeze(1)],1) #[batch,2,state]
        alpha_logits = torch.sum(input_ap.unsqueeze(1) * hcat,dim=-1) #[batch, 2]
        alpha_probs = F.softmax(alpha_logits,dim=-1)

        hidden = torch.sum(alpha_probs.unsqueeze(-1) * hcat,1)


        return hidden

    def attention(self, logits, length):
        """
        :param logits: [batch, len]
        :param length: [batch]
        :return:
        """
        masked = mask_lengths(length)
        logits = logits * masked
        reverse_mask = 30* (masked - 1)
        logits = logits + reverse_mask
        return gumbel_softmax(logits,1)



class PA_RNNCell(RNNCellBase):
    def __init__(self, input_size, hidden_size, bias=True):
        super(PA_RNNCell, self).__init__()
        self.input_size = input_size
        self.hidden_size = hidden_size
        self.bias = bias
        self.weight_ih = Parameter(torch.Tensor(hidden_size, input_size))
        self.weight_hh = Parameter(torch.Tensor(hidden_size, hidden_size))
        self.weight_project = Parameter(torch.Tensor(self.hidden_size, self.input_size ))
        if bias:
            self.bias_ih = Parameter(torch.Tensor(hidden_size))
            self.bias_hh = Parameter(torch.Tensor(hidden_size))
            self.bias_project = Parameter(torch.Tensor(hidden_size))
        else:
            self.register_parameter('bias_ih', None)
            self.register_parameter('bias_hh', None)
            self.register_parameter('bias_project', None)
        self.reset_parameters()

    def reset_parameters(self):
        nn.init.xavier_uniform_(self.weight_ih)
        nn.init.xavier_uniform_(self.weight_hh)
        nn.init.xavier_uniform_(self.weight_project)
        #self.weight_ar.data = self.weight_ar.data.clamp(max=1.0)
        if self.bias:
            nn.init.constant_(self.bias_hh,val=0)
            nn.init.constant_(self.bias_ih,val=0)
            nn.init.constant_(self.bias_project,val=0)

    def reset_hidden(self):
        self.hidden = None

    def detach_hidden(self):
        self.hidden.detach_()

    def forward(self, input_,memory_tape,length):
        """
        :param input_: [batch_size, input_size]
        :param memory_tape: [batch_size, len, state_size]
        :return: (h_1)
        """
        if memory_tape.dim() == 1:
            batch_size = input_.size(0)
            att_score = torch.ones(batch_size,1).float().unsqueeze(1).cuda()
            memory_tape = torch.zeros(batch_size,self.hidden_size).float().unsqueeze(1).cuda()
        else:
            att_score = self.attention(input_,memory_tape,length) #[batch, len]
        hi = F.linear(input_, self.weight_ih, self.bias_ih) # [batch,state]
        hh = F.linear(memory_tape, self.weight_hh, self.bias_hh) # [ batch, len, state]

        hidden = torch.tanh(hi.unsqueeze(1)+hh)
        hidden = att_score.unsqueeze(2) * hidden
        hidden = torch.sum(hidden,dim=1)

        return hidden

    def attention(self, input_, output, length):
        """
        :param input_:  Tensor (batch, state)
        :param outputs: [ Tensor(batch,1,state), Tensor(...), ... ]
        :param length: Tensor(batch)
        :return:
        """
        if output.dim() == 1:
            return None
        else:
            time = output.shape[1]
            time = torch.Tensor([time]).long()
            length = torch.min(length, time).cuda()
            masked = mask_lengths(length)
            reverse_mask = 30* (masked - 1)
            # context = torch.cat([input_,last_state],1)
            context = input_
            context_projected = F.linear(context, self.weight_project, self.bias_project)
            context_projected = torch.tanh(context_projected)  # [batch, state]
            # output = F.linear(output,self.hidden_projection_weight,self.hidden_projection_bias)
            # location_bias = self.location_bias(length,output.shape[1])
            att_logits = torch.mul(output, context_projected.unsqueeze(1)) # [batch, len, state]
            att_logits = torch.sum(att_logits, dim=2)  # [ batch, len]
            # print(torch.max(att_logits))
            # print(att_logits[0])
            att_logits = (att_logits * masked) - reverse_mask
            att_score = torch.nn.functional.softmax(att_logits,dim=-1)
            return att_score


# class ARNCell(RNNCellBase):
#     def __init__(self, input_size, hidden_size, bias=True):
#         super(ARNCell, self).__init__()
#         self.input_size = input_size
#         self.hidden_size = hidden_size
#         self.bias = bias
#         self.sigmoid = F.sigmoid
#         self.tanh = F.tanh
#         self.weight_ih = Parameter(torch.Tensor(2 * hidden_size, input_size))
#         self.weight_hh = Parameter(torch.Tensor(hidden_size, hidden_size))
#         self.weight_ar = Parameter(torch.Tensor(hidden_size))
#         if bias:
#             self.bias_ih = Parameter(torch.Tensor(2 * hidden_size))
#             self.bias_hh = Parameter(torch.Tensor(hidden_size))
#         else:
#             self.register_parameter('bias_ih', None)
#             self.register_parameter('bias_hh', None)
#         self.reset_parameters()
#
#     def reset_parameters(self):
#         nn.init.xavier_uniform_(self.weight_ih)
#         nn.init.xavier_uniform_(self.weight_hh)
#         nn.init.constant(self.weight_ar,val=1)
#         if self.bias:
#             nn.init.constant(self.bias_hh,val=0)
#             nn.init.constant(self.bias_ih,val=0)
#         self.reset_hidden()
#
#     def reset_hidden(self):
#         self.hidden = None
#
#     def detach_hidden(self):
#         self.hidden.detach_()
#
#     def forward(self, input_data, future=0):
#         batch_size, timesteps, features = input_data.size()
#         # print("t %d, b %d, f %d" % (timesteps, batch_size, features))
#         outputs = Variable(torch.zeros(timesteps + future, batch_size, self.hidden_size), requires_grad=False)
#
#         if self.hidden is None:
#             self.hidden = Variable(torch.zeros(batch_size, self.hidden_size), requires_grad=False)
#
#         self.check_forward_input(input_data[0])
#         self.check_forward_hidden(input_data[0], self.hidden)
#
#         for i, input_t in enumerate(input_data.split(1)):
#             gi = F.linear(input_t.view(batch_size, features), self.weight_ih, self.bias_ih)
#             gh = F.linear(self.hidden, self.weight_hh, self.bias_hh)
#             i_i, i_n = gi.chunk(2, 1)
#             h_i = gh
#
#             # f, i = sigmoid(Wx + Vh_tm1 + b)
#             inputgate = self.sigmoid(i_i + h_i)
#             new_input = i_n
#
#             # h_t = f * tanh(h_tm1) + i * tanh(Wx)
#             self.hidden = inputgate * new_input +  self.weight_ar*(self.hidden)
#             outputs[i] = self.hidden
#
#         return outputs

if __name__ =='__main__':
    temp = ARNCell(3,2)
    res = temp(torch.Tensor([[1,1,1],[0,0,0]]),torch.Tensor([[0,0],[0,0]]))
    print(res)