import torch
from autoencoder import Decoder
from autoencoders.base_ar_decoder import BaseARDecoder
from fast_transformers.builders import TransformerEncoderBuilder, RecurrentEncoderBuilder
from fast_transformers.masking import TriangularCausalMask, LengthMask
from autoencoders.transformer_encoder import PositionalEncoder

class SimpleTransformerDecoder(BaseARDecoder):
    def __init__(self, config):
        super(SimpleTransformerDecoder, self).__init__(config)
        
        self.layers = config.layers
        self.attn_type = config.att_type
        self.heads = config.heads
        self.input_size = config.input_size
        self.ff_dimension = config.ff_dimension
        self.positional_embeddings = config.positional_embeddings
        self.dropout = config.dropout
        self.att_dropout = config.att_dropout

        builder = TransformerEncoderBuilder()
        self._set_builder_attributes(builder)
        self.transformer_decoder = builder.get()
        self.recurrent_transformer_decoder = self._set_builder_attributes(RecurrentEncoderBuilder()).get()
        
        if self.positional_embeddings:
            self.pe = PositionalEncoder(self.input_size)
        
        self._tie_weights()
        
        self.learn_embedding_transformation = config.learn_embedding_transformation
        if self.learn_embedding_transformation:
            self.embedding_transformation = torch.nn.Linear(self.input_size, self.input_size)

    def _tie_weights(self):

        recurrent_ms = self.recurrent_transformer_decoder.named_modules()
        transformer_ms = self.transformer_decoder.named_modules()

        for (_, m1), (n2,m2) in zip(recurrent_ms, transformer_ms):
            for n,p in m2.named_parameters(recurse=False):
                setattr(m1, n, p)
                
    def _get_output_and_update_memory(self, embedded_input, state, embedding, t):
        #assert not self.training
        
        for i in range(embedded_input.size(1)):
        
            input_at_t = embedded_input[:,i,:]
            if self.positional_embeddings:
                input_at_t = self.pe(input_at_t.unsqueeze(1), t = t).squeeze(1)
            out, state = self.recurrent_transformer_decoder(input_at_t, 
                                                  state=state)
            t = t + 1
        return out.unsqueeze(1), state, t

    def load_state_dict(self, state_dict, strict=True):
        super(SimpleTransformerDecoder, self).load_state_dict(state_dict, strict=strict)

        # need to make sure the weights are tied after loading
        self._tie_weights()
        
    def _decode_eval(self, x, beam_width):
        # TODO: need to implement beamsearch for this decoder
        if False:
            return self.beam_decode(x, beam_width)
        else:
            return self.greedy_decode(x)
        
    def _decode_all(self, embedded_teacher, h, l):
        
        #packed_teacher = pack_padded_sequence(embedded_teacher, l, batch_first=True)
        #packed_output, h = self.decoder(packed_teacher, h)
        #output, _ = pad_packed_sequence(packed_output, batch_first=True)
        # append embedding to 0th position
        if self.learn_embedding_transformation:
            h = self.embedding_transformation(h)
        
        inputs = torch.cat([h, embedded_teacher], dim = 1)
        
        if self.positional_embeddings:
            inputs = self.pe(inputs)
        
        length_mask = LengthMask(l+1, max_len=inputs.size(1), device=embedded_teacher.device)
        causal_mask = TriangularCausalMask(inputs.size(1), device=embedded_teacher.device)
        outputs = self.transformer_decoder(inputs, attn_mask=causal_mask, length_mask=length_mask)
        
        # we need to remove the first output because it corresponds to the embedding
        return outputs[:, 1:, :]
    
    def init_hidden_greedy(self, x):
        return None
    
    def init_hidden_batchwise(self, x):
        return x.unsqueeze(1)
    
    def _get_initial_inputs(self, x):
        if self.learn_embedding_transformation:
            x = self.embedding_transformation(x)
        
        initial_inputs = self.embedding(torch.tensor([[self.sos_idx]], device=self.device).repeat(x.shape[0], 1))
        initial_inputs = torch.cat([x.unsqueeze(1), initial_inputs], dim = 1)
        return initial_inputs
        
#     def train(self, mode=True):
#         super(TransformerDecoder, self).train(mode=mode)
#         
#         # switch to TransformerDecoder
#         
#         if not self.training and mode == True:
#             self.transformer_decoder = self._replace_decoder(TransformerDecoderBuilder)
#             self.transformer_decoder.train(mode=True)
#         elif self.training and mode == False:
#             # switch to RecurrentDecoder
#         
#             self.transformer_decoder = self._replace_decoder(RecurrentDecoderBuilder)
#             self.transformer_decoder.train(mode=False)
#         
    def _set_builder_attributes(self, builder):
        builder.n_layers = self.layers
        builder.n_heads = self.heads
        builder.feed_forward_dimensions = self.ff_dimension
        builder.query_dimensions = int(self.input_size / self.heads)
        builder.value_dimensions = int(self.input_size / self.heads)
        builder.dropout = self.dropout
        builder.attention_dropout = self.att_dropout
        builder.attention_type = self.attn_type
        return builder

    def _hidden_from_beam(self, incomplete):
        s = [beam.hidden_state for batch in incomplete for beam in incomplete[batch]] #tuple of list of list of states
        # list of layers of tuple of states
        
        s = zip(*s)
        #layers of list of tuple of states
        
        new_layers = []
        for l in s:
            att_states = list(zip(*l))
            
            attentions = []
            for a in att_states:
                state_parts = list(zip(*a))
        
                s1 = torch.stack(state_parts[0])
                s2 = torch.stack(state_parts[1])
                attentions.append((s1,s2))
            new_layers.append(tuple(attentions))
            
        # want: list of states
        #def put_layers(states):
            #out_states = [[]] * self.layers
            #for s in states:
            #    for i in range(len(out_states)):
            #        out_states[i].append(s[i])
            #return out_states
        
        #return (put_layers(s1), put_layers(s2))
        return new_layers
    
    def _hidden_to_beam(self, h, indices):
        
        # for each layer
        states = []
        for i in range(self.layers):
            
            layer_states = h[i]
            
            # for each state-type (cross, and self-attention)
            att_types = []
            for j in range(2):
                att_state = layer_states[j]
                
                # for each of si,zi
                att_types.append((att_state[0][indices], att_state[1][indices]))
                
            states.append(tuple(att_types))
                
                          
        return states

flatten = lambda l: [item for sublist in l for item in sublist]