# -*- coding: utf-8 -*-

import torch
import torch.nn as nn
import torch.nn.functional as F

from framework.common.dataclass_options import argfield
from framework.common.utils import DotDict
from framework.generation_v2.layers.global_attention import GlobalAttention
from framework.torch_extra.layers.graph_rnn import GraphRNNEncoder
from framework.torch_extra.layers.sentence import SentenceEmbeddings
from framework.torch_extra.model_base import ModelBase
from framework.torch_extra.pretrained import ExternalEmbeddingPlugin


class HyperParams(ModelBase.HyperParams):
    sentence_embedding: SentenceEmbeddings.Options
    external_embedding: ExternalEmbeddingPlugin.Options

    encoder: GraphRNNEncoder.Options

    hrg_size: int = 100
    cfg_size: int = 40
    nonterminal_size: int = 50

    dropout: float = 0.2

    # training samples with more than 1 choice are valid
    hrg_min_choices: int = 1
    cfg_min_choices: int = 2

    border_type_size: int = 5
    border_order_size: int = 10
    border_size: int = 50

    loss_reduction: str = argfield(default='mean', choices=['sum', 'mean'])

    attention_type: str = 'general'


class HRGSelector(nn.Module):
    def __init__(self, options, vocabs, input_size, num_hrg_grammars, grammar_nonterminals, **_):
        super().__init__()

        self.vocabs = vocabs
        self.hrg_min_choices = options.hrg_min_choices
        self.loss_reduction = options.loss_reduction

        self.grammar_nonterminals = grammar_nonterminals

        self.init_embeddings(options, num_hrg_grammars)

        self.dropout = nn.Dropout(options.dropout)

        self.type_embeddings = nn.Embedding(3, options.border_type_size)
        self.order_embeddings = nn.Embedding(16, options.border_order_size)

        self.border_size = input_size + options.border_type_size + options.border_order_size

        self.partition_size = input_size * 3 + 2 * self.border_size
        self.scorer = nn.Bilinear(self.partition_size, self.hrg_size, 1)

        self.project_graph = nn.Linear(self.hrg_size, input_size)
        self.project_border = nn.Linear(self.hrg_size, self.border_size)

        self.graph_attention = GlobalAttention(input_size, options.attention_type)
        self.border_attention = GlobalAttention(self.border_size, options.attention_type)

    def init_embeddings(self, options, num_hrg_grammars):
        self.nonterminal_embeddings = nn.Embedding(len(self.vocabs.get('nonterminal')),
                                                   options.nonterminal_size,
                                                   padding_idx=0)
        self.hrg_embeddings = nn.Embedding(num_hrg_grammars, options.hrg_size)

        self.hrg_size = options.hrg_size + options.nonterminal_size * 3

    def compute_embeddings(self, inputs):
        hrg_indices = inputs.hrg_indices
        num_partitions = hrg_indices.size(0)

        # shape: [num_partitions, 3]
        hrg_nonterminals = self.grammar_nonterminals.index_select(0, hrg_indices)
        # shape: [num_partitions, 3 * nonterminal_size + hrg_size]
        hrgs = torch.cat([
            self.nonterminal_embeddings(hrg_nonterminals).view(num_partitions, -1),
            self.hrg_embeddings(hrg_indices)
        ], dim=-1)
        return hrgs

    def collect_nodes(self, node_indices, nodes_mask, node_embeddings, hrgs):
        # shape: [num_partitions, num_nodes, input_size]
        nodes = node_embeddings.index_select(0, node_indices.view(-1))
        nodes = nodes_mask * nodes.view(*node_indices.size(), -1)

        # shape: [num_partitions, num_nodes, hrg_size]
        return self.graph_attention(self.project_graph(hrgs), nodes,
                                    memory_mask=(nodes_mask.squeeze(-1) != 0))[0]

    def compute_border_embeddings(self, node_embeddings, hrgs,
                                  x_border_indices, x_borders_mask,
                                  x_border_types, x_border_orders):
        # shape: [num_partitions, num_border_nodes, input_size]
        x_borders = node_embeddings.index_select(0, x_border_indices.view(-1))
        x_borders = x_borders_mask * x_borders.view(*x_border_indices.size(), -1)

        # shape: [num_partitions, num_border_nodes, type_size]
        x_border_types = self.type_embeddings(x_border_types) * x_borders_mask
        # shape: [num_partitions, num_border_nodes, order_size]
        x_border_orders = self.order_embeddings(x_border_orders) * x_borders_mask
        # shape: [num_partitions, num_border_nodes, border_size]
        x_borders = torch.cat([x_borders, x_border_types, x_border_orders], dim=-1)

        # shape: [batch_size, num_border_nodes, hrg_size]
        return self.border_attention(self.project_border(hrgs), x_borders,
                                     memory_mask=(x_borders_mask.squeeze(-1) != 0))[0]

    def compute_partitions(self, inputs, node_embeddings):
        # shape: [num_partitions, hrg_size]
        hrgs = self.compute_embeddings(inputs)

        num_partitions = hrgs.size(0)
        node_size = node_embeddings.size(-1)

        # shape: [num_partitions, node_size]
        center_nodes = inputs.center_nodes
        if center_nodes is None:
            center_parts = torch.zeros(num_partitions, node_size, device=hrgs.device)
        else:
            center_parts = \
                self.collect_nodes(center_nodes, inputs.center_nodes_mask, node_embeddings, hrgs)

        # shape: [num_partitions, node_size]
        left_nodes = inputs.left_nodes
        if left_nodes is None:
            left_parts = torch.zeros_like(center_parts)
        else:
            left_parts = \
                self.collect_nodes(left_nodes, inputs.left_nodes_mask, node_embeddings, hrgs)

        # shape: [num_partitions, node_size]
        right_nodes = inputs.right_nodes
        if right_nodes is None:
            right_parts = torch.zeros_like(center_parts)
        else:
            right_parts = \
                self.collect_nodes(right_nodes, inputs.right_nodes_mask, node_embeddings, hrgs)

        left_borders = inputs.left_borders
        if left_borders is None:
            left_borders = torch.zeros(num_partitions, self.border_size, device=hrgs.device)
        else:
            left_borders = self.compute_border_embeddings(
                node_embeddings, hrgs,
                left_borders, inputs.left_borders_mask,
                inputs.left_border_types, inputs.left_border_orders)

        right_borders = inputs.right_borders
        if right_borders is None:
            right_borders = torch.zeros_like(left_borders)
        else:
            right_borders = self.compute_border_embeddings(
                node_embeddings, hrgs,
                inputs.right_borders, inputs.right_borders_mask,
                inputs.right_border_types, inputs.right_border_orders)

        partitions_parts = [center_parts, left_parts, right_parts, left_borders, right_borders]
        # shape: [num_partitions, partition_size]
        partitions = torch.cat(partitions_parts, dim=-1)
        partitions = self.dropout(partitions)

        return partitions, hrgs

    def compute_scores(self, inputs, node_embeddings, hrg_min_choices=None):
        if hrg_min_choices is None:
            hrg_min_choices = self.hrg_min_choices

        partitions, hrgs = self.compute_partitions(inputs, node_embeddings)

        # shape: [num_partitions]
        partition_scores = self.scorer(partitions, hrgs).squeeze(-1)

        size = inputs.instances.size()
        instances_mask = inputs.instances_mask
        instance_scores = partition_scores.index_select(0, inputs.instances.view(-1))
        # shape: [num_instances, num_partitions_an_instance]
        instance_scores = instance_scores.view(*size)
        instance_scores = torch.where(instances_mask,
                                      instance_scores,
                                      torch.ones_like(instance_scores) * (-1e20))

        valid_mask = instances_mask.sum(dim=1) >= hrg_min_choices
        instance_scores = instance_scores[valid_mask]

        return instance_scores, partitions

    def forward(self, inputs, node_embeddings, return_partitions=False, hrg_min_choices=None):
        instance_scores, partitions = self.compute_scores(inputs, node_embeddings,
                                                          hrg_min_choices=hrg_min_choices)
        outputs = DotDict()
        if self.training:
            targets = torch.zeros(instance_scores.size(0),
                                  dtype=torch.long, device=instance_scores.device)

            pred_indices = instance_scores.argmax(dim=-1)
            outputs.update(
                loss=F.cross_entropy(instance_scores, targets, reduction=self.loss_reduction),
                correct=(pred_indices == targets).sum().item(),
                total=pred_indices.size(0)
            )
        else:
            outputs.scores = instance_scores

        if return_partitions:
            outputs.partitions = partitions
        return outputs


