import torch
from torch import nn
from collections import OrderedDict
import numpy as np
from random import shuffle
from autoencoders.base_ar_decoder import BaseARDecoder
from fast_transformers.builders import TransformerDecoderBuilder, RecurrentDecoderBuilder, TransformerEncoderBuilder
from fast_transformers.masking import TriangularCausalMask, LengthMask
from autoencoders.torch_utils import PositionalEncoder
import random


def generate_linear_layers(in_size, layer_count, identity):
    lin_layers = []
    for i in range(layer_count):
        l = torch.nn.Linear(in_size, in_size, bias=not identity)

        # initialize the weights
        if identity:
            init_weights = _init_random_identity(in_size)
            l.weight.data = nn.Parameter(init_weights)
        lin_layers.append(l)

    return lin_layers


def generate_carry_gate_layers(in_size, layer_count, identity):
    lin_layers = []
    for i in range(layer_count):
        l = nn.Sequential(
            torch.nn.Linear(in_size, in_size, bias=False),
            torch.nn.ReLU(),
            torch.nn.Linear(in_size, in_size, bias=True)
        )

        # initialize the weights
        if identity:
            init_weights = _init_random_identity(in_size)
            l.weight.data = nn.Parameter(init_weights)
        lin_layers.append(l)

    return lin_layers


def _init_random_identity(n):
    """Random normal initialization around 0., but add 1. at the diagonal"""
    init_weights = np.random.uniform(
        size=(n, n),
        low=-1. / float(n),
        high=+1. / float(n)).astype(
        np.float32)
    init_weights += np.eye(n, dtype=np.float32)
    #init_weights = np.reshape(init_weights, (-1))
    init_weights = torch.from_numpy(init_weights)
    return init_weights


class Id(nn.Module):

    def __init__(self):
        super(Id, self).__init__()

    def forward(self, x):
        return x


class BovIdentity(nn.Module):

    def __init__(self):
        super(BovIdentity, self).__init__()

    def forward(self, x, lens, Y=None, teacher_forcing=1.0, decode_till_Y=False):
        return x, lens


class BovOracle(nn.Module):

    def __init__(self):
        super(BovOracle, self).__init__()

    def forward(self, x, lens, Y=None, teacher_forcing=1.0, decode_till_Y=False):
        if Y is None:
            raise ValueError("BovOracle received no Y.")
        return Y


class HighwayNetwork(torch.nn.Module):
    def __init__(self, in_size, out_size, layer_count, nonlinear_function=torch.nn.Sigmoid(), activation=Id(), bias=-1., final_layer=False):
        super(HighwayNetwork, self).__init__()

        self.carry_gate_list = torch.nn.ModuleList(
            generate_carry_gate_layers(in_size, layer_count, identity=False))
        self.linear_term_list = torch.nn.ModuleList(
            generate_linear_layers(in_size, layer_count, identity=True))
        self.nonlinear_function = nonlinear_function
        self.out_size = out_size
        self.activation = activation

        # whether to add another linear layer after
        if final_layer:
            self.final_layer = torch.nn.Linear(in_size, out_size)
        else:
            self.final_layer = None

        for carry_gate in self.carry_gate_list:
            for l in carry_gate:
                if hasattr(l, "bias") and l.bias is not None:
                    l.bias.data.fill_(bias)

    def forward(self, x):
        out = x

        for carry_gate, linear_term in zip(self.carry_gate_list, self.linear_term_list):
            gate = self.nonlinear_function(carry_gate(out))
            H = self.activation(linear_term(out))
            out = gate * H + (1.0 - gate) * out
            #out = self.activation(linear_term(out))

        if self.final_layer is not None:
            out = self.final_layer(out)

        return out


