import numpy as np

import chainer
from chainer import Variable, Chain, cuda, initializers, utils
import chainer.links as L
import chainer.functions as F
from chainer.configuration import config

#class ResLayer(Chain):
#    def __init__(self, ch, kernel, dropout):
#        self.dropout = dropout
#        super().__init__(
#                conv1 = L.ConvolutionND(1, ch, ch, kernel, pad=(kernel-1)//2,
#                                        nobias=True),
#                bn1 = L.BatchNormalization(ch),
#                conv2 = L.ConvolutionND(1, ch, ch, kernel, pad=(kernel-1)//2,
#                                        nobias=True),
#                bn2 = L.BatchNormalization(ch))
#
#    def __call__(self, x):
#        h = F.relu(self.bn1(self.conv1(F.dropout(x, self.dropout))))
#        h = F.relu(self.bn2(self.conv2(h)))
#        return F.relu(h + x)
#
#
#class ResNetTransducer(Chain):
#    def __init__(self, alphabet, features, embedding_size,
#                 n_layers, ch, kernel_size, dropout):
#        super().__init__(
#            embeddings=L.EmbedID(len(alphabet), embedding_size,
#                                 ignore_label=-1),
#            conv1 = L.ConvolutionND(
#                1, embedding_size+len(features), ch, kernel_size,
#                pad=(kernel_size-1)//2, nobias=True),
#            bn1 = L.BatchNormalization(ch),
#            output = L.Convolution1D(1, ch, len(alphabet), 1))
#
#        self.forward = [('res%d'%i, ResLayer(ch, kernel_size, dropout))
#                        for i in range(n_layers)]
#
#        for link in self.forward: self.add_link(*link)
#
#    def __call__(self, x):
#        x = self.embeddings(x)
#        x = F.relu(self.bn1(self.conv1(x)))
#        for name, _ in self.forward:
#            f = getattr(self, name)
#            x = f(x)
#        x = self.output(x)
#        return x

# class RNNTransducerState:
#     def __init__(self, model, encoded, features, h=None, c=None):
#         self.model = model
#         self.encoded = encoded
#         self.features = features
#         self.h = h
#         self.c = c
#
#     def __call__(self, x_tm1):
#         h, c, y = self.model.decoder(
#                 self.h, self.c,
#                 [F.expand_dims(x, 0)
#                  for x in F.concat((self.model.embeddings(x_tm1),
#                                     self.encoded, self.features))
#                 ])
#         state = RNNTransducerState(
#                 self.model, self.encoded, self.features, h, c)
#         state.p = self.model.output(h[-1])
#         return state
#
#
# class RNNTransducer(Chain):
#     def __init__(
#             self,
#             alphabet,
#             features,
#             embedding_size=32,
#             encoder_size=128,
#             decoder_size=128,
#             features_size=128,
#             dropout=0.0):
#
#         self.alphabet = alphabet
#         self.features = features
#         self.encoder_size = encoder_size
#
#         super().__init__(
#             embeddings=L.EmbedID(len(alphabet), embedding_size,
#                                  ignore_label=-1),
#             encoder=L.NStepLSTM(1, embedding_size, encoder_size,
#                                 dropout),
#             decoder=L.NStepLSTM(1, encoder_size+features_size+embedding_size,
#                                 decoder_size, dropout),
#             features_encoder=L.Linear(len(features), features_size),
#             output=L.Linear(decoder_size, len(alphabet)))
#
#     def __call__(self, in_strings, features):
#         _, _, encoded = self.encoder(
#                 None, None,
#                 [self.embeddings(s) for s in in_strings])
#         encoded = F.concat(encoded, axis=0)
#
#         encoded_features = F.relu(self.features_encoder(features))
#         duplicated_features = F.concat([
#             encoded_features[i:i+1]
#             for i,s in enumerate(in_strings)
#             for _ in range(s.shape[0])], axis=0)
#
#         return  RNNTransducerState(self, encoded, duplicated_features)




