import trax.layers as tl
from trax.layers.research.rel_attention import RelativeAttentionWrapper, \
  get_rel_att_inputs
from trax.layers.research.resampling import AttentionResampling, \
  AveragePooling, LinearUpsampling
from trax.models.research.funnel_transformer import _RelativeDecoderBlock


def HourglassLM(vocab_size,
                d_model=512,
                d_ff=2048,
                vanilla_layers=(1, 1),
                hierarchy_shorten_factors=(3,),
                hierarchy_n_layers=(6,),
                n_heads=8,
                dropout=0.1,
                dropout_shared_axes=None,
                mode='train',
                ff_activation=tl.FastGelu,
                vanilla_attn_type=RelativeAttentionWrapper,
                middle_attn_type=RelativeAttentionWrapper,
                downsampling_fn=AttentionResampling,
                upsampling_fn=AttentionResampling,
                attention_downsampling_fn=AveragePooling,
                attention_upsampling_fn=LinearUpsampling):
  """Returns a hierarchical Transformer language model.

  This model performs autoregressive language modeling:

    - input: rank 2 tensor representing a batch of text strings via token IDs
      plus padding markers; shape is (batch_size, sequence_length). The tensor
      elements are integers in `range(vocab_size)`, and `0` values mark padding
      positions.

    - output: rank 3 tensor representing a batch of log-probability
      distributions for each sequence position over possible token IDs;
      shape is (batch_size, sequence_length, `vocab_size`).

  This model uses only the decoder part of the overall Transformer.

  Args:
    vocab_size: Input vocabulary size -- each element of the input tensor
        should be an integer in `range(vocab_size)`. These integers typically
        represent token IDs from a vocabulary-based tokenizer.
    d_model: Final dimension of tensors at most points in the model, including
        the initial embedding output.
    d_ff: Size of special dense layer in the feed-forward part of each encoder
        block.
    vanilla_layers: (pre_layers, post_layers) tuple - number of full token-level
        Transformer decoder layers before and after shortening.
    hierarchy_shorten_factors: by how much to shorten at each step - tuple of arbitrary
        length denoting by how much shorten at each pooling stage.
    hierarchy_n_layers: number of Transformer decoder blocks after each stage of
        pooling - tuple of the same length as `shorten_factors`.
    n_heads: Number of attention heads.
    dropout: Stochastic rate (probability) for dropping an activation value
        when applying dropout within an encoder block.
    dropout_shared_axes: Tensor axes on which to share a dropout mask.
        Sharing along batch and sequence axes (`dropout_shared_axes=(0,1)`) is
        a useful way to save memory and apply consistent masks to activation
        vectors at different sequence positions.
    mode: str: 'train' or 'eval'.
    ff_activation: Type of activation function at the end of each encoder
        block; must be an activation-type subclass of `Layer`.
    vanilla_attn_type: class: attention class such as SelfAttention to use in
        the layers before and after shortening (vanilla layers).
    middle_attn_type: class: attention class to use in the middle layers
        (these operating on the shortened sequence).
    downsampling_fn: function that takes full token-level vectors of
        length `l` and transforms them into `l` / `k` vectors, where `k`
        denotes `shorten_factor` parameter.
    upsampling_fn: function that takes shortened representations of a sequence,
        consisting of `l` / `k` vectors and transforms them into full
        token-level representations of length `l`.
    attention_downsampling_fn: Downsampling function that transforms token-level
        vectors into query vectors with reduced length. Necessary only when
        AttentionResampling is used as `downsampling_fn`.
    attention_upsampling_fn: Upsampling function for AttentionResampling.
        Valid only when AttentionResampling is used as a `upsampling_fn`.


  Returns:
    A Transformer language model as a layer that maps from a tensor of tokens
    to activations over a vocab set.
  """
  assert mode != 'predict'  # For now, 'predict' mode is unsupported.
  assert len(hierarchy_n_layers) == len(hierarchy_shorten_factors)

  token_encoder = [
      tl.Embedding(vocab_size, d_model),
      tl.Dropout(rate=dropout, shared_axes=dropout_shared_axes, mode=mode)]

  context_bias_layer, location_bias_layer = get_rel_att_inputs(d_model,
                                                               n_heads)

  n_pre_decoder_blocks, n_post_decoder_blocks = vanilla_layers

  def create_decoder_blocks(n_layers, total_pooling, attention_type):
    decoder_blocks = [
        # pylint: disable=g-complex-comprehension
        _RelativeDecoderBlock(attention_type, d_model, d_ff, n_heads, dropout,
                              dropout_shared_axes, mode, ff_activation,
                              context_bias_layer, location_bias_layer,
                              total_pooling)
        for _ in range(n_layers)]
    return decoder_blocks + [tl.LayerNorm()]

  def create_hourglass_valley(rest_shorten_factors,
                              rest_n_funnel_blocks,
                              current_total_pooling):
    assert (len(rest_shorten_factors) > 0)
    assert (len(rest_shorten_factors) == len(rest_n_funnel_blocks))

    current_sf = rest_shorten_factors[0]
    current_n_layers = rest_n_funnel_blocks[0]

    shortening_layer = downsampling_fn(current_sf, d_model,
                                       is_upsampling=False, d_ff=d_ff,
                                       n_heads=n_heads,
                                       dropout=dropout,
                                       dropout_shared_axes=dropout_shared_axes,
                                       mode=mode,
                                       ff_activation=ff_activation,
                                       context_bias_layer=context_bias_layer,
                                       location_bias_layer=location_bias_layer,
                                       total_pooling=current_total_pooling,
                                       resampling_fn=attention_downsampling_fn)

    upsampling_layer = upsampling_fn(current_sf,
                                     d_model=d_model, is_upsampling=True,
                                     d_ff=d_ff, n_heads=n_heads,
                                     dropout=dropout,
                                     dropout_shared_axes=dropout_shared_axes,
                                     mode=mode, ff_activation=ff_activation,
                                     context_bias_layer=context_bias_layer,
                                     location_bias_layer=location_bias_layer,
                                     total_pooling=current_total_pooling,
                                     resampling_fn=attention_upsampling_fn)

    if len(rest_shorten_factors) > 1:  # we need to go deeper again
      pre_stage_blocks = create_decoder_blocks(
          current_n_layers,
          current_total_pooling * current_sf,
          middle_attn_type
      )

      post_stage_blocks = create_decoder_blocks(
          current_n_layers,
          current_total_pooling * current_sf,
          middle_attn_type
      )

      return [
          tl.Dup(),
          tl.ShiftRight(current_sf - 1, mode=mode),
          shortening_layer,
          pre_stage_blocks,
          *create_hourglass_valley(rest_shorten_factors[1:],
                                   rest_n_funnel_blocks[1:],
                                   current_total_pooling * current_sf),
          post_stage_blocks,
          upsampling_layer,
          tl.LayerNorm(),
          tl.Add()
      ]
    else:
      blocks = create_decoder_blocks(current_n_layers,
                                     current_total_pooling * current_sf,
                                     middle_attn_type)

      return [
          tl.Dup(),
          tl.ShiftRight(current_sf - 1),
          shortening_layer,
          blocks,
          upsampling_layer,
          tl.LayerNorm(),
          tl.Add()
      ]

  pre_decoder_blocks = create_decoder_blocks(n_pre_decoder_blocks,
                                             1,
                                             vanilla_attn_type)

  post_decoder_blocks = create_decoder_blocks(n_post_decoder_blocks,
                                              1,
                                              vanilla_attn_type)

  valley = create_hourglass_valley(hierarchy_shorten_factors, hierarchy_n_layers, 1)

  # Assemble and return the model.
  return tl.Serial(              # tokens (or chunked tuple of tokens)
      tl.ShiftRight(mode=mode),  # toks
      token_encoder,             # vecs
      pre_decoder_blocks,        # vecs
      valley,
      # Concatenate
      post_decoder_blocks,       # vecs
      tl.Dense(vocab_size),      # vecs
  )