class OffsetVectorMLP(nn.Module):

    def __init__(self, embedding_size, n_layers, nonlinearity="SELU", dropout="AlphaDropout",
                 dropout_p=0, outlayers=False, activate_result=False, offset_dropout_p=0.):
        super(OffsetVectorMLP, self).__init__()

        activation = eval("nn." + nonlinearity)
        dropout = eval("nn." + dropout)
        self.dropout_p = dropout_p
        self.offset_dropout_p = offset_dropout_p
        self.embedding_size = embedding_size
        self.hidden_size = embedding_size
        self.outlayers = outlayers
        self.activate_result = activate_result
        if activate_result:
            self.final_activation = eval("nn." + nonlinearity)()

        inlayers = []
        if self.outlayers:
            outlayers = []
        activations = []
        dropouts = []
        offset_dropouts = []
        for i in range(n_layers):
            inlayers.append(nn.Linear(self.hidden_size, self.hidden_size))
            activations.append(activation())
            if self.outlayers:
                outlayers.append(nn.Linear(self.hidden_size, self.hidden_size))

            if dropout_p != 0:
                dropouts.append(dropout(p=dropout_p))

            if offset_dropout_p > 0.:
                offset_dropouts.append(nn.Dropout(p=offset_dropout_p))

        self.mlp_in = torch.nn.ModuleList(inlayers)
        if self.outlayers:
            self.mlp_out = torch.nn.ModuleList(outlayers)
        self.activations = torch.nn.ModuleList(activations)
        self.dropouts = torch.nn.ModuleList(dropouts)
        self.offset_dropouts = torch.nn.ModuleList(offset_dropouts)

    def forward(self, embeddings):
        x = embeddings
        for i in range(len(self.mlp_in)):
            if self.dropout_p > 0:
                x = self.dropouts[i](x)

            # compute offset vector
            offset = self.mlp_in[i](x)
            offset = self.activations[i](offset)
            if self.outlayers:
                offset = self.mlp_out[i](offset)

            if self.offset_dropout_p > 0.:
                self.offset_dropouts[i](offset)

            # and add it
            x = x + offset

            if self.activate_result:
                x = self.final_activation(x)

        return x


class ResNet(nn.Module):

    def __init__(self, embedding_size, n_layers, nonlinearity="SELU", dropout="AlphaDropout",
                 dropout_p=0, offset_dropout_p=0.):
        super(ResNet, self).__init__()

        self.offsetnet = OffsetVectorMLP(embedding_size, n_layers, nonlinearity=nonlinearity, dropout=dropout,
                                         dropout_p=dropout_p, offset_dropout_p=offset_dropout_p, outlayers=True, activate_result=True)
        self.final_layer = nn.Linear(embedding_size, embedding_size)

    def forward(self, embeddings):
        x = self.offsetnet(embeddings)
        x = self.final_layer(x)
        return x


class LearnableOffsetVector(nn.Module):
    def __init__(self, embedding_size):
        super(LearnableOffsetVector, self).__init__()
        self.weight = nn.Parameter(torch.randn(embedding_size))

    def forward(self, x):
        return x + self.weight


class MeanOffsetVectorMLP(nn.Module):
    def __init__(self, factor_init, encoder, x_input, y_input):
        super(MeanOffsetVectorMLP, self).__init__()
        self.factor = nn.Parameter(torch.tensor(factor_init))

        bsize = 10
        num_samples = 100
        shuffle(x_input)
        shuffle(y_input)
        x_mean = None
        y_mean = None
        with torch.no_grad():
            for idx in range(0, num_samples, bsize):

                x_emb = encoder(x_input[idx: idx + bsize])
                if x_mean is None:
                    x_mean = x_emb.sum(dim=0)
                else:
                    x_mean = x_mean + x_emb.sum(dim=0)

                y_emb = encoder(y_input[idx: idx + bsize])
                if y_mean is None:
                    y_mean = y_emb.sum(dim=0)
                else:
                    y_mean = y_mean + y_emb.sum(dim=0)

            x_mean = x_mean / float(num_samples)
            y_mean = y_mean / float(num_samples)

        self.x_mean = x_mean.detach()
        self.y_mean = y_mean.detach()

    def forward(self, x):

        return x + self.factor * (-self.x_mean + self.y_mean)


class FixOffsetVectorMLP(nn.Module):

    def __init__(self, embedding_size, n_layers):
        super(FixOffsetVectorMLP, self).__init__()

        offsetvectors = []

        for _ in range(n_layers):
            offsetvectors.append(LearnableOffsetVector(embedding_size))

        self.offsetvectors = torch.nn.ModuleList(offsetvectors)

    def forward(self, embeddings):
        x = embeddings
        for i in range(len(self.offsetvectors)):
            x = self.offsetvectors[i](x)
        return x

    def print_vecs(self):
        for i, f in enumerate(self. offsetvectors):
            print("Layer {}:\n norm {} \n vector: {}".format(
                i, f.weight.norm(), f.weight))


