import math
import torch
import numpy as np
import torch.nn as nn
import torch.nn.functional as f

from torch.nn import LayerNorm
from torch.nn.utils.rnn import pack_padded_sequence as pack
from torch.nn.utils.rnn import pad_packed_sequence as unpack


class BiLSTMEncoder(nn.Module):
    def __init__(self, input_size, hidden_size):
        super(BiLSTMEncoder, self).__init__()
        self.encoder = nn.LSTM(
            input_size=input_size,
            hidden_size=hidden_size,
            batch_first=True,
            bidirectional=True
        )
    def forward(self, input, length):
        self.encoder.flatten_parameters()
        _, idx_sort = torch.sort(length, dim=0, descending=True)
        _, idx_unsort = torch.sort(idx_sort, dim=0)
        length = list(length[idx_sort])
        enc_packed = input.index_select(0, idx_sort)
        total_length = enc_packed.size(1)
        enc_packed = pack(enc_packed, length, True)
        enc_packed, (_, _) = self.encoder(enc_packed)
        enc_padded = unpack(enc_packed, batch_first=True, total_length=total_length)
        output = enc_padded[0].index_select(0, idx_unsort)
        return output
