from .dep_decoder import DependencyDecoder
from antu.io.vocabulary import Vocabulary
from antu.io.configurators.ini_configurator import IniConfigurator
from antu.nn.dynet.multi_layer_perception import MLP
from antu.nn.dynet.attention.biaffine import BiaffineAttention
from antu.nn.dynet.units.graph_nn_unit import GraphNNUnit
from utils.mst_decoder import MST_inference
import dynet as dy
import numpy as np


class LabelGNNDecoder(DependencyDecoder):

    def __init__(
        self,
        model: dy.ParameterCollection,
        cfg: IniConfigurator,
        vocabulary: Vocabulary):
        pc = model.add_subcollection()
        # MLP layer
        def leaky_relu(x):
            return dy.bmax(.1*x, x)
        self.head_MLP = MLP(
            pc, cfg.MLP_SIZE, leaky_relu, 'orthonormal', cfg.MLP_BIAS, cfg.MLP_DROP)
        self.dept_MLP = MLP(
            pc, cfg.MLP_SIZE, leaky_relu, 'orthonormal', cfg.MLP_BIAS, cfg.MLP_DROP)

        # Biaffine attention layer
        # arc_size = cfg.ARC_SIZE
        # rel_size = cfg.MLP_SIZE[-1] - arc_size
        rel_num = vocabulary.get_vocab_size('rel')
        self.rel_mat = pc.add_parameters((cfg.REL_DIM, rel_num), init=0) 
        V = cfg.MLP_SIZE[-1]
        V_R = V+cfg.REL_DIM

        self.BiAttn = [
            BiaffineAttention(pc, V_R, V_R, 1, cfg.ARC_BIAS, 'orthonormal')
            for _ in range(cfg.GRAPH_LAYERS+1)]

        # Graph NN layer
        self.head_GNN = GraphNNUnit(pc, V, V_R, leaky_relu, 'orthonormal')
        self.dept_GNN = GraphNNUnit(pc, V, V_R, leaky_relu, 'orthonormal')

        self.W = pc.add_parameters((V, V), 0)
        self.BH = pc.add_parameters((V, cfg.REL_DIM), 0)
        self.BD = pc.add_parameters((V, cfg.REL_DIM), 0)

        # Save variable
        # self.arc_size, self.rel_size, self.rel_num = arc_size, rel_size, rel_num
        self.rel_num = rel_num
        self.pc, self.cfg = pc, cfg
        self.spec = (cfg, vocabulary)

    def __call__(self, inputs, masks, truth, is_train=True, is_tree=True):
        L, R, B = len(inputs), self.rel_num, inputs[0].dim()[1]
        LR, LB, RB = L*R, L*B, R*B

        # H -> hidden size, L -> sentence length, B -> batch size
        X = dy.concatenate_cols(inputs)     # ((H, L), B)
        if is_train: X = dy.dropout_dim(X, 1, self.cfg.MLP_DROP)
        # M_H -> MLP hidden size
        H = self.head_MLP(X, is_train)   # ((M_H, L), B)
        D = self.dept_MLP(X, is_train)   # ((M_H, L), B)
        if is_train:
            total_token = sum(masks['flat'].tolist())
            H = dy.dropout_dim(H, 1, self.cfg.MLP_DROP)
            D = dy.dropout_dim(D, 1, self.cfg.MLP_DROP)

        masks_flat = dy.inputTensor(masks['flat'], True)
        masks_2D = dy.transpose(dy.inputTensor(masks['2D'], True))
        masks_LR = [dy.zeros(masks_2D.dim()[0], masks_2D.dim()[1]), ]
        masks_LR += [masks_2D for _ in range(R-1)]
        masks_LR = 1e9*(1-dy.concatenate(masks_LR))
        
        R_dim = self.cfg.REL_DIM
        H_repeat = dy.concatenate_cols([H for _ in range(R)])        # ((M_H, L*R), B)
        D_repeat = dy.concatenate_cols([D for _ in range(R)])        # ((M_H, L*R), B)
        R_repeat = dy.concatenate([self.rel_mat for _ in range(L)])  # ((R_H*L, R), 1)
        R_repeat = dy.reshape(R_repeat, (R_dim, LR), 1)              # ((R_H, L*R), 1)
        HR = dy.concatenate([H_repeat, R_repeat])                    # ((H, L*R), B) 
        DR = dy.concatenate([D_repeat, R_repeat])                    # ((H, L*R), B)
        
        # H_dim = HR.dim()[0][0]
        gnn_losses = []
        for k in range(self.cfg.GRAPH_LAYERS):
            # H_tmp = dy.reshape(HR, (H_dim, L), RB)      # ((H, L), R*B)
            # D_tmp = dy.reshape(DR, (H_dim, L), RB)      # ((H, L), R*B)
            s1 = dy.transpose(H) * self.W                 
            s1 = s1 * D              # ((L, L), B)
            s2 = dy.transpose(H) * self.BH
            s2 = s2 * self.rel_mat   # ((R, L), B)
            s3 = dy.transpose(D) * self.BD
            s3 = s3 * self.rel_mat   # ((R, L), B)
            s1 = dy.concatenate([s1 for _ in range(R)])  # ((R*L, L), B)
            s2 = dy.reshape(s2, (L,), R*B)
            s3 = dy.reshape(s3, (L,), R*B)
            ss = s3 * dy.transpose(s2) # ((L, L), R*B)
            ss = dy.reshape(ss, (L, L*R), B)
            attn_mat = s1 + dy.transpose(ss)
            
            # attn_mat = self.BiAttn[k](D_tmp, H_tmp)     # ((L, L), R*B) 
            # attn_mat = dy.reshape(attn_mat, (L, LR), B) # ((L, L*R), B)
            # attn_mat = dy.transpose(attn_mat)-masks_LR  # ((L*R, L), B)
            attn_mat  = attn_mat-masks_LR                 # ((L*R, L), B)
            attn_prob = dy.softmax(attn_mat)              # ((L*R, L), B) 
            if is_train: attn_prob = dy.dropout(attn_prob, 0.1)
            attn_prob_T = dy.transpose(attn_prob)             # ((L, L*R), B)
            attn_prob_T = dy.reshape(attn_prob, (L, L), RB)   # ((L, L), R*B)
            attn_prob_T = dy.transpose(attn_prob_T)           # ((L, L), R*B)
            attn_prob_T = dy.reshape(attn_prob_T, (L, LR), B) # ((L, L*R), B)
            attn_prob_T = dy.transpose(attn_prob_T)           # ((L*R, L), B)
            if is_train:
                attn_mat = dy.reshape(attn_mat, (LR,), LB)    # ((L*R,), L*B)
                loss = dy.pickneglogsoftmax_batch(attn_mat, truth['headarc'])
                loss = dy.sum_batches(loss*masks_flat)/total_token
                gnn_losses.append(loss)
            HX = HR * attn_prob     # ((H, L), B)
            SX = DR * attn_prob_T   # ((H, L), B)
            FX = HX + SX    # Fusion head and dept representation
            H = self.head_GNN(FX, H, is_train)
            D = self.dept_GNN(FX, D, is_train)
            H_repeat = dy.concatenate_cols([H for _ in range(R)])  # ((M_H, L*R), B)
            D_repeat = dy.concatenate_cols([D for _ in range(R)])  # ((M_H, L*R), B)
            if k != self.cfg.GRAPH_LAYERS:
                HR = dy.concatenate([H_repeat, R_repeat])              # ((H, L*R), B) 
                DR = dy.concatenate([D_repeat, R_repeat])              # ((H, L*R), B)

        # H_tmp = dy.reshape(HR, (H_dim, L), RB)      # ((H, L), R*B)
        # D_tmp = dy.reshape(DR, (H_dim, L), RB)      # ((H, L), R*B)
        # attn_mat = self.BiAttn[-1](D_tmp, H_tmp)    # ((L, L), R*B) 
        s1 = dy.transpose(H) * self.W                 
        s1 = s1 * D              # ((L, L), B)
        s2 = dy.transpose(H) * self.BH
        s2 = s2 * self.rel_mat   # ((R, L), B)
        s3 = dy.transpose(D) * self.BD
        s3 = s3 * self.rel_mat   # ((R, L), B)
        s1 = dy.concatenate([s1 for _ in range(R)])  # ((R*L, L), B)
        s2 = dy.reshape(s2, (L,), R*B)
        s3 = dy.reshape(s3, (L,), R*B)
        ss = s3 * dy.transpose(s2) # ((L, L), R*B)
        ss = dy.reshape(ss, (L, L*R), B)
        attn_mat = s1 + dy.transpose(ss)
        attn_mat  = attn_mat-masks_LR                 # ((L*R, L), B)
        # attn_mat = dy.reshape(attn_mat, (L, LR), B) # ((L, L*R), B)
        # attn_mat = dy.transpose(attn_mat)-masks_LR  # ((L*R, L), B)
        if is_train:
            attn_mat = dy.reshape(attn_mat, (LR,), LB) # ((L*R,), L*B)
            loss = dy.pickneglogsoftmax_batch(attn_mat, truth['headarc'])
            loss = dy.sum_batches(loss*masks_flat)/total_token
            losses = loss*self.cfg.LAMBDA2
            if gnn_losses:
                losses += dy.esum(gnn_losses)*self.cfg.LAMBDA1
            losses_list = gnn_losses + [loss, ]
            return losses, losses_list
        else:
            attn_prob = dy.softmax(attn_mat)                 # ((L*R, L), B) 
            attn_prob = dy.transpose(attn_prob)              # ((L, L*R), B)
            attn_prob = dy.reshape(attn_prob, (L, L, R), B)  # ((L, L, R), B)
            arc_prob = dy.max_dim(attn_prob, 2)              # ((L, L), B)
            arc_pred = np.argmax(arc_prob.npvalue(), 1)      # (L, B)
            arc_pred = arc_pred.flatten('F')                 # (L*B)
            rel_prob = np.argmax(attn_prob.npvalue(), 2).transpose((1,0,2))  # ((L, L, B))
            rel_prob = np.reshape(rel_prob, (L, LB), 'F').T  # (L*B, L)
            rel_pred = rel_prob[np.arange(LB), arc_pred]     # (L*B)
            pred = {'head': arc_pred, 'rel': rel_pred}
            return pred
            # sys.exit()
            # attn_prob = dy.max_dim(attn_prob, 2)  # ((L, L), B)
            if is_tree:
                # print("argmax:", np.argmax(attn_prob.npvalue(), 1))
                arc_pred1 = np.argmax(attn_prob.npvalue(), 1)
                attn_probs = np.reshape(attn_prob.npvalue(), (L, L, B), 'F')  # (L, L, B)
                arc_masks = [np.array(masks['flat'][i:i+L]) for i in range(0, LB, L)]
                arc_pred = []
                for msk, arc_prob in zip(arc_masks, attn_probs):
                    msk[0] = 1
                    seq_len = int(np.sum(msk))
                    tmp_pred = MST_inference(arc_prob, seq_len, msk)
                    tmp_pred[0] = 0
                    arc_pred.extend(tmp_pred)
                pred = {}
                print("MST:", arc_pred)
                rel_pred = [1 for _ in range(len(arc_pred))]
                pred['head'], pred['rel'] = arc_pred1, rel_pred
                return pred

    @staticmethod
    def from_spec(spec, model):
        """Create and return a new instane with the needed parameters.

        It is one of the prerequisites for Dynet save/load method.
        """
        cfg, vocabulary = spec
        return LabelGNNDecoder(model, cfg, vocabulary)

    def param_collection(self):
        """Return a :code:`dynet.ParameterCollection` object with the parameters.

        It is one of the prerequisites for Dynet save/load method.
        """
        return self.pc