class MLP(nn.Module):

    def __init__(self, embedding_size, n_layers, hidden_size, nonlinearity="SELU", residual_connections=False, skip_connections=False, dropout="AlphaDropout", dropout_p=0):
        super(MLP, self).__init__()

        activation = eval("nn." + nonlinearity)
        dropout = eval("nn." + dropout)
        self.dropout_p = dropout_p
        self.embedding_size = embedding_size
        self.hidden_size = hidden_size

        layers = []
        activations = []
        dropouts = []
        for i in range(n_layers):
            if i == 0:
                #layers.append(("lin"+str(i), nn.Linear(embedding_size, hidden_size)))
                layers.append(nn.Linear(embedding_size, hidden_size))
            else:
                layers.append(nn.Linear(hidden_size, hidden_size))
            #layers.append(("act"+str(i), activation()))
            activations.append(activation())
            if dropout_p != 0:
                dropouts.append(dropout(p=dropout_p))

        #layers.append(("output_layer", nn.Linear(hidden_size, embedding_size)))
        layers.append(nn.Linear(hidden_size, embedding_size))
        #layers = OrderedDict(layers)
        self.embedding_size = embedding_size
        self.hidden_size = hidden_size
        self.mlp = torch.nn.ModuleList(layers)
        self.activations = torch.nn.ModuleList(activations)
        self.dropouts = torch.nn.ModuleList(dropouts)
        self.residual_connections = residual_connections
        self.skip_connections = skip_connections

    def forward(self, embeddings):
        x = embeddings
        for i in range(len(self.mlp)):
            xx = self.mlp[i](x)

            if i != len(self.mlp) - 1:
                if self.residual_connections and (self.embedding_size == self.hidden_size or i != 0):
                    xx = xx + x
                if self.skip_connections and self.embedding_size == self.hidden_size:
                    xx = xx + embeddings
                xx = self.activations[i](xx)
                if self.dropout_p != 0:
                    xx = self.dropouts[i](xx)

            x = xx

        return x


