# -*- coding: utf-8 -*-

import torch
from torch import nn

from ...common.dataclass_options import OptionsBase
from ..activations import get_activation


def collect_neighbor_nodes(node_embedding, indices):
    """
    node_embedding: [batch_size, num_nodes, feature_dim]
    indices: [batch_size, num_nodes, num_neighbors]
    """
    batch_size, num_nodes, feature_dim = node_embedding.size()
    # shape: [batch_size, num_nodes * num_neighbors, feature_dim]
    indices_ = indices.view(batch_size, -1, 1).expand(-1, -1, feature_dim)
    # `indices_` represents new indices tensor:
    #     indices_[i][j * num_neighbors + k][l] = indices[i][j][k]
    #
    #     output[i][j][k][l] = \
    #       node_embedding[i][indices_[i][j * num_neighbors + k][l]][l] = \
    #         node_embedding[i][indices[i][j][k]][l]
    return node_embedding.gather(1, indices_).view(batch_size, num_nodes, -1, feature_dim)


class Highway(nn.Module):
    def __init__(self, size, num_layers, activation):
        super().__init__()
        self.nonlinears = nn.ModuleList([nn.Linear(size, size) for _ in range(num_layers)])
        self.linears = nn.ModuleList([nn.Linear(size, size) for _ in range(num_layers)])
        self.gates = nn.ModuleList([nn.Linear(size, size) for _ in range(num_layers)])
        self.activation = activation

    def forward(self, x):
        """
        :param x: tensor with shape of [batch_size, size]
        :return: tensor with shape of [batch_size, size]
        """
        for gate, linear, nonlinear in zip(self.gates, self.linears, self.nonlinears):
            gate = torch.sigmoid(gate(x))
            nonlinear = self.activation(nonlinear(x))
            linear = linear(x)

            x = gate * nonlinear + (1 - gate) * linear
        return x


class GraphRNNGate(nn.Module):
    def __init__(self, input_size, output_size, activation):
        super().__init__()
        self.w_in = nn.Linear(input_size, output_size, bias=False)
        self.w_out = nn.Linear(input_size, output_size, bias=False)
        self.u_in = nn.Linear(input_size, output_size, bias=False)
        self.u_out = nn.Linear(input_size, output_size, bias=True)
        self.activation = activation

    def forward(self, incoming_neighbor, outgoing_neighbor,
                prev_incoming_neighbor_hidden, prev_outgoing_neighbor_hidden):
        return self.activation(self.w_in(incoming_neighbor)
                               + self.w_out(outgoing_neighbor)
                               + self.u_in(prev_incoming_neighbor_hidden)
                               + self.u_out(prev_outgoing_neighbor_hidden))


class GraphRNNEncoder(nn.Module):
    class Options(OptionsBase):
        hidden_size: int = 256
        edge_size: int = 20

        num_highway_layers: int = 0
        num_layers: int = 4
        dropout: float = 0.1

        activation: str = 'leaky_relu/0.1'

    def __init__(self, options: Options, input_size, num_edges):
        super().__init__()

        self.input_size = input_size
        self.num_layers = options.num_layers

        self.dropout = nn.Dropout(options.dropout)

        self.edge_embeddings = nn.Embedding(num_edges, options.edge_size)
        self.activation = get_activation(options.activation)

        if options.num_highway_layers > 0:
            self.multi_highway = Highway(input_size, options.num_highway_layers, self.activation)
        else:
            self.multi_highway = None

        self.hidden_size = hidden_size = options.hidden_size
        self.project_neighbor = nn.Linear(input_size + options.edge_size, hidden_size)

        self.gate = GraphRNNGate(hidden_size, 3 * hidden_size, torch.sigmoid)
        self.cell = GraphRNNGate(hidden_size, hidden_size, torch.tanh)

    def compute_neighbor_embedding(self, node_embeddings,
                                   x_edge_labels, x_node_indices, x_nodes_mask):
        # shape: [batch_size, num_nodes, num_neighbors, edge_size]
        x_edge_embeddings = self.edge_embeddings(x_edge_labels)
        # shape: [batch_size, num_nodes, num_neighbors, input_size]
        x_node_embeddings = collect_neighbor_nodes(node_embeddings, x_node_indices)
        # shape: [batch_size, num_nodes, num_neighbors, input_size + edge_size]
        x_neigh_embeddings = torch.cat([x_edge_embeddings, x_node_embeddings], dim=-1)
        x_neigh_embeddings = x_neigh_embeddings * x_nodes_mask
        # shape: [batch_size, num_nodes, input_size + edge_size]
        x_neigh_embeddings = x_neigh_embeddings.sum(dim=2)

        return self.dropout(self.activation(self.project_neighbor(x_neigh_embeddings)))

    def compute_neighbor_hidden(self, state_h, x_node_indices, x_nodes_mask, nodes_mask):
        # shape: [batch_size, num_nodes, num_neighbors, hidden_size]
        prev_x_neighbor_hidden = collect_neighbor_nodes(state_h, x_node_indices)
        prev_x_neighbor_hidden = prev_x_neighbor_hidden * x_nodes_mask
        # shape: [batch_size, num_nodes, hidden_size]
        prev_x_neighbor_hidden = prev_x_neighbor_hidden.sum(dim=2)
        return prev_x_neighbor_hidden

    def forward(self, inputs, node_embeddings):
        """
        node_embeddings: [batch_size, num_nodes, input_size]
        """
        hidden_size = self.hidden_size
        device = node_embeddings.device

        if self.multi_highway is not None:
            node_embeddings = self.multi_highway(node_embeddings)

        # shape: [batch_size, num_nodes, hidden_size]
        incoming_neighbor = self.compute_neighbor_embedding(node_embeddings,
                                                            inputs.incoming_labels,
                                                            inputs.incoming_indices,
                                                            inputs.incoming_mask)
        # shape: [batch_size, num_nodes, hidden_size]
        outgoing_neighbor = self.compute_neighbor_embedding(node_embeddings,
                                                            inputs.outgoing_labels,
                                                            inputs.outgoing_indices,
                                                            inputs.outgoing_mask)

        hidden_shape = inputs.nodes_mask.shape[:-1] + (self.hidden_size, )

        state_h = torch.zeros(hidden_shape, device=device)
        state_c = torch.zeros(hidden_shape, device=device)

        layer_outputs = []
        for _ in range(self.num_layers):
            # shape: [batch_size, num_nodes, hidden_size * 4]
            context = (incoming_neighbor,
                       outgoing_neighbor,
                       self.compute_neighbor_hidden(state_h,
                                                    inputs.incoming_indices,
                                                    inputs.incoming_mask,
                                                    inputs.nodes_mask),
                       self.compute_neighbor_hidden(state_h,
                                                    inputs.outgoing_indices,
                                                    inputs.outgoing_mask,
                                                    inputs.nodes_mask))
            gates = self.gate(*context)
            input_gate = gates[:, :, :hidden_size]
            output_gate = gates[:, :, hidden_size: 2 * hidden_size]
            forget_gate = gates[:, :, 2 * hidden_size:]

            # [batch_size, num_nodes, hidden_size]
            state_c = forget_gate * state_c + input_gate * self.cell(*context)
            state_c = state_c * inputs.nodes_mask
            # [batch_size, num_nodes, hidden_size]
            state_h = output_gate * torch.tanh(state_c)

            layer_outputs.append((state_h, state_c))

        return layer_outputs, (state_h, state_c)
