import torch.nn as nn
from torch.nn.utils.rnn import pack_padded_sequence
from rnn_decoder import RNNDecoder
from autoencoder import Encoder
from autoencoders.base_encoder import BaseEncoder


class RNNEncoder(BaseEncoder):
    def __init__(self, config):
        super(RNNEncoder, self).__init__(config)

        self.layers = config.layers
        self.reduction = "mean"
        self.type = config.type

        if self.type == "LSTM":
            self.encoder = nn.LSTM(
                input_size=config.input_size,
                hidden_size=self.hidden_size,
                num_layers=config.layers,
                bidirectional=True,
                batch_first=True
            )
        elif self.type == "GRU":
            self.encoder = nn.GRU(
                input_size=config.input_size,
                hidden_size=self.hidden_size,
                num_layers=config.layers,
                bidirectional=True,
                batch_first=True
            )

    def _to_hidden_representation(self, embedded, lengths):
        packed = pack_padded_sequence(embedded, lengths, batch_first=True)

        _, h = self.encoder(packed)

        # h_n of shape (num_layers*num_dir, batch, hidden_size)
        # mean over all hidden layers
        if self.type == "LSTM":
            h = h[0].mean(dim=0)
        else:
            h = h.mean(dim=0)

        return h.unsqueeze(1)
