from .dep_decoder import DependencyDecoder
from antu.io.vocabulary import Vocabulary
from antu.io.configurators.ini_configurator import IniConfigurator
from antu.nn.dynet.multi_layer_perception import MLP
from antu.nn.dynet.attention.biaffine import BiaffineAttention
from antu.nn.dynet.units.graph_nn_unit import GraphNNUnit
from utils.mst_decoder import MST_inference
import dynet as dy
import numpy as np


class GraphRELDecoder(DependencyDecoder):

    def __init__(
        self,
        model: dy.ParameterCollection,
        cfg: IniConfigurator,
        vocabulary: Vocabulary):
        pc = model.add_subcollection()
        # MLP layer
        def leaky_relu(x):
            return dy.bmax(.1*x, x)
        self.head_MLP = MLP(
            pc, cfg.MLP_SIZE, leaky_relu, 'orthonormal', cfg.MLP_BIAS, cfg.MLP_DROP)
        self.son_MLP = MLP(
            pc, cfg.MLP_SIZE, leaky_relu, 'orthonormal', cfg.MLP_BIAS, cfg.MLP_DROP)

        # Biaffine attention layer
        arc_size = cfg.ARC_SIZE
        rel_size = cfg.MLP_SIZE[-1] - arc_size

        # self.arc_attn = BiaffineAttention(pc, arc_size, arc_size, 1, cfg.ARC_BIAS, 0)
        self.arc_attn_mat = [
            BiaffineAttention(pc, arc_size, arc_size, 1, cfg.ARC_BIAS, 0)
            for _ in range(cfg.GRAPH_LAYERS+1)]
        # self.arc_attn1 = BiaffineAttention(pc, arc_size, arc_size, 1, cfg.ARC_BIAS, 0)
        # self.arc_attn2 = BiaffineAttention(pc, arc_size, arc_size, 1, cfg.ARC_BIAS, 0)
        rel_num = vocabulary.get_vocab_size('rel')
        self.rel_mask = np.array([1] + [0] * (rel_num-1))
        self.rel_attn = BiaffineAttention(pc, rel_size, rel_size, rel_num, cfg.REL_BIAS, 0)

        # Graph NN layer
        self.head_graphNN = GraphNNUnit(pc, arc_size, arc_size, leaky_relu, 'orthonormal')
        self.son_graphNN  = GraphNNUnit(pc, arc_size, arc_size, leaky_relu, 'orthonormal')
        self.hrel_graphNN = GraphNNUnit(pc, rel_size, rel_size, leaky_relu, 'orthonormal')
        self.srel_graphNN = GraphNNUnit(pc, rel_size, rel_size, leaky_relu, 'orthonormal')

        # Save variable
        self.arc_size, self.rel_size, self.rel_num = arc_size, rel_size, rel_num
        self.pc, self.cfg = pc, cfg
        self.spec = (cfg, vocabulary)

    def __call__(self, inputs, masks, truth, is_train=True, is_tree=True):
        sent_len = len(inputs)
        batch_size = inputs[0].dim()[1]
        flat_len = sent_len * batch_size

        # H -> hidden size, L -> sentence length, B -> batch size
        X = dy.concatenate_cols(inputs)     # ((H, L), B)
        if is_train: X = dy.dropout_dim(X, 1, self.cfg.MLP_DROP)
        # M_H -> MLP hidden size
        head_mat = self.head_MLP(X, is_train)   # ((M_H, L), B)
        son_mat  = self.son_MLP (X, is_train)   # ((M_H, L), B)
        if is_train:
            total_token = sum(masks['flat'].tolist())
            head_mat = dy.dropout_dim(head_mat, 1, self.cfg.MLP_DROP)
            son_mat  = dy.dropout_dim(son_mat,  1, self.cfg.MLP_DROP)

        # A_H -> Arc hidden size, R_H -> Label hidden size
        # A_H + R_H = M_H
        head_arc = head_mat[:self.arc_size]     # ((A_H, L), B)
        son_arc  = son_mat [:self.arc_size]     # ((A_H, L), B)
        head_rel = head_mat[self.arc_size:]     # ((R_H, L), B)
        son_rel  = son_mat [self.arc_size:]     # ((R_H, L), B)

        masks_2D = dy.inputTensor(masks['2D'], True)
        masks_flat = dy.inputTensor(masks['flat'], True)

        gnn_losses = []
        for k in range(self.cfg.GRAPH_LAYERS):
            arc_mat = self.arc_attn_mat[k](head_arc, son_arc)-1e9*(1-masks_2D)
            arc_prob = dy.softmax(arc_mat)
            if is_train:
                arc_prob = dy.dropout(arc_prob, 0.2)
            if is_train:
                arc_mat = dy.reshape(arc_mat, (sent_len,), flat_len)
                arc_loss = dy.pickneglogsoftmax_batch(arc_mat, truth['head'])
                arc_loss = dy.sum_batches(arc_loss*masks_flat)/total_token
                gnn_losses.append(arc_loss)
            # FX = head_arc + son_arc
            HX = head_arc * arc_prob
            SX = son_arc * dy.transpose(arc_prob)
            '''
            if is_train:
                HX = dy.dropout_dim(HX, 1, self.cfg.MLP_DROP)
                SX = dy.dropout_dim(SX, 1, self.cfg.MLP_DROP)
            '''
            FX = HX + SX    # Fusion head and dept representation
            head_arc = self.head_graphNN(FX, head_arc, is_train)
            FX_new = head_arc * arc_prob + SX
            son_arc = self.son_graphNN(FX_new, son_arc, is_train)
            '''
            if is_train:
                head_arc = dy.dropout_dim(head_arc, 1, self.cfg.MLP_DROP)
                son_arc = dy.dropout_dim(son_arc, 1, self.cfg.MLP_DROP)
            '''

        arc_mat = self.arc_attn_mat[-1](head_arc, son_arc)-1e9*(1-masks_2D)
        arc_prob = dy.softmax(arc_mat)
        if is_train:
            arc_prob = dy.dropout(arc_prob, 0.2)
        HR = head_rel * arc_prob
        DR = son_rel * dy.transpose(arc_prob)
        FR = HR + DR
        head_rel = self.hrel_graphNN(FR, head_rel, is_train)
        son_rel  = self.srel_graphNN(FR, son_rel,  is_train)
        arc_mat = dy.reshape(arc_mat, (sent_len,), flat_len)
        # predict relation
        head_rel = dy.reshape(head_rel, (self.rel_size, flat_len))
        son_rel  = dy.reshape(son_rel,  (self.rel_size,), flat_len)
        if is_train:
            arc_losses = dy.pickneglogsoftmax_batch(arc_mat, truth['head'])
            # print(arc_losses.dim(), masks_flat.dim(), type(total_token))
            arc_loss = dy.sum_batches(arc_losses*masks_flat)/total_token
            truth_rel = dy.pick_batch(head_rel, truth['flat_head'], 1)
            rel_mat = self.rel_attn(son_rel, truth_rel)
        else:
            if is_tree:
                arc_probs = np.transpose(np.reshape(dy.softmax(arc_mat).npvalue(), (sent_len, sent_len, batch_size), 'F'))
                arc_masks = [np.array(masks['flat'][i:i+sent_len]) for i in range(0, flat_len, sent_len)]
                arc_pred = []
                for msk, arc_prob in zip(arc_masks, arc_probs):
                    msk[0] = 1
                    seq_len = int(np.sum(msk))
                    tmp_pred = MST_inference(arc_prob, seq_len, msk)
                    tmp_pred[0] = 0
                    arc_pred.extend(tmp_pred)
            else:
                arc_pred = np.argmax(arc_mat.npvalue(), 0)
            flat_pred = [j+(i//sent_len)*sent_len for i, j in enumerate(arc_pred)]
            pred_rel = dy.pick_batch(head_rel, flat_pred, 1)
            rel_mat = self.rel_attn(son_rel, pred_rel)
            rel_mask = dy.inputTensor(self.rel_mask)
            rel_mat = rel_mat - 1e9*rel_mask
        if is_train:
            rel_losses = dy.pickneglogsoftmax_batch(rel_mat, truth['rel'])
            rel_loss = dy.sum_batches(rel_losses*masks_flat) / total_token
            losses = (rel_loss+arc_loss)*self.cfg.LAMBDA2
            if gnn_losses:
                losses += dy.esum(gnn_losses)*self.cfg.LAMBDA1
            losses_list = gnn_losses + [arc_loss, rel_loss]
            # losses_list = gnn_losses + [arc_loss, ]
            return losses, losses_list
        else:
            '''
            arc_losses = dy.pickneglogsoftmax_batch(arc_mat, truth['head'])
            arc_loss = dy.sum_batches(arc_losses*masks_flat)/total_token
            rel_losses = dy.pickneglogsoftmax_batch(rel_mat, truth['rel'])
            rel_loss = dy.sum_batches(rel_losses*masks_flat) / total_token
            '''
            rel_pred = np.argmax(dy.reshape(rel_mat, (self.rel_num,)).npvalue(), 0)
            pred = {}
            pred['head'], pred['rel'] = arc_pred, rel_pred
            return pred

    @staticmethod
    def from_spec(spec, model):
        """Create and return a new instane with the needed parameters.

        It is one of the prerequisites for Dynet save/load method.
        """
        cfg, vocabulary = spec
        return GraphRELDecoder(model, cfg, vocabulary)

    def param_collection(self):
        """Return a :code:`dynet.ParameterCollection` object with the parameters.

        It is one of the prerequisites for Dynet save/load method.
        """
        return self.pc