class CFGSelector(nn.Module):
    def __init__(self, options, num_cfg_grammars, partition_size, **_):
        super().__init__()

        self.cfg_min_choices = options.cfg_min_choices
        self.loss_reduction = options.loss_reduction

        self.init_embeddings(options, num_cfg_grammars)

        self.scorer = nn.Bilinear(partition_size, options.cfg_size, 1)

    def init_embeddings(self, options, num_cfg_grammars):
        self.cfg_embeddings = nn.Embedding(num_cfg_grammars, options.cfg_size)

    def compute_scores(self, inputs, gold_partitions):
        # shape: [num_shrg_instances, num_choices, cfg_size]
        cfgs = self.cfg_embeddings(inputs.shrg_instances)

        shrg_instances_mask = inputs.shrg_instances_mask
        # shape: [num_shrg_instances, num_choices]
        shrg_instance_scores = self.scorer(gold_partitions, cfgs).squeeze(-1)
        shrg_instance_scores = torch.where(shrg_instances_mask,
                                           shrg_instance_scores,
                                           torch.ones_like(shrg_instance_scores) * (-1e20))

        valid_mask = shrg_instances_mask.sum(dim=1) >= self.cfg_min_choices
        # select valid instance
        return shrg_instance_scores[valid_mask]

    def forward(self, inputs, partitions=None, gold_partitions=None):
        if gold_partitions is None:
            num_choices = inputs.shrg_instances.size(1)
            # shape: [num_shrg_instances, partition_size]
            gold_partitions = partitions.index_select(0, inputs.instances[:, 0])
            # shape: [num_shrg_instances, num_choices, partition_size]
            gold_partitions = gold_partitions.unsqueeze(dim=1).repeat(1, num_choices, 1)

        shrg_instance_scores = self.compute_scores(inputs, gold_partitions)

        outputs = DotDict()
        if self.training:
            targets = torch.zeros(shrg_instance_scores.size(0),
                                  dtype=torch.long, device=shrg_instance_scores.device)

            pred_indices = shrg_instance_scores.argmax(dim=-1)
            outputs.update(
                loss=F.cross_entropy(shrg_instance_scores, targets, reduction=self.loss_reduction),
                correct=(pred_indices == targets).sum().item(),
                total=pred_indices.size(0)
            )
        else:
            outputs.scores = shrg_instance_scores

        return outputs


