import torch
import torch.nn as nn
from torch.nn.utils.rnn import pack_padded_sequence, pad_packed_sequence
from autoencoder import Decoder
from autoencoders.base_ar_decoder import BaseARDecoder

# Rename to RNNAE


class RNNDecoder(BaseARDecoder):
    def __init__(self, config):
        super(RNNDecoder, self).__init__(config)

        self.type = config.type
        self.layers = config.layers

        # Consider using GRU?
        if self.type == "LSTM":
            self.decoder = nn.LSTM(
                input_size=config.input_size,
                hidden_size=self.hidden_size,
                num_layers=self.layers,
                batch_first=True
            )
        elif self.type == "GRU":
            self.decoder = nn.GRU(
                input_size=config.input_size,
                hidden_size=self.hidden_size,
                num_layers=self.layers,
                batch_first=True
            )

    def _get_output_and_update_memory(self, embedded_input, state, embedding, t):
        out, h = self.decoder(embedded_input, state)
        return out, h, t + 1

    def _decode_all(self, embedded_teacher, h, l):

        packed_teacher = pack_padded_sequence(
            embedded_teacher, l, batch_first=True)
        packed_output, h = self.decoder(packed_teacher, h)
        output, _ = pad_packed_sequence(packed_output, batch_first=True)
        return output

    def init_hidden_greedy(self, x):
        x = x.squeeze(1)
        if self.type == "LSTM":
            return x.repeat(self.layers, 1, 1), torch.zeros(self.layers, x.shape[0], self.hidden_size, device=self.device)
        elif self.type == "GRU":
            return x.repeat(self.layers, 1, 1)

    def init_hidden_batchwise(self, x):
        return self.init_hidden_greedy(x)

    def _hidden_from_beam(self, incomplete):
        # want: h_prev of (num_layers, batch*beam_len, input_size)
        # Do (batch*beam_len, num_layers, input_size) then move axis
        if self.type == "LSTM":
            h_prev = torch.stack(
                [beam.hidden_state[0] for batch in incomplete for beam in incomplete[batch]]).permute(1, 0, 2)
            c_prev = torch.stack(
                [beam.hidden_state[1] for batch in incomplete for beam in incomplete[batch]]).permute(1, 0, 2)
            h = (h_prev.contiguous(), c_prev.contiguous())
        elif self.type == "GRU":
            h = torch.stack(
                [beam.hidden_state for batch in incomplete for beam in incomplete[batch]]).permute(1, 0, 2).contiguous()

        return h

    def _hidden_to_beam(self, h, indices):
        # h_n of shape (num_layers, batch, hidden_size)
        if self.type == "LSTM":
            return (h[0][:, indices], h[1][:, indices])
        elif self.type == "GRU":
            return h[:, indices]
