import sys
import argparse

import torch
import torch.nn as nn
import numpy as np
from antu.io import Vocabulary
from antu.io import glove_reader
from torch.nn.utils.rnn import pack_padded_sequence, pad_packed_sequence

from module.dropout import IndependentDropout, SharedDropout
from module.bilstm import BiLSTM
from module.mlp import MLP
from module.biaffine import Biaffine
from module.transformer import XformerEncoder


class Parser(nn.Module):

    def __init__(
        self,
        vocabulary: Vocabulary,
        cfg: argparse.Namespace):
        super(Parser, self).__init__()
        
        # Build word lookup from pre-trained embedding file
        _, v_glove = glove_reader(cfg.GLOVE)
        d_glove, n_glove = len(v_glove[1]), vocabulary.get_vocab_size('glove') 
        v_glove = [[0.0]*d_glove, [0.0]*d_glove] + v_glove
        v_glove = np.array(v_glove, dtype=np.float32) #/np.std(v_glove)
        PAD = vocabulary.get_padding_index('glove')
        self.glookup = torch.nn.Embedding(n_glove, d_glove, padding_idx=PAD)
        self.glookup.weight.data.copy_(torch.from_numpy(v_glove))
        self.glookup.weight.requires_grad = not cfg.IS_FIX_GLOVE

        # Build word lookup embedding
        n_word = vocabulary.get_vocab_size('word') 
        PAD = vocabulary.get_padding_index('word')
        self.wlookup = torch.nn.Embedding(n_word, d_glove, padding_idx=PAD)
        self.wlookup.weight.data.fill_(0)
        
        # Build POS tag lookup
        n_tag = vocabulary.get_vocab_size('tag')
        PAD = vocabulary.get_padding_index('tag')
        self.tlookup = torch.nn.Embedding(n_tag, cfg.D_TAG, padding_idx=PAD)

        # Emb. Dropout
        self.emb_drop = IndependentDropout(cfg.EMB_DROP)

        # Encoder Layer
        ## BiLSTM
        if cfg.MODEL_TYPE == 'RNN':
            D_RNN_IN = d_glove+cfg.D_TAG
            self.bilstm = BiLSTM(
                D_RNN_IN, cfg.D_RNN_HID, cfg.N_RNN_LAYER, cfg.RNN_DROP)
            self.bilstm_drop = SharedDropout(cfg.RNN_DROP)
            D_MLP_IN = cfg.D_RNN_HID*2
        ## Xformer
        elif cfg.MODEL_TYPE == 'Xformer':
            if 'cat' in cfg.PE_TYPE:
                cfg.D_PE = cfg.D_MODEL-(cfg.D_TAG+d_glove)
            elif 'add' in cfg.PE_TYPE:
                # self.in2model = nn.Linear(cfg.D_TAG+d_glove, cfg.D_MODEL)
                cfg.D_PE = cfg.D_MODEL

            self.xformer = XformerEncoder(cfg)
            self.xformer_drop = SharedDropout(cfg.MLP_DROP)
            D_MLP_IN = cfg.D_MODEL

        # MLP Layer
        self.mlp_d = MLP(D_MLP_IN, cfg.D_ARC+cfg.D_REL, cfg.MLP_DROP)
        self.mlp_h = MLP(D_MLP_IN, cfg.D_ARC+cfg.D_REL, cfg.MLP_DROP)
        self.d_arc = cfg.D_ARC
        # Bi-affine Layer
        self.arc_attn = Biaffine(cfg.D_ARC, 1, True, False)
        n_rel = vocabulary.get_vocab_size('rel')
        self.rel_attn = Biaffine(cfg.D_REL, n_rel, True, True)
        self.activation = nn.LeakyReLU(negative_slope=0.1)


    def forward(self, x):
        max_len, lens = x['w_lookup'].size()[1], x['mask'].sum(dim=1)

        # Embedding Layer
        v_w = self.wlookup(x['w_lookup']) + self.glookup(x['g_lookup'])
        v_t = self.tlookup(x['t_lookup'])
        v_w, v_t = self.emb_drop(v_w, v_t)
        v = torch.cat((v_w, v_t), dim=-1)
        
        # BiLSTM Layer
        if hasattr(self, 'bilstm'):
            v = pack_padded_sequence(v, lens, True, False)
            v, _ = self.bilstm(v)
            v, _ = pad_packed_sequence(v, True, total_length=max_len)
            v = self.bilstm_drop(v)
        # Xformer Layer
        elif hasattr(self, 'xformer'):
            # if hasattr(self, 'in2model'):
            #     v = self.in2model(v)
            v = v.permute(1, 0, 2)
            v = self.xformer(v, ~x['mask'])
            v = v.permute(1, 0, 2)
            v = self.xformer_drop(v)

        # MLP Layer
        h, d = self.mlp_h(v), self.mlp_d(v)
        h_arc, d_arc = h[..., :self.d_arc], d[..., :self.d_arc]
        h_rel, d_rel = h[..., self.d_arc:], d[..., self.d_arc:]

        # Arc Bi-affine Layer
        s_arc = self.arc_attn(d_arc, h_arc)
        s_arc.masked_fill_(~x['mask'].unsqueeze(1), float('-inf'))
        
        # mask the ROOT token
        x['mask'][:, 0] = 0
        pred_arc = s_arc[x['mask']]

        # Rel Bi-affine Layer
        s_rel = self.rel_attn(d_rel, h_rel).permute(0, 2, 3, 1)
        pred_rel = s_rel[x['mask']]

        if self.training:
            gold_arc = x['head'][x['mask'].view(-1)]
            gold_rel = x['rel'][x['mask'].view(-1)]
            pred_rel = pred_rel[torch.arange(len(gold_arc)), gold_arc]
            return pred_arc, gold_arc, pred_rel, gold_rel
        else:
            pred_arc = pred_arc.argmax(-1)
            pred_rel = pred_rel[torch.arange(len(pred_arc)), pred_arc].argmax(-1)
            return pred_arc.tolist(), pred_rel.tolist()