class BovToBovMapping(nn.Module):

    def __init__(self, config):
        super(BovToBovMapping, self).__init__()

        self.layers = config.layers
        self.self_att_type = config.self_att_type
        self.cross_att_type = config.cross_att_type
        self.heads = config.heads
        self.input_size = config.input_size
        self.ff_dimension = config.ff_dimension
        self.positional_embeddings = config.positional_embeddings
        self.dropout = config.dropout
        self.att_dropout = config.att_dropout
        self.backprop_through_outputs = config.backprop_through_outputs
        self.offset = config.offset
        self.output_layer = config.output_layer
        self.project_input_dimension = config.project_input_dimension
        self.point_gen = config.point_gen
        self.point_gen_offset = config.point_gen_offset
        self.point_gen_context_vector = config.point_gen_context_vector
        self.point_gen_offset_copy_dependence = config.point_gen_offset_copy_dependence
        self.point_gen_coverage = config.point_gen_coverage
        self.out_to_in = config.point_gen_out_to_in
        self.mask_first_vector = config.mask_first_vector

        if self.project_input_dimension != self.input_size:
            self.input_projection = nn.Linear(
                self.project_input_dimension, self.input_size)
        else:
            self.input_projection = None

        if self.output_layer or self.project_input_dimension != self.input_size:
            self.outlayer = nn.Linear(
                config.input_size, config.project_input_dimension)
        else:
            self.outlayer = None

        if self.point_gen:

            if self.point_gen_context_vector:
                self.pgen = nn.Sequential(
                    nn.Linear(2 * config.input_size, 2 * config.input_size),
                    nn.ReLU(),
                    nn.Linear(2 * config.input_size, 1),
                    nn.Sigmoid())
            else:
                self.pgen = nn.Sequential(
                    nn.Linear(config.input_size, config.input_size),
                    nn.ReLU(),
                    nn.Linear(config.input_size, 1),
                    nn.Sigmoid())
            self.cpy_transformation = nn.Sequential(nn.Linear(
                config.input_size, config.input_size), nn.ReLU(), nn.Linear(
                config.input_size, config.input_size))
            self.point_att_softmax = nn.Softmax(dim=-1)

            if self.point_gen_offset:
                if self.point_gen_context_vector or self.point_gen_offset_copy_dependence:
                    self.offset_net = nn.Sequential(
                        nn.Linear(2 * config.input_size, config.input_size), nn.ReLU(), nn.Linear(config.input_size, config.input_size))
                else:
                    self.offset_net = nn.Sequential(
                        nn.Linear(config.input_size, config.input_size), nn.ReLU(), nn.Linear(config.input_size, config.input_size))

        self.recurrent_transformer_decoder = self._set_builder_attributes(
            RecurrentDecoderBuilder()).get()

        if self.positional_embeddings:
            self.pe = PositionalEncoder(
                self.input_size, mul_by_sqrt=False, learned_embeddings=config.learned_positional_embeddings)

        self.max_lens = config.max_length
        self.adaptive_max_lens = config.adaptive_max_len
        self.start_embedding = nn.Parameter(
            torch.randn(self.project_input_dimension))

    def forward(self, X, X_lens, Y=None, teacher_forcing=1.0, decode_till_Y=False):

        X_lens = X_lens.to(X.device)

        if Y:
            Y_emb, Y_len = Y

        batch_size = X.size(0)
        embedded_input = self.start_embedding.repeat(
            batch_size, 1)
        state = None  # let fast-transformers initialize the state
        outputs = []

        if self.adaptive_max_lens:
            if (Y_len == 0).any():
                print("Something zero!")
                print(Y_len)
            mlens = Y_len.max().item() + 1
        else:
            mlens = self.max_lens

        if self.offset or self.point_gen:
            base_vector = torch.zeros(
                (batch_size, mlens, X.size(2)), device=X.device)

            minsize = min(mlens, X.size(1))
            base_vector[:, :minsize,
                        :] = X[:, :minsize, :]
        # print(base_vector[0,0,:])

        if self.input_projection:
            X = self.input_projection(X)

        out_lens = torch.zeros(batch_size, device=X_lens.device).long()

        if Y and decode_till_Y:
            max_seq_len = min(mlens, torch.max(Y_len).item() + 1)
        else:
            max_seq_len = mlens

        if self.point_gen and self.point_gen_coverage:
            cum_weights = torch.zeros(
                (base_vector.size(0), base_vector.size(1)), device=base_vector.device)
        for t in range(1, max_seq_len):

            if self.input_projection is not None:
                embedded_input = self.input_projection(embedded_input)

            if self.positional_embeddings:
                embedded_input = self.pe(
                    embedded_input.unsqueeze(1), t=t).squeeze(1)

            memory = X
            memory_len = X_lens.to(memory.device)
            out, state = self.recurrent_transformer_decoder(embedded_input, memory=memory,
                                                            memory_length_mask=LengthMask(
                                                                memory_len, max_len=memory.shape[1]),
                                                            state=state)

            if self.offset:
                out = base_vector[:, t - 1] + out

            dec_out = out

            if self.point_gen:
                if self.point_gen_context_vector:
                    last_layer_states = state[-1]
                    self_att_states = last_layer_states[0]
                    att_out_projection = self_att_states[0]

                    # retrieve start symbol embedding
                    context_v = att_out_projection[:, :, 0, :]
                    # print(context_v)
                    # concat heads
                    context_v = torch.reshape(
                        context_v, (context_v.size(0), -1))
                    pgen_vector = torch.cat([context_v, out], dim=-1)

                else:
                    pgen_vector = out

                # compute probability of generating
                p_gen = self.pgen(pgen_vector)
                # print(p_gen.mean())

                # for i in range(base_vector.size(1)):
                #    if not (base_vector[0,i] == 0).all():
                #        print("basevector\n",base_vector[0,i])
                # compute copy-vector
                keys = self.cpy_transformation(base_vector)
                # print("keys\n",keys[0,2])
                # print("keys\n",keys[0,3])
                #context_v = context_v.unsqueeze(1)
                if not self.point_gen_context_vector:
                    context_v = out.unsqueeze(1)
                att_logits = torch.bmm(
                    keys, context_v.transpose(1, 2)).squeeze()

                # make sure we mask out invalid tokens
                X_mask = (torch.arange(base_vector.size(1),
                                       device=base_vector.device).unsqueeze(0).expand(base_vector.size(0), -1) < X_lens.unsqueeze(1))
                att_logits = att_logits - (1. - X_mask.float()) * 10e10

                if self.mask_first_vector:
                    att_logits[:, 0] = -10e10

                if self.point_gen_coverage:
                    att_logits[cum_weights > 0.99] = -10e10

                att_weights = self.point_att_softmax(att_logits)

                if self.point_gen_coverage:
                    cum_weights = cum_weights + (1 - cum_weights) * att_weights

                #print(t, "attWeights\n",att_weights[0])

                if self.point_gen_coverage:
                    att_w = (1 - cum_weights).unsqueeze(-1) * \
                        att_weights.unsqueeze(-1)
                    # print(att_w[0])
                    cpy_vector = (att_w * base_vector).sum(1)
                else:
                    cpy_vector = (att_weights.unsqueeze(-1)
                                  * base_vector).sum(1)

                    # compute offset vector
                if self.point_gen_offset:
                    if self.point_gen_context_vector:
                        offset_vector = self.offset_net(
                            torch.cat([context_v, out], dim=-1))
                    elif self.point_gen_offset_copy_dependence:
                        offset_vector = self.offset_net(
                            torch.cat([context_v.squeeze(), cpy_vector], dim=-1))
                    else:
                        offset_vector = self.offset_net(
                            out)

            if self.outlayer is not None:
                out = self.outlayer(out)

            if self.point_gen:
                if self.point_gen_offset:
                    out = (cpy_vector + offset_vector) * \
                        (1 - p_gen) + p_gen * out
                else:
                    out = cpy_vector * (1 - p_gen) + p_gen * out
                    # print(p_gen.mean())

            if self.out_to_in:
                o = out
            else:
                o = dec_out

            if self.backprop_through_outputs:
                embedded_input = o
            elif Y is not None and t - 1 < torch.max(Y_len).item() and random.random() < teacher_forcing and self.training:
                embedded_input = Y_emb[:, t - 1]
            else:
                embedded_input = o.detach()

            outputs.append(out.unsqueeze(1))

        out_lens[:] = t
        out_lens = out_lens.to(X.device)

        outputs = torch.cat(outputs, dim=1)
        return outputs, out_lens

    def _is_done(self, out, t):
        """
        Supposed to return a vector of size out.size(0) with True, False flags
        in those positions that are "done".
        """
        return torch.rand(out.size(0)) < 0.1

    def _set_builder_attributes(self, builder):
        builder.n_layers = self.layers
        builder.n_heads = self.heads
        builder.feed_forward_dimensions = self.ff_dimension
        builder.query_dimensions = int(self.input_size / self.heads)
        builder.value_dimensions = int(self.input_size / self.heads)
        builder.dropout = self.dropout
        builder.attention_dropout = self.att_dropout
        builder.self_attention_type = self.self_att_type
        builder.cross_attention_type = self.cross_att_type
        return builder


