import torch
import torch.nn as nn
from autoencoder import Encoder


class BaseEncoder(Encoder):
    def __init__(self, config):
        super(BaseEncoder, self).__init__(config)

        self.teacher_forcing_ratio = config.teacher_forcing_ratio
        self.gaussian_noise_std = config.gaussian_noise_std
        self.unit_sphere = config.unit_sphere
        self.teacher_forcing_batchwise = config.teacher_forcing_batchwise

        self.config = config
        self.device = config.device

        self.hidden_size = config.hidden_size
        self.variational = config.variational

        self.max_sequence_len = config.max_sequence_len
        self.input_size = config.hidden_size
        self.use_l0drop = config.use_l0drop

        # plus 1 to make the 0th word a "padding" one.
        self.vocab_size = config.vocab_size + 1

        self.embedding = nn.Embedding(
            self.vocab_size, config.embedding_size, padding_idx=0)  # let 0 denote padding
        self.eos_idx = config.eos_idx
        self.sos_idx = config.sos_idx
        self.layers = config.layers

        self.input_projection = nn.Linear(
            config.embedding_size, self.hidden_size, bias=False)

        if config.variational:
            self.hidden2mean = nn.Linear(self.hidden_size, self.hidden_size)
            self.hidden2logv = nn.Linear(self.hidden_size, self.hidden_size)

    def _to_hidden_representation(self, embedded, lengths):
        raise NotImplementedError("Needs to be implemented by the subclass.")

    def encode(self, x, lengths, train=False, reparameterize=True):

        if self.max_sequence_len > 0:
            x = x[:, :self.max_sequence_len]
            lengths = torch.minimum(lengths, torch.tensor(
                self.max_sequence_len, device=lengths.device))

        # token embeddings
        embedded = self.embedding(x)

        # transform to input size
        embedded = self.input_projection(embedded)

        h = self._to_hidden_representation(embedded, lengths)
        if self.variational:
            mean = self.hidden2mean(h)
            logv = self.hidden2logv(h)
            std = torch.exp(0.5 * logv)
            if reparameterize:
                h = torch.randn(x.shape[0], self.hidden_size,
                                device=self.device) * std + mean
            else:
                h = mean

        if self.unit_sphere:
            h = h / h.norm(p=None, dim=-1, keepdim=True)

        # (batch, hidden_size)
        if train and self.variational:
            return h, mean, logv
        elif train and self.use_l0drop:
            h, l0_loss = h
            return h, l0_loss
        else:
            return h
