from parser.modules.scalar_mix import ScalarMix

import torch
import torch.nn as nn
from allennlp.common import FromParams
from allennlp.modules.elmo import _ElmoBiLm
from allennlp.nn.util import remove_sentence_boundaries


class Elmo(torch.nn.Module, FromParams):
    def __init__(self,
                 layer=None,
                 dropout=0.33,
                 options_file="data/options.json",
                 weight_file="data/weights.hdf5",
                 requires_grad=False,
                 fd_repr=True):
        """

        Args:
            layer (int):
            dropout (float):
        """
        super().__init__()
        self.fd_repr = fd_repr
        self.layer = layer
        self.dropout = dropout
        self._dropout = nn.Dropout(p=dropout)
        self.requires_grad = requires_grad
        self._elmo_lstm = _ElmoBiLm(
            options_file=options_file,
            weight_file=weight_file,
            requires_grad=requires_grad,
        )
        if layer is None:
            self.scalar_mix = ScalarMix(n_layers=2)

    def __repr__(self):
        s = self.__class__.__name__ + f"("
        s += f"layer={self.layer if self.layer is not None else 'all'}, "
        s += f"dropout={self.dropout}, "
        s += f"n_out={self._elmo_lstm.get_output_dim()}, "
        s += f"fd_repr={self.fd_repr}"
        if self.requires_grad:
            s += f", requires_grad={self.requires_grad}"
        s += ')'

        return s

    def forward(self, chars):
        """

        Args:
            chars : `torch.Tensor`, required.
            Shape `(batch_size, timesteps, 50)` of character ids representing the current batch.
        Returns:

        """

        bilm_output = self._elmo_lstm(chars)
        layer_activations = bilm_output["activations"]
        mask_with_bos_eos = bilm_output["mask"]

        if self.layer is None:
            representation_with_bos_eos = self.scalar_mix(
                layer_activations[-2:])
        else:
            representation_with_bos_eos = layer_activations[self.layer]
        res, _ = remove_sentence_boundaries(representation_with_bos_eos,
                                            mask_with_bos_eos)
        res = self._dropout(res)

        if self.fd_repr:
            # get minus
            forward, backward = torch.chunk(res, 2, dim=-1)

            forward_minus = forward[:, 1:] - forward[:, :-1]
            forward_minus = torch.cat([forward[:, :1], forward_minus], dim=1)
            backward_minus = backward[:, :-1] - backward[:, 1:]
            backward_minus = torch.cat([backward_minus, backward[:, -1:]],
                                       dim=1)
            # [batch_size, seq_len, n_elmo]
            res = torch.cat([forward_minus, backward_minus], dim=-1)
        return res