# class DecoderState:
#     def __init__(self, model, encoded, invariant, h=None, c=None):
#         self.model = model
#         self.encoded = encoded
#         self.invariant = invariant
#         self.c = c
#         self.h = h
#
#     def __call__(self, x_tm1):
#         attention, summary = self.model.attention(
#                 self.h, self.encoded, self.invariant)
#         c, h = self.model.decoder(
#                 self.c, self.h,
#                 F.concat((self.model.embeddings(x_tm1), summary)))
#         state = DecoderState(self.model, self.encoded, self.invariant, h, c)
#         state.p = self.model.output(F.tanh(self.model.hidden(h)))
#         state.attention = attention
#         return state
#
#
# class EncoderDecoder(Chain):
#     def __init__(
#             self,
#             alphabet,
#             embedding_size=32,
#             encoder_size=128,
#             decoder_size=128,
#             attention_size=128):
#
#         self.alphabet = alphabet
#         self.alphabet_idx = {c:i for i,c in enumerate(alphabet)}
#
#         super().__init__(
#             embeddings=L.EmbedID(len(alphabet), embedding_size,
#                                  ignore_label=-1),
#             encoder=L.LSTM(embedding_size, encoder_size),
#             decoder=L.StatelessLSTM(encoder_size+embedding_size, decoder_size),
#             hidden=L.Linear(decoder_size, embedding_size),
#             output=L.Linear(embedding_size, len(alphabet)),
#             attention=Attention(encoder_size, decoder_size, attention_size))
#
#     def __call__(self, source):
#         # source: list of int32(sequence_length,) sorted by length
#         batch_size = len(source)
#
#         self.encoder.reset_state()
#         encoded = F.transpose_sequence(
#             [self.encoder(self.embeddings(x_t))
#              for x_t in F.transpose_sequence(source)])
#
#         invariant = self.attention.get_invariant(encoded)
#
#         return DecoderState(self, encoded, invariant)

class Attention(Chain):
    def __init__(self, attended_size, decoder_size, attention_size):
        super().__init__(
                invariant=L.Linear(attended_size, attention_size),
                variant=L.Linear(decoder_size, attention_size, nobias=True),
                predict=L.Linear(attention_size, 1))

    def get_invariant(self, attended):
        return [self.invariant(x) for x in attended]

    def __call__(self, decoder_state, attended, invariant):
        batch_size = len(attended)

        # decoder_state (batch_size, decoder_size)  or None
        # invariant     list(batch_size) of (sent_len, attention_size)
        # attended      list(batch_size) of (sent_len, encoder_size)

        if decoder_state is None:
            attention = [F.softmax(F.expand_dims(
                            self.predict(F.tanh(y)), 0))[0]
                         for y in invariant]
        else:
            variant = self.variant(decoder_state)
            attention = [F.softmax(F.expand_dims(
                            self.predict(F.tanh(
                                F.broadcast_to(x, y.shape) + y)), 0))[0]
                         for x, y in zip(variant, invariant)]

        result = F.vstack([
            F.sum(F.scale(x, p.reshape(p.shape[:1]), axis=0), axis=0)
            for x,p in zip(attended, attention)])

        return attention, result

class ForwardBackwardState:
    def __init__(self, model, forward, encoded, invariant,
                 h=None, c=None, encoded_features=None, dropout_states=None):
        self.forward = forward
        self.backward = not forward
        self.model = model
        self.encoded = encoded
        self.invariant = invariant
        self.c = c
        self.h = h
        self.encoded_features = encoded_features

        n_dropout_states = 1
        if (not config.train) or (model.recurrent_dropout == 0):
            self.dropout_states = [
                    (lambda x: x) for _ in range(n_dropout_states)]
        elif dropout_states is None:
            self.dropout_states = [F.Dropout(model.recurrent_dropout)
                                   for _ in range(n_dropout_states)]
        else:
            self.dropout_states = dropout_states

    def __call__(self, x_tm1):
        if self.forward:
            attention, summary = self.model.f_attention(
                    self.h, self.encoded, self.invariant)
            x = F.concat((self.model.embeddings(x_tm1), summary,
                          self.encoded_features))
            h = None if self.h is None else self.dropout_states[0](self.h)
            c, h = self.model.f_decoder(self.c, h, x)
        else:
            attention, summary = self.model.b_attention(
                    self.h, self.encoded, self.invariant)
            x = F.concat((self.model.embeddings(x_tm1), summary))
            h = None if self.h is None else self.dropout_states[0](self.h)
            c, h = self.model.b_decoder(self.c, h, x)

        state = ForwardBackwardState(
                self.model, self.forward, self.encoded, self.invariant, h, c,
                encoded_features=self.encoded_features,
                dropout_states=self.dropout_states)

        if self.forward:
            hh = F.tanh(self.model.f_hidden(h))
        else:
            hh = F.tanh(self.model.b_hidden(h))
        state.p = self.model.fb_output(hh)
        state.attention = attention
        return state


