import torch
from autoencoder import Decoder
from autoencoders.base_ar_decoder import BaseARDecoder
from fast_transformers.builders import TransformerDecoderBuilder, RecurrentDecoderBuilder
from fast_transformers.masking import TriangularCausalMask, LengthMask
from autoencoders.transformer_encoder import PositionalEncoder
import torch.nn.functional as F


class TransformerDecoder(BaseARDecoder):
    def __init__(self, config):
        super(TransformerDecoder, self).__init__(config)

        self.layers = config.layers
        self.self_att_type = config.self_att_type
        self.cross_att_type = config.cross_att_type
        self.heads = config.heads
        self.input_size = config.input_size
        self.ff_dimension = config.ff_dimension
        self.positional_embeddings = config.positional_embeddings
        self.dropout = config.decoder_dropout

        builder = TransformerDecoderBuilder()
        self._set_builder_attributes(builder)
        self.transformer_decoder = builder.get()
        self.recurrent_transformer_decoder = self._set_builder_attributes(
            RecurrentDecoderBuilder()).get()

        if self.positional_embeddings:
            self.pe = PositionalEncoder(self.input_size)

        self._tie_weights()

    def _tie_weights(self):

        recurrent_ms = self.recurrent_transformer_decoder.named_modules()
        transformer_ms = self.transformer_decoder.named_modules()

        for (_, m1), (n2, m2) in zip(recurrent_ms, transformer_ms):
            for n, p in m2.named_parameters(recurse=False):
                setattr(m1, n, p)

        # for p1, p2 in zip(self.recurrent_transformer_decoder.parameters(), self.transformer_decoder.parameters()):
            #p1.data = p2.data

        #self.recurrent_transformer_decoder.load_state_dict(self.transformer_decoder.state_dict(), strict=True)

    def _get_output_and_update_memory(self, embedded_input, state, embedding, t):
        #assert not self.training

        if self.positional_embeddings:
            embedded_input = self.pe(embedded_input, t=t)

        memory = embedding[0]

        if self.cross_att_type == "gated":
            memory = memory[..., :-2]
        # print(memory.shape[1])
        memory_len = embedding[1]

        if embedded_input.size(0) % memory.size(0) == 0:
            memory = memory.repeat(
                embedded_input.size(0) // memory.size(0), 1, 1)
            memory_len = memory_len.repeat(
                embedded_input.size(0) // memory_len.size(0))
        else:
            print("Size memory:", memory.size())
            print("Size embedded_input:", embedded_input.size())
            raise ValueError(
                "Embedded input and memory have differing batch sizes!")

        # print(memory_len)
        out, state = self.recurrent_transformer_decoder(embedded_input.squeeze(1), memory=memory,
                                                        memory_length_mask=LengthMask(
                                                            memory_len, max_len=memory.shape[1]),
                                                        state=state)
        return out.unsqueeze(1), state, t + 1

    def load_state_dict(self, state_dict, strict=True):
        super(TransformerDecoder, self).load_state_dict(
            state_dict, strict=strict)

        # need to make sure the weights are tied after loading
        self._tie_weights()

    def _decode_all(self, embedded_teacher, h, l):

        #packed_teacher = pack_padded_sequence(embedded_teacher, l, batch_first=True)
        #packed_output, h = self.decoder(packed_teacher, h)
        #output, _ = pad_packed_sequence(packed_output, batch_first=True)

        if self.positional_embeddings:
            embedded_teacher = self.pe(embedded_teacher)

        memory = h[0]
        memory_len = h[1]
        memory_len_mask = LengthMask(memory_len, max_len=memory.size(
            1), device=embedded_teacher.device)
        length_mask = LengthMask(l, max_len=embedded_teacher.size(
            1), device=embedded_teacher.device)
        causal_mask = TriangularCausalMask(
            embedded_teacher.size(1), device=embedded_teacher.device)
        outputs = self.transformer_decoder(embedded_teacher, memory, x_mask=causal_mask, x_length_mask=length_mask,
                                           memory_mask=None, memory_length_mask=memory_len_mask)
        return outputs

    def init_hidden_greedy(self, x):
        if self.cross_att_type == "gated":
            # in case of gated attention, the encoder appends the gate values
            # to the last dimension
            return [(None, (None, None, x[0][:, :, -2])) for _ in range(self.layers)]
        else:
            return None

    def init_hidden_batchwise(self, x):
        return x

#     def train(self, mode=True):
#         super(TransformerDecoder, self).train(mode=mode)
#
#         # switch to TransformerDecoder
#
#         if not self.training and mode == True:
#             self.transformer_decoder = self._replace_decoder(TransformerDecoderBuilder)
#             self.transformer_decoder.train(mode=True)
#         elif self.training and mode == False:
#             # switch to RecurrentDecoder
#
#             self.transformer_decoder = self._replace_decoder(RecurrentDecoderBuilder)
#             self.transformer_decoder.train(mode=False)
#
    def _set_builder_attributes(self, builder):
        builder.n_layers = self.layers
        builder.n_heads = self.heads
        builder.feed_forward_dimensions = self.ff_dimension
        builder.query_dimensions = int(self.input_size / self.heads)
        builder.value_dimensions = int(self.input_size / self.heads)
        builder.dropout = self.dropout
        builder.attention_dropout = self.dropout
        builder.self_attention_type = self.self_att_type
        builder.cross_attention_type = self.cross_att_type
        return builder

#     def _replace_decoder(self, BuilderClass):
#         builder = BuilderClass()
#         self._set_builder_attributes(builder)
#         decoder = builder.get()
#         decoder.load_state_dict(self.transformer_decoder.state_dict())
#         del self.transformer_decoder
#         return decoder

    def _hidden_from_beam(self, incomplete):
        # tuple of list of list of states
        s = [beam.hidden_state for batch in incomplete for beam in incomplete[batch]]
        # list of layers of tuple of states

        s = zip(*s)
        # layers of list of tuple of states

        new_layers = []
        for l in s:
            att_states = list(zip(*l))

            attentions = []
            for a in att_states:
                state_parts = list(zip(*a))

                s1 = torch.stack(state_parts[0])
                s2 = torch.stack(state_parts[1])
                attentions.append((s1, s2))
            new_layers.append(tuple(attentions))

        # want: list of states
        # def put_layers(states):
            #out_states = [[]] * self.layers
            # for s in states:
            #    for i in range(len(out_states)):
            #        out_states[i].append(s[i])
            # return out_states

        # return (put_layers(s1), put_layers(s2))
        return new_layers

    def _hidden_to_beam(self, h, indices):

        # for each layer
        states = []
        for i in range(self.layers):

            layer_states = h[i]

            # for each state-type (cross, and self-attention)
            att_types = []
            for j in range(2):
                att_state = layer_states[j]

                # for each of si,zi
                att_types.append(
                    (att_state[0][indices], att_state[1][indices]))

            states.append(tuple(att_types))

        return states

    def beam_decode(self, x, beam_width=10):
        # x = (batch, hidden_size)
        # hidden_lstm = (layers, batch, hidden)
        batch_size = x[0].shape[0]
        h = self.init_hidden_greedy(x)
        decoded = [None for i in range(batch_size)]

        # beam_width nodes per batch
        incomplete = {ba: [
            self.BeamNode(h, None, torch.tensor(self.sos_idx, device=self.device), 0, 1) for be in range(beam_width)
        ] for ba in range(batch_size)}

        # create first hypotheses:
        # lstm input: (batch, seq_len, input_size)
        # lstm output: (batch, seq_len, hidden_size)
        embedded_input = self._get_initial_inputs(x)
        embedded_input = self.input_projection(embedded_input)
        pos = 1
        decoder_output, h, pos = self._get_output_and_update_memory(
            embedded_input, h, x, pos)

        for b in range(batch_size):
            # decoder_output[b] shape: (1, hidden_size)
            log_probs = F.log_softmax(
                self._outlayer(decoder_output[b]), dim=1).squeeze(0)
            k_log_probs, k_indices = torch.topk(log_probs, beam_width)
            for i in range(beam_width):
                prev_node = incomplete[b][i]
                incomplete[b][i] = self.BeamNode(self._hidden_to_beam(h, b),
                                                 prev_node,
                                                 k_indices[i],
                                                 k_log_probs[i],
                                                 2)

        for t in range(2, self.max_sequence_len):
            if len(incomplete) == 0:
                break
            # Prepare step [ batch1_beams | batch2_beams | | ]
            embedding_input = torch.tensor(
                [beam.word_id for batch in incomplete for beam in incomplete[batch]], device=self.device)
            # keep track of the order which beams are put in
            input_order = [batch for batch in incomplete]
            # embedding_input shape: (batch * beam_len)
            # print(embedding_input.size())
            embedding_input = embedding_input.reshape(-1, 1)
            # embedding_input shape: (batch*beam_len, 1[seq_len])
            embedded_input = self.embedding(embedding_input)
            # embedded_input shape: (batch*beam_len, 1, input_size)
            embedded_input = self.input_projection(embedded_input)

            # list of lens number of layers in the decoder, containing the
            # states
            h = self._hidden_from_beam(incomplete)
            # print(h.size())

            incomplete_idxs = []
            for k, _ in incomplete.items():
                incomplete_idxs.append(k)
            decoder_output, h, pos = self._get_output_and_update_memory(
                embedded_input, h, (x[0][incomplete_idxs], x[1][incomplete_idxs]), pos)
            # x = x.
            # lstm output: (batch*beam_len, 1, hidden_size)
            for batch_index, batch in enumerate(input_order):
                # Each batch is a seperate beam search.
                # Get the probabilites from each beam
                log_probs = F.log_softmax(self._outlayer(
                    decoder_output[batch_index * beam_width:(batch_index + 1) * beam_width].squeeze(1)), dim=1)

                # Put all the beam probabilities in a single vector, with the
                # full seq prob
                seq_probs = torch.cat(
                    [incomplete[batch][i].log_prob + log_probs[i] for i in range(beam_width)])

                # Get the top k
                k_seq_probs, k_indices = torch.topk(seq_probs, beam_width)

                new_beams = []

                for seq_prob, index in zip(k_seq_probs, k_indices):
                    beam_index = index // self.vocab_size
                    word_index = index % self.vocab_size
                    prev_beam = incomplete[batch][beam_index]
                    if word_index == self.eos_idx:
                        # we hit the end of the sequence! Therefore, this element
                        # of the batch is now complete.

                        # Since we wont be training, we will turn these into regular
                        # values, rather than tensors.
                        seq = [self.eos_idx]
                        prev = prev_beam
                        while prev != None:
                            seq.append(prev.word_id.cpu().item())
                            prev = prev.previous_node
                        seq = seq[::-1]
                        decoded[batch] = seq
                        del incomplete[batch]
                        break
                    new_beams.append(
                        self.BeamNode(
                            self._hidden_to_beam(
                                h, batch_index * beam_width + beam_index),
                            prev_beam,
                            word_index,
                            seq_prob,
                            prev_beam.length + 1))

                # if we didn't complete the sequence
                if batch in incomplete:
                    incomplete[batch] = new_beams

        # For elements which hit the max seq length, we will cut them off at the
        # most probable sequence so far.
        for batch in incomplete:
            seq = [self.eos_idx]
            # The first beam will be the most probable sequence so far
            prev = incomplete[batch][0]
            while prev != None:
                seq.append(prev.word_id.cpu().item())
                prev = prev.previous_node
            seq = seq[::-1]
            decoded[batch] = seq

        return self.clip_predictions(decoded)


def flatten(l): return [item for sublist in l for item in sublist]
