from .dep_decoder import DependencyDecoder
from antu.io.vocabulary import Vocabulary
from antu.io.configurators.ini_configurator import IniConfigurator
from antu.nn.dynet.multi_layer_perception import MLP
from antu.nn.dynet.attention.biaffine import BiaffineAttention
from antu.nn.dynet.units.graph_nn_unit import GraphNNUnit
from utils.mst_decoder import MST_inference
import dynet as dy
import numpy as np


class LabelGNNDecoder(DependencyDecoder):

    def __init__(
        self,
        model: dy.ParameterCollection,
        cfg: IniConfigurator,
        vocabulary: Vocabulary):
        pc = model.add_subcollection()
        # MLP layer
        def leaky_relu(x):
            return dy.bmax(.1*x, x)
        self.head_MLP = MLP(
            pc, cfg.MLP_SIZE, leaky_relu, 'orthonormal', cfg.MLP_BIAS, cfg.MLP_DROP)
        self.son_MLP = MLP(
            pc, cfg.MLP_SIZE, leaky_relu, 'orthonormal', cfg.MLP_BIAS, cfg.MLP_DROP)

        # Biaffine attention layer
        arc_size = cfg.ARC_SIZE
        rel_size = cfg.MLP_SIZE[-1] - arc_size

        # self.arc_attn = BiaffineAttention(pc, arc_size, arc_size, 1, cfg.ARC_BIAS, 0)
        self.arc_attn_mat = [
            BiaffineAttention(pc, arc_size+32+rel_size, arc_size+32+rel_size, 1, cfg.ARC_BIAS, 'orthonormal')
            for _ in range(cfg.GRAPH_LAYERS+1)]
        # self.arc_attn1 = BiaffineAttention(pc, arc_size, arc_size, 1, cfg.ARC_BIAS, 0)
        # self.arc_attn2 = BiaffineAttention(pc, arc_size, arc_size, 1, cfg.ARC_BIAS, 0)
        rel_num = vocabulary.get_vocab_size('rel')
        self.rel_mat = pc.add_parameters((cfg.REL_DIM, rel_num), init=0) 
        self.rel_mask = np.array([1] + [0] * (rel_num-1))
        self.rel_attn = BiaffineAttention(pc, rel_size, rel_size, rel_num, cfg.REL_BIAS, 'orthonormal')

        # Graph NN layer
        self.head_graphNN = GraphNNUnit(pc, 350, 382, leaky_relu, 'orthonormal')
        self.son_graphNN  = GraphNNUnit(pc, 350, 382, leaky_relu, 'orthonormal')

        # Save variable
        self.arc_size, self.rel_size, self.rel_num = arc_size, rel_size, rel_num
        self.pc, self.cfg = pc, cfg
        self.spec = (cfg, vocabulary)

    def __call__(self, inputs, masks, truth, is_train=True, is_tree=True):
        sent_len = len(inputs)
        batch_size = inputs[0].dim()[1]
        flat_len = sent_len * batch_size

        # H -> hidden size, L -> sentence length, B -> batch size
        X = dy.concatenate_cols(inputs)     # ((H, L), B)
        if is_train: X = dy.dropout_dim(X, 1, self.cfg.MLP_DROP)
        # M_H -> MLP hidden size
        head_mat = self.head_MLP(X, is_train)   # ((M_H, L), B)
        son_mat  = self.son_MLP (X, is_train)   # ((M_H, L), B)
        if is_train:
            total_token = sum(masks['flat'].tolist())
            head_mat = dy.dropout_dim(head_mat, 1, self.cfg.MLP_DROP)
            son_mat  = dy.dropout_dim(son_mat,  1, self.cfg.MLP_DROP)

        masks_flat = dy.inputTensor(masks['flat'], True)

        H_dim = H_rel.dim()[0][0]
        gnn_losses = []
        for k in range(self.cfg.GRAPH_LAYERS):
        
        
        head_arc = head_mat
        son_arc = son_mat
        LR = sent_len*self.rel_num
        R_H = self.cfg.REL_DIM
        head_mat = dy.concatenate_cols([head_mat for _ in range(self.rel_num)])   # ((M_H, L*R), B)
        son_mat  = dy.concatenate_cols([son_mat for _ in range(self.rel_num)])    # ((M_H, L*R), B)
        rel_mat_repeat = dy.concatenate([self.rel_mat for _ in range(sent_len)])  # ((R_H*L, R), 1)
        rel_mat = dy.reshape(rel_mat_repeat, (R_H, LR), 1)  # ((R_H, L*R), 1)
        H_rel = dy.concatenate([head_mat, rel_mat])   # ((H, L*R), B) 
        D_rel = dy.concatenate([son_mat, rel_mat])    # ((H, L*R), B)
        
        H_dim = H_rel.dim()[0][0]
        gnn_losses = []
        for k in range(self.cfg.GRAPH_LAYERS):
            H = dy.reshape(H_rel, (H_dim, sent_len), self.rel_num*batch_size)  # ((H, L), R*B)
            D = dy.reshape(D_rel, (H_dim, sent_len), self.rel_num*batch_size)  # ((H, L), R*B)
            arc_mat = self.arc_attn_mat[k](H, D)# -1e9*(1-masks_2D)   # ((L, L), R*B) 
            arc_mat = dy.reshape(arc_mat, (sent_len, sent_len*self.rel_num), batch_size)  # ((L, L*R), B)
            arc_mat = dy.transpose(arc_mat)  # ((L*R, L), B)
            arc_prob = dy.softmax(arc_mat)   # ((L*R, L), B) 
            if is_train: arc_prob = dy.dropout(arc_prob, 0.1)
            arc_prob_T = dy.reshape(dy.transpose(arc_prob), (sent_len, sent_len), self.rel_num*batch_size)
            arc_prob_T = dy.transpose(arc_prob_T)
            arc_prob_T = dy.reshape(arc_prob_T, (sent_len, sent_len*self.rel_num), batch_size)
            arc_prob_T = dy.transpose(arc_prob_T)
            if is_train:
                arc_mat = dy.reshape(arc_mat, (sent_len*self.rel_num,), flat_len) # ((L*R,), L*B)
                arc_loss = dy.pickneglogsoftmax_batch(arc_mat, truth['headarc'])
                arc_loss = dy.sum_batches(arc_loss*masks_flat)/total_token
                gnn_losses.append(arc_loss)
            # FX = head_arc + son_arc
            HX = H_rel * arc_prob     # ((H, L), B)
            SX = D_rel * arc_prob_T   # ((H, L), B)
            # HX = head_arc * arc_prob
            # SX = son_arc * dy.transpose(arc_prob)
            FX = HX + SX    # Fusion head and dept representation
            head_arc = self.head_graphNN(FX, head_arc, is_train)
            # FX_new = head_arc * arc_prob + SX
            son_arc = self.son_graphNN(FX, son_arc, is_train)
            head_mat = dy.concatenate_cols([head_arc for _ in range(self.rel_num)])   # ((M_H, L*R), B)
            son_mat  = dy.concatenate_cols([son_arc for _ in range(self.rel_num)])    # ((M_H, L*R), B)
            H_rel = dy.concatenate([head_mat, rel_mat])   # ((H, L*R), B) 
            D_rel = dy.concatenate([son_mat, rel_mat])    # ((H, L*R), B)

        H = dy.reshape(H_rel, (H_dim, sent_len), self.rel_num*batch_size)  # ((H, L), R*B)
        D = dy.reshape(D_rel, (H_dim, sent_len), self.rel_num*batch_size)  # ((H, L), R*B)
        arc_mat = self.arc_attn_mat[-1](H, D)# -1e9*(1-masks_2D)   # ((L, L), R*B) 
        arc_mat = dy.reshape(arc_mat, (sent_len, sent_len*self.rel_num), batch_size)  # ((L, L*R), B)
        arc_mat = dy.transpose(arc_mat)  # ((L*R, L), B)
        arc_prob = dy.softmax(arc_mat)   # ((L*R, L), B) 
        if is_train:
            print(arc_prob.npvalue())
            arc_mat = dy.reshape(arc_mat, (sent_len*self.rel_num,), flat_len) # ((L*R,), L*B)
            arc_losses = dy.pickneglogsoftmax_batch(arc_mat, truth['headarc'])
            arc_loss = dy.sum_batches(arc_losses*masks_flat)/total_token
            print(arc_loss.value())
            losses = arc_loss*self.cfg.LAMBDA2
            if gnn_losses:
                losses += dy.esum(gnn_losses)*self.cfg.LAMBDA1
            losses_list = gnn_losses + [arc_loss, ]
            return losses, losses_list
        else:
            arc_prob = dy.transpose(arc_prob)   # ((L, L*R), B)
            arc_prob = dy.reshape(arc_prob, (sent_len, sent_len, self.rel_num), batch_size)  # ((L, L, R), B)
            arc_prob = dy.max_dim(arc_prob, 2)  # ((L, L), B)
            if is_tree:
                arc_probs = np.reshape(dy.transpose(arc_prob).npvalue(), (sent_len, sent_len, batch_size), 'F')  # (L, L, B)
                arc_masks = [np.array(masks['flat'][i:i+sent_len]) for i in range(0, flat_len, sent_len)]
                arc_pred = []
                for msk, arc_prob in zip(arc_masks, arc_probs):
                    msk[0] = 1
                    seq_len = int(np.sum(msk))
                    tmp_pred = MST_inference(arc_prob, seq_len, msk)
                    tmp_pred[0] = 0
                    arc_pred.extend(tmp_pred)
                pred = {}
                print(arc_pred)
                rel_pred = [1 for _ in range(len(arc_pred))]
                pred['head'], pred['rel'] = arc_pred, rel_pred
                return pred
            else:
                pass

    @staticmethod
    def from_spec(spec, model):
        """Create and return a new instane with the needed parameters.

        It is one of the prerequisites for Dynet save/load method.
        """
        cfg, vocabulary = spec
        return LabelGNNDecoder(model, cfg, vocabulary)

    def param_collection(self):
        """Return a :code:`dynet.ParameterCollection` object with the parameters.

        It is one of the prerequisites for Dynet save/load method.
        """
        return self.pc

