from .dep_decoder import DependencyDecoder
from antu.io.vocabulary import Vocabulary
from antu.io.configurators.ini_configurator import IniConfigurator
from antu.nn.dynet.multi_layer_perception import MLP
from antu.nn.dynet.attention.biaffine import BiaffineAttention
import dynet as dy
import numpy as np

class SeqHeadSelDecoder(DependencyDecoder):

    def __init__(
        self,
        model: dy.ParameterCollection,
        cfg: IniConfigurator,
        vocabulary: Vocabulary):
        pc = model.add_subcollection()
        # MLP layer
        def leaky_relu(x):
            return dy.bmax(.1*x, x)
        self.head_MLP = MLP(
            pc, cfg.MLP_SIZE, leaky_relu, 'orthonormal', cfg.MLP_BIAS, cfg.MLP_DROP)
        self.son_MLP = MLP(
            pc, cfg.MLP_SIZE, leaky_relu, 'orthonormal', cfg.MLP_BIAS, cfg.MLP_DROP)

        # Biaffine attention layer
        arc_size = cfg.ARC_SIZE
        rel_size = cfg.MLP_SIZE[-1] - arc_size
        self.arc_attn = BiaffineAttention(pc, arc_size, arc_size, 1, cfg.ARC_BIAS, 0)
        rel_num = vocabulary.get_vocab_size('rel')
        self.rel_attn = BiaffineAttention(pc, rel_size, rel_size, rel_num, cfg.REL_BIAS, 0)

        # Save variable
        self.rel_num = rel_num
        self.arc_size, self.rel_size = arc_size, rel_size
        self.pc, self.cfg = pc, cfg
        self.spec = (cfg, vocabulary)

    def __call__(self, inputs, masks, truth, is_train=True):
        sent_len = len(inputs)
        batch_size = inputs[0].dim()[1]
        flat_len = sent_len * batch_size

        # H -> hidden size, L -> sentence length, B -> batch size
        X = dy.concatenate_cols(inputs)     # ((H, L), B)
        if is_train:
            X = dy.dropout_dim(X, 1, self.cfg.MLP_DROP)
        # if is_train: X = dy.dropout_dim(X, 1, self.cfg.MLP_DROP)
        # M_H -> MLP hidden size
        head_mat = self.head_MLP(X, is_train)   # ((M_H, L), B)
        son_mat  = self.son_MLP (X, is_train)   # ((M_H, L), B)
        if is_train:
            head_mat = dy.dropout_dim(head_mat, 1, self.cfg.MLP_DROP)
            son_mat  = dy.dropout_dim(son_mat,  1, self.cfg.MLP_DROP)

        # A_H -> Arc hidden size, R_H -> Label hidden size
        # A_H + R_H = M_H
        head_arc = head_mat[:self.arc_size]     # ((A_H, L), B)
        son_arc  = son_mat [:self.arc_size]     # ((A_H, L), B)
        head_rel = head_mat[self.arc_size:]     # ((R_H, L), B)
        son_rel  = son_mat [self.arc_size:]     # ((R_H, L), B)

        masks_2D = dy.inputTensor(np.transpose(np.array(masks['2D'])), True)
        masks_flat = dy.inputTensor(np.array(masks['flat']), True)

        arc_mat = self.arc_attn(head_arc, son_arc)  # ((L, L), B)
        arc_mat = dy.reshape(arc_mat-1e9*(1-masks_2D), (sent_len,), flat_len)
        '''
        arc_prob = dy.softmax(dy.cmult(arc_mat, masks_2D))
        arc_prob = dy.reshape(arc_prob, (sent_len,), sent_len*batch_size)
        arc_prob = dy.cmult(arc_prob, masks_flat)
        '''
        head_rel = dy.reshape(head_rel, (self.rel_size, flat_len))
        son_rel  = dy.reshape(son_rel,  (self.rel_size,), flat_len)
        if is_train:
            truth_rel = dy.pick_batch(head_rel, truth['flat_head'], 1)
            rel_mat = self.rel_attn(son_rel, truth_rel)
        else:
            arc_pred = np.argmax(arc_mat.npvalue(), 0)
            flat_pred = [j+(i//sent_len)*sent_len for i, j in enumerate(arc_pred)]
            pred_rel = dy.pick_batch(head_rel, flat_pred, 1)
            rel_mat = self.rel_attn(son_rel, pred_rel)

        if is_train:
            arc_losses = dy.pickneglogsoftmax_batch(arc_mat, truth['head'])
            rel_losses = dy.pickneglogsoftmax_batch(rel_mat, truth['rel'])
            total_token = sum(masks['flat'])
            arc_loss = dy.sum_batches(arc_losses*masks_flat) / total_token
            rel_loss = dy.sum_batches(rel_losses*masks_flat) / total_token
            losses = (arc_loss + rel_loss) * 0.5
            return losses
        else:
            rel_pred = np.argmax(dy.reshape(rel_mat, (self.rel_num,)).npvalue(), 0)
            pred = {}
            pred['head'], pred['rel'] = arc_pred, rel_pred
            return pred

    @staticmethod
    def from_spec(spec, model):
        """Create and return a new instane with the needed parameters.

        It is one of the prerequisites for Dynet save/load method.
        """
        cfg, vocabulary = spec
        return SeqHeadSelDecoder(model, cfg, vocabulary)

    def param_collection(self):
        """Return a :code:`dynet.ParameterCollection` object with the parameters.

        It is one of the prerequisites for Dynet save/load method.
        """
        return self.pc