class SimpleBovMapping(nn.Module):

    def __init__(self, config):
        super(SimpleBovMapping, self).__init__()

        self.layers = config.layers
        self.self_att_type = config.self_att_type
        self.heads = config.heads
        self.input_size = config.input_size
        self.ff_dimension = config.ff_dimension
        self.positional_embeddings = config.positional_embeddings
        self.dropout = config.dropout
        self.att_dropout = config.att_dropout
        self.offset = config.offset
        self.output_layer = config.output_layer
        self.project_input_dimension = config.project_input_dimension

        if self.project_input_dimension != self.input_size:
            self.input_projection = nn.Linear(
                self.project_input_dimension, self.input_size)
        else:
            self.input_projection = None

        if self.output_layer or self.project_input_dimension != self.input_size:
            self.outlayer = nn.Linear(
                config.input_size, config.project_input_dimension)
        else:
            self.outlayer = None

        self.transformer_encoder = self._set_builder_attributes(
            TransformerEncoderBuilder()).get()

        if self.positional_embeddings:
            self.pe = PositionalEncoder(
                self.input_size, mul_by_sqrt=False, learned_embeddings=config.learned_positional_embeddings)

    def forward(self, X, X_lens, Y=None, teacher_forcing=1.0, decode_till_Y=False):

        X_lens = X_lens.to(X.device)

        batch_size = X.size(0)

        if self.offset:
            base_vector = torch.zeros(
                (batch_size, mlens, X.size(2)), device=X.device)

            minsize = min(mlens, X.size(1))
            base_vector[:, :minsize,
                        :] = X[:, :minsize, :]
        # print(base_vector[0,0,:])

        if self.input_projection:
            X = self.input_projection(X)

        if self.positional_embeddings:
            X = self.pe(
                X)

        memory = X
        memory_len = X_lens.to(memory.device)
        maxlen = memory.size(1) if not Y else max(Y[0].size(1), memory.size(1))
        mask = LengthMask(memory_len, max_len=maxlen,
                          device=memory.device)
        outputs = self.transformer_encoder(memory, length_mask=mask)

        if self.outlayer:
            outputs = self.outlayer(outputs)

        if self.offset:
            outputs = X + outputs

        #assert torch.equal(X_lens, Y[1])
        return outputs, X_lens

    def _is_done(self, out, t):
        """
        Supposed to return a vector of size out.size(0) with True, False flags
        in those positions that are "done".
        """
        return torch.rand(out.size(0)) < 0.1

    def _set_builder_attributes(self, builder):
        builder.n_layers = self.layers
        builder.n_heads = self.heads
        builder.feed_forward_dimensions = self.ff_dimension
        builder.query_dimensions = int(self.input_size / self.heads)
        builder.value_dimensions = int(self.input_size / self.heads)
        builder.dropout = self.dropout
        builder.attention_dropout = self.att_dropout
        builder.attention_type = self.self_att_type
        return builder
