import jax
import numpy as np
from absl.testing import parameterized

from trax import shapes, fastmath, layers as tl
from trax.models.research.hourglass import HourglassLM


class HourglassTest(parameterized.TestCase):
  def _check_forward_shape(self, model, input_shape, output_vocab_size):
    x = np.ones(input_shape).astype(np.int32)
    model.init(shapes.signature(x))
    y = model(x)
    self.assertEqual(y.shape, (*input_shape, output_vocab_size))

  def test_hourglass_lm_forward_shape(self):
    d_model = 16
    vocab_size = 7
    model = HourglassLM(
        vocab_size,
        hierarchy_shorten_factors=(3, 2),
        hierarchy_n_layers=(2, 2),
        vanilla_layers=(1, 1),
        d_model=d_model,
        d_ff=d_model,
        n_heads=2,
    )

    batch_size, seq_len = 3, 96
    self._check_forward_shape(model,
                              input_shape=(batch_size, seq_len),
                              output_vocab_size=vocab_size)

  def _test_autoregressive_property(self, model, input_shape,
                                    output_vocab_size):
    rng_1 = jax.random.PRNGKey(0)
    rng_2 = jax.random.PRNGKey(1)

    def _get_output_logits(unitialized_eval_model: tl.Layer, x):
      input_signature = shapes.signature(x)
      unitialized_eval_model.init(input_signature, rng=rng_1, use_cache=False)

      output_logits, *_ = unitialized_eval_model(x, rng=rng_1)
      return output_logits

    def check_autoregressive_property(model):
      with fastmath.use_backend(fastmath.Backend.JAX):
        x_1 = jax.random.randint(rng_1, input_shape, 0, output_vocab_size)
        y_1 = _get_output_logits(model, x_1)

        x_2 = jax.random.randint(rng_2, input_shape, 0, output_vocab_size)

        for i in range(input_shape[1]):
          masked_x_2 = np.concatenate((x_1[:, :i], x_2[:, i:]), axis=1)

          y_2 = _get_output_logits(model, masked_x_2)
          self.assertEqual(y_2.shape[0], input_shape[1])
          np.testing.assert_array_almost_equal(y_1[:i + 1], y_2[:i + 1])

    check_autoregressive_property(model)

  def test_hourglass_lm_autoregressive_property(self):
    d_model = 8
    vocab_size = 26

    model = HourglassLM(
        vocab_size,
        hierarchy_shorten_factors=(3, 2),
        hierarchy_n_layers=(2, 2),
        vanilla_layers=(1, 1),
        d_model=d_model,
        d_ff=d_model,
        n_heads=2,
    )

    input_shape = (1, 48)
    self._test_autoregressive_property(model, input_shape,
                                       output_vocab_size=vocab_size)