class SHRGSelector(nn.Module):
    HRG_CLASS = HRGSelector
    CFG_CLASS = CFGSelector

    def __init__(self, options, vocabs, plugins, num_hrg_grammars, num_cfg_grammars, **kwargs):
        super().__init__()

        self.input_embeddings = options.sentence_embedding.create(vocabs, plugins)
        self.graph_encoder = GraphRNNEncoder(options.encoder,
                                             self.input_embeddings.output_size,
                                             len(vocabs.get('edge')))

        self.hrg = self.HRG_CLASS(options, vocabs,
                                  input_size=self.graph_encoder.hidden_size,
                                  num_hrg_grammars=num_hrg_grammars, **kwargs)
        self.cfg = self.CFG_CLASS(options,
                                  num_cfg_grammars=num_cfg_grammars,
                                  partition_size=self.hrg.partition_size, **kwargs)

    def run_encoder(self, inputs):
        # shape: [batch_size * num_nodes, graph_encoder.hidden_size]
        _, (state_h, _) = self.graph_encoder(inputs, self.input_embeddings(inputs))
        return state_h.view((-1, state_h.size(-1)))

    def forward(self, inputs, mode):
        outputs = DotDict()

        node_embeddings = self.run_encoder(inputs)
        outputs.hrg = self.hrg(inputs, node_embeddings, return_partitions=(mode != 'hrg'))

        if mode != 'hrg':
            outputs.cfg = self.cfg(inputs, outputs.hrg.pop('partitions'))

        if self.training:
            if mode == 'hrg':
                outputs.loss = outputs.hrg.loss
            elif mode == 'cfg':
                outputs.loss = outputs.cfg.loss
            else:
                outputs.loss = outputs.hrg.loss + outputs.cfg.loss

        return outputs
