from rnn_decoder import RNNDecoder
from autoencoder import Encoder
from autoencoders.base_encoder import BaseEncoder
from fast_transformers.builders import TransformerEncoderBuilder, TransformerDecoderBuilder
from fast_transformers.masking import LengthMask
import torch
from math import ceil
from l0drop import L0Drop, DotProductAlphaEstimator, ContextualizedAlphaEstimator
from torch_utils import PositionalEncoder
from centroid_attention import CentroidAttention, random_sampling_initialization, avgpool, set_to_k, reduce_by_k, init_with_fixed
from torch.nn.modules.pooling import AvgPool1d
from fast_transformers.builders.transformer_builders import RecurrentDecoderBuilder
import numpy as np


class TransformerEncoder(BaseEncoder):
    def __init__(self, config):
        super(TransformerEncoder, self).__init__(config)

        self.layers = config.layers
        self.reduction = "mean"
        self.att_type = config.att_type
        self.heads = config.heads
        self.input_size = config.input_size
        self.ff_dimension = config.ff_dimension
        self.positional_embeddings = config.positional_embeddings
        self.fixed_size_bottleneck = config.fixed_size_bottleneck
        self.output_reduction = config.output_reduction
        self.use_l0drop = config.use_l0drop
        self.dropout = config.encoder_dropout
        self.output_positional_embeddings = config.output_positional_embeddings
        self.centroid_attention = config.centroid_attention
        self.output_reduction_avgpooling = config.output_reduction_avgpooling
        self.autoregressive_encoder = config.autoregressive_encoder
        self.append_gates = config.append_gates
        self.dont_discard_vectors = config.dont_discard_vectors
        self.dont_apply_gates = config.dont_apply_gates

        if self.autoregressive_encoder:
            self.start_embedding = torch.nn.Parameter(
                torch.randn(self.input_size))

        # Create the builder for our transformers
        if self.autoregressive_encoder:
            builder = RecurrentDecoderBuilder.from_kwargs(
                n_layers=self.layers,
                n_heads=self.heads,
                query_dimensions=int(self.input_size / self.heads),
                value_dimensions=int(self.input_size / self.heads),
                feed_forward_dimensions=self.ff_dimension,
                dropout=self.dropout,
                attention_dropout=self.dropout,
                cross_attention_type=self.att_type,
                self_attention_type=self.att_type
            )
        else:

            builder = TransformerEncoderBuilder.from_kwargs(
                attention_type=self.att_type,
                n_layers=self.layers,
                n_heads=self.heads,
                query_dimensions=int(self.input_size / self.heads),
                value_dimensions=int(self.input_size / self.heads),
                feed_forward_dimensions=self.ff_dimension,
                dropout=self.dropout,
                attention_dropout=self.dropout
            )

        self.transformer = builder.get()
        self.act = config.act

        if self.act:
            self.halt_probability_net = torch.nn.Sequential(
                torch.nn.Linear(self.input_size, self.input_size), torch.nn.Linear(self.input_size, 1), torch.nn.Sigmoid())
            self.mseloss = torch.nn.MSELoss(reduction='none')

        if self.positional_embeddings or self.output_positional_embeddings:
            self.pe = PositionalEncoder(self.input_size)

        self.target_ratio = config.target_ratio
        if self.use_l0drop:
            self.l0drop = L0Drop(
                self.input_size, keep_dropped_vectors=config.keep_dropped_vectors, learned_dummy=config.learned_dummy,
                alpha_estimator=eval(config.alpha_estimator), target_ratio=config.target_ratio, target_mse=config.target_mse,
                discard_epsilon=config.discard_epsilon, append_gates=self.append_gates,
                dont_discard_vectors=self.dont_discard_vectors,
                dont_apply_gates=self.dont_apply_gates)

        if self.centroid_attention:

            if config.num_centroids_f == "reduce":
                def num_centroids_f(x, xlen): return reduce_by_k(
                    x, xlen, k=config.output_reduction,
                    max_centroids=config.max_centroids)
            elif config.num_centroids_f == "set":
                def num_centroids_f(x, xlen): return set_to_k(
                    x, xlen, k=config.output_reduction)
            if config.init_f == "avgpool":
                init_f = avgpool
            elif config.init_f == "randomsample":
                init_f = random_sampling_initialization
            elif config.init_f == "fixedinit":
                max_centroids = config.max_centroids
                self.centroids = torch.zeros(
                    (max_centroids, self.input_size))
                self.centroids = torch.nn.init.xavier_uniform_(self.centroids)

                def init_f(X, X_len, num_centroids):

                    return init_with_fixed(X, X_len, num_centroids, centroids=self.centroids)
            self.centroid_attention_module = CentroidAttention(config.input_size,
                                                               T=config.centroid_attention_iterations,
                                                               reduction=config.output_reduction,
                                                               initialization_f=init_f,
                                                               num_centroids_f=num_centroids_f
                                                               )
        if self.output_reduction_avgpooling:
            self.pooling = AvgPool1d(
                self.output_reduction, padding=int(self.output_reduction / 2))

    def _make_mask(self, bsize, max_lens, lens):
        mask = torch.arange(max_lens, device=lens.device)
        mask = mask.unsqueeze(0).expand(bsize, -1)
        mask = mask < lens.unsqueeze(1)
        return mask

    def _decoder_encoder(self, X, X_lens):
        X_lens = (X_lens / self.output_reduction).long(
        ) if self.output_reduction >= 1 else (X_lens * self.output_reduction).long()
        X_lens = torch.maximum(
            X_lens, torch.ones_like(X_lens))
        batch_size = X.size(0)
        max_seq_len = X_lens.max()
        embedded_input = self.start_embedding.repeat(
            batch_size, 1)
        state = None  # let fast-transformers initialize the state

        outputs = torch.zeros(batch_size, max_seq_len,
                              X.size(-1), device=X.device)
        for i in range(max_seq_len):

            memory = X
            memory_len = X_lens.to(memory.device)
            out, state = self.transformer(embedded_input, memory=memory,
                                          memory_length_mask=LengthMask(
                                              memory_len, max_len=memory.shape[1]),
                                          state=state)

            embedded_input = out
            outputs[:, i, :] = out

        return outputs, X_lens

    def _decoder_encoder_act(self, X, X_lens):
        # X_lens = (X_lens / self.output_reduction).long(
        #) if self.output_reduction >= 1 else (X_lens * self.output_reduction).long()
        X_lens = torch.maximum(
            X_lens, torch.ones_like(X_lens))
        batch_size = X.size(0)
        max_seq_len = X_lens.max()
        embedded_input = self.start_embedding.repeat(
            batch_size, 1)
        state = None  # let fast-transformers initialize the state
        outputs = []
        halt_probabilities = []

        for _ in range(max_seq_len):

            memory = X
            memory_len = X_lens.to(memory.device)
            out, state = self.transformer(embedded_input, memory=memory,
                                          memory_length_mask=LengthMask(
                                              memory_len, max_len=memory.shape[1]),
                                          state=state)
            halt_probability = self.halt_probability_net(out)  # [batch, 1]
            halt_probabilities.append(halt_probability)

            embedded_input = out
            outputs.append(out.unsqueeze(1))

        lens_mask = self._make_mask(batch_size, max_seq_len, X_lens)
        outputs = torch.cat(outputs, dim=1) * lens_mask.unsqueeze(-1)

        halt_probabilities = torch.cat(halt_probabilities, dim=1)

        # compute cumulative halting probabilities.. but we disregard the
        # invalid outputs with respect to length
        cumul_probabilities = torch.cumsum(halt_probabilities, dim=1)
        # everything set to zero will be ignored in the following
        cumul_probabilities = cumul_probabilities * lens_mask

        max_vectors = int(np.ceil(cumul_probabilities.max().item()))

        outputs = halt_probabilities.unsqueeze(-1) * outputs

        # make sure we sum the right vectors
        summed_vectors = []
        for i in range(1, max_vectors + 1):
            mask = ((i - 1) < cumul_probabilities).float() * \
                (cumul_probabilities < i).float()
            sum_vector = (outputs * mask.unsqueeze(-1)).sum(dim=1)
            summed_vectors.append(sum_vector.unsqueeze(1))

        outputs = torch.cat(summed_vectors, dim=1)
        X_lens = torch.ceil(cumul_probabilities.max(dim=1)[0])

        # ACT cost
        act_cost = (halt_probabilities * lens_mask).sum(1)
        if self.target_ratio > 0.:
            ratio = act_cost / X_lens
            act_cost = self.mseloss(
                ratio, torch.ones_like(ratio) * self.target_ratio)
        return outputs, X_lens, act_cost

    def _to_hidden_representation(self, embedded, lengths):

        if self.positional_embeddings:
            embedded = self.pe(embedded)

        if not self.autoregressive_encoder:
            mask = LengthMask(lengths, max_len=embedded.size(1),
                              device=embedded.device)
            outs = self.transformer(embedded, length_mask=mask)
        else:
            if self.act:
                outs, lengths, act_cost = self._decoder_encoder_act(
                    embedded, lengths)
            else:
                outs, lengths = self._decoder_encoder(embedded, lengths)

        if self.fixed_size_bottleneck:

            # zero out the invalid items
            mask = self._make_mask(
                outs.size(0), outs.size(1), lengths)
            mask = mask.unsqueeze(-1)
            outs = outs * mask

            # mean over all hidden layers
            output = outs.sum(dim=1) / lengths.unsqueeze(1)
            output = output.unsqueeze(1)
            lengths = torch.ones((output.size(0)), device=output.device).long()
            # print(lengths)
        elif self.use_l0drop:
            (output, lengths), l0_loss = self.l0drop(outs, lengths)

        elif not self.autoregressive_encoder and self.output_reduction > 1 and not self.centroid_attention:

            if not self.output_reduction_avgpooling:
                number_of_elements = ceil(outs.size(1) / self.output_reduction)
                selected_indices = torch.arange(number_of_elements, device=outs.device) * \
                    self.output_reduction
                reduced_l = (lengths.float() / self.output_reduction)
                # need to subtract a tiny amount because torch.ceil might fail to
                # correctly compute the result sometimes
                reduced_l = reduced_l - 0.0001
                lengths = torch.ceil(
                    reduced_l).long()
                output = outs[:, selected_indices]
            elif self.output_reduction_avgpooling:
                output = outs.transpose(2, 1)
                output = self.pooling(output)
                # we want to do SumPooling instead of average pooling
                output = output * self.output_reduction
                output = output.transpose(2, 1)
                lengths = torch.ceil(lengths / self.output_reduction).long()
        else:
            output = outs

        if self.output_positional_embeddings:
            output = self.pe(output)

        # zero out the invalid items
        mask = self._make_mask(
            output.size(0), output.size(1), lengths)
        mask = mask.unsqueeze(-1)
        output = output * mask

        if self.centroid_attention:
            output, lengths = self.centroid_attention_module(
                output, lengths.to(output.device))

        if self.use_l0drop:
            return (output, lengths.to(output.device)), l0_loss
        elif self.act:
            return (output, lengths.to(output.device)), act_cost
        else:
            return output, lengths.to(output.device)