class ForwardBackward(Chain):
    def __init__(
            self,
            alphabet,
            features,
            language_embeddings,
            embedding_size=32,
            encoder_size=128,
            decoder_size=128,
            features_size=128,
            language_embedding_size=64,
            attention_size=128,
            dropout=0.0,
            recurrent_dropout=0.0,
            bidirectional=False,
            use_lembs=False):

        self.features = features
        self.alphabet = alphabet
        self.language_embeddings = language_embeddings[1]
        self.languages_idx = {l:i for i,l in enumerate(language_embeddings[0])}
        self.alphabet_idx  = {c:i for i,c in enumerate(alphabet)}
        self.features_idx  = {c:i for i,c in enumerate(features)}
        self.dropout = dropout
        self.recurrent_dropout = recurrent_dropout
        self.bidirectional = bidirectional
        self.use_lembs = use_lembs

        kwargs = {}
        if bidirectional:
            kwargs['f_encoder_b']=L.LSTM(encoder_size, encoder_size)
            kwargs['b_encoder_b']=L.LSTM(encoder_size, encoder_size)

        super().__init__(
            embeddings=L.EmbedID(len(alphabet), embedding_size,
                                 ignore_label=-1),
            l_embeddings=L.EmbedID(len(language_embeddings[0]),
                        language_embedding_size,
                        initialW=self.language_embeddings),
            f_encoder=L.LSTM(embedding_size, encoder_size),
            b_encoder=L.LSTM(embedding_size, encoder_size),
            f_decoder=L.StatelessLSTM(
                encoder_size+embedding_size+features_size, decoder_size),
            b_decoder=L.StatelessLSTM(
                encoder_size+embedding_size, decoder_size),
            f_hidden=L.Linear(decoder_size, embedding_size),
            b_hidden=L.Linear(decoder_size, embedding_size),
            fb_output=L.Linear(embedding_size, len(alphabet)),
            f_features1=L.Linear(len(features), features_size),
            f_features2=L.Linear(features_size, features_size),
            b_features1=L.Linear(encoder_size, features_size),
            b_features2=L.Linear(features_size, len(features)),
            f_attention=Attention(encoder_size, decoder_size, attention_size),
            b_attention=Attention(encoder_size, decoder_size, attention_size),
            **kwargs)

    def forward(self, source, features, languages):
        # source: list of int32(sequence_length,) sorted by length
        # features: float32(batch_size, len(features)) with the corresponding
        #           feature values for the source sequences
        batch_size = len(source)

        encoded_features = F.dropout(
                F.relu(self.f_features2(
                    F.dropout(
                        F.relu(self.f_features1(features)),
                        self.dropout))),
                self.dropout)

        self.f_encoder.reset_state()
        encoded = F.transpose_sequence(
            [self.f_encoder(self.embeddings(x_t))
             for x_t in F.transpose_sequence(source)])


        if self.use_lembs:
            encoded = [F.concat([c_encoded,
                       F.concat([self.l_embeddings(languages[idx]), self.l_embeddings(languages[idx])])], axis=0)
                       for idx, c_encoded in enumerate(encoded)]
        #encoded =
        if self.bidirectional:
            encoded = [ys[::-1] for ys in F.transpose_sequence(
                [self.f_encoder_b(x_t)
                    for x_t in F.transpose_sequence(
                        [xs[::-1] for xs in encoded])])]

        invariant = self.f_attention.get_invariant(encoded)

        return ForwardBackwardState(self, True, encoded, invariant,
                                    encoded_features=encoded_features)

    def backward(self, source, languages):
        # source: list of int32(sequence_length,) sorted by length
        self.l_embeddings.enable_update()

        batch_size = len(source)

        self.b_encoder.reset_state()
        encoded = F.transpose_sequence(
            [self.b_encoder(self.embeddings(x_t))
             for x_t in F.transpose_sequence(source)])

        # print(languages[0])
        # import ipdb; ipdb.set_trace()
        if self.use_lembs:
            encoded = [F.concat([c_encoded,
                       F.concat([self.l_embeddings(languages[idx]), self.l_embeddings(languages[idx])])], axis=0)
                       for idx, c_encoded in enumerate(encoded)]

        if self.bidirectional:
            encoded = [ys[::-1] for ys in F.transpose_sequence(
                [self.b_encoder_b(x_t)
                    for x_t in F.transpose_sequence(
                        [xs[::-1] for xs in encoded])])]

        features_h = F.relu(self.b_features1(
                        F.stack([e[-1] for e in encoded], 0)))
        features = self.b_features2(F.dropout(features_h, self.dropout))

        invariant = self.b_attention.get_invariant(encoded)

        state = ForwardBackwardState(self, False, encoded, invariant)
        state.p_features = features
        return state
