from typing import Dict, TypeVar, List
from antu.io.vocabulary import Vocabulary
from collections import Counter
from antu.io.ext_embedding_readers import fasttext_reader
from antu.nn.dynet.attention.biaffine import BiaffineAttention
from antu.nn.dynet.units.graph_nn_unit import GraphNNUnit
from antu.nn.dynet.multi_layer_perception import MLP
Indices = TypeVar("Indices", List[int], List[List[int]])
import numpy as np
import dynet as dy
import random, math


def leaky_relu(x):
    return dy.bmax(.1*x, x)

class GraphRepresentation(object):

    def __init__(
        self,
        model,
        cfg,
        vocabulary: Vocabulary):

        pc = model.add_subcollection()
        # word occurrence embeddings
        self.nglookup = pc.lookup_parameters_from_numpy(
            np.random.randn(cfg.MAX_NUM, cfg.NUM_DIM).astype(np.float32))
        # tag occurrence embeddings
        self.ntlookup = pc.lookup_parameters_from_numpy(
            np.random.randn(cfg.MAX_NUM, cfg.NUM_DIM).astype(np.float32))
        # UPOS tag embeddings
        tag_num = vocabulary.get_vocab_size('tag')
        self.tlookup = pc.lookup_parameters_from_numpy(
            np.random.randn(tag_num, cfg.TAG_DIM).astype(np.float32))
        # fasttext word embeddings
        _, glove_vec = fasttext_reader(cfg.GLOVE, False)
        glove_dim = len(glove_vec[0])
        unk_pad_vec = [[0.0 for _ in range(glove_dim)]]
        glove_num = vocabulary.get_vocab_size('glove')
        glove_vec = unk_pad_vec + unk_pad_vec + glove_vec
        self.glookup = np.array(glove_vec, dtype=np.float32)
        
        self.token_dim = cfg.WORD_DIM + cfg.TAG_DIM + cfg.NUM_DIM * 2
        self.PWH = pc.add_parameters((self.token_dim, self.token_dim), 'glorot')
        self.SWH = pc.add_parameters((self.token_dim, self.token_dim), 'glorot')
        self.PWD = pc.add_parameters((self.token_dim, self.token_dim), 'glorot')
        self.SWD = pc.add_parameters((self.token_dim, self.token_dim), 'glorot')
        
        self.PWX = pc.add_parameters((self.token_dim, self.token_dim), 'glorot')
        self.GNN = GraphNNUnit(pc, self.token_dim, self.token_dim, leaky_relu, 'orthonormal')
        self.vocabulary = vocabulary
        self.pc, self.cfg = pc, cfg
        self.spec = (cfg, vocabulary)

    def __call__(
        self,
        indexes: Dict[str, List[Indices]],
        masks: List[np.array],
        truth: List[int],
        is_train=False) -> List[dy.Expression]:
        L = len_s = len(indexes['head'][0])
        batch_num = len(indexes['head'])
        LB = L * batch_num
        vectors = []
        cntg = [Counter() for _ in range(batch_num)]
        cntt = [Counter() for _ in range(batch_num)]
        for i in range(len_s):
            # map token indexes -> vector
            g_idxes = [indexes['word']['glove'][x][i] for x in range(batch_num)]
            ng_idxes = [cntg[k][x] for k, x in enumerate(g_idxes)]
            for k, x in enumerate(g_idxes):
                cntg[k].update([x,])
            t_idxes = [indexes['tag']['tag'][x][i] for x in range(batch_num)]
            nt_idxes = [cntt[k][x] for k, x in enumerate(t_idxes)]
            for k, x in enumerate(t_idxes):
                cntt[k].update([x,])
            w_vec = dy.inputTensor(self.glookup[g_idxes, :].T, True)
            t_vec = dy.lookup_batch(self.tlookup, t_idxes)
            ng_vec = dy.lookup_batch(self.nglookup, ng_idxes)
            nt_vec = dy.lookup_batch(self.ntlookup, nt_idxes)

            # build token mask with dropout scale
            # For only word dropped: tag * 3
            # For only tag dropped: word * 1.5
            # For both word and tag dropped: 0 vector
            if is_train:
                wm = np.random.binomial(1, 1.-self.cfg.WORD_DROP, batch_num).astype(np.float32)
                tm = np.random.binomial(1, 1.-self.cfg.TAG_DROP, batch_num).astype(np.float32)
                scale = np.logical_or(wm, tm) * 3 / (2*wm + tm + 1e-12)
                wm *= scale
                tm *= scale
                w_vec *= dy.inputTensor(wm, batched=True)
                t_vec *= dy.inputTensor(tm, batched=True)

            w_vec = dy.concatenate([w_vec, ng_vec])
            t_vec = dy.concatenate([t_vec, nt_vec])
            vectors.append(dy.concatenate([w_vec, t_vec]))

        maskH_2D = 1e9*(1-dy.inputTensor(masks['H-2D'], True))  # [zero] + [ins]
        maskD_2D = 1e9*(1-dy.inputTensor(masks['D-2D'], True))  # [ins] + [zero]
        maskH_flat = dy.inputTensor(masks['H-flat'], True)      # [zero] + [ins]
        maskD_flat = dy.inputTensor(masks['D-flat'], True)      # [ins] + [zero]
        total_token = int(masks['H-flat'].sum())
        # reg = math.sqrt(self.token_dim)
        reg = 1
        X = dy.concatenate_cols(vectors)
        PH, PD, SH, SD = self.PWH*X, self.PWD*X, self.SWH*X, self.SWD*X
        PA, SA = (leaky_relu(dy.transpose(PH)*PD))/reg-maskH_2D, (leaky_relu(dy.transpose(SH)*SD))/reg-maskD_2D
        if is_train:
            prefix = dy.reshape(PA, (L,), LB)
            prefix_loss = dy.pickneglogsoftmax_batch(prefix, truth['H']) # [0,0,1,2,3]
            prefix_loss = dy.sum_batches(prefix_loss*maskH_flat)/total_token 
            suffix = dy.reshape(SA, (L,), LB)
            suffix_loss = dy.pickneglogsoftmax_batch(suffix, truth['D']) # [1,2,3,4,0]
            suffix_loss = dy.sum_batches(suffix_loss*maskD_flat)/total_token
            return prefix_loss + suffix_loss
        else:
            A_prefix, A_suffix = dy.softmax(PA).npvalue(), dy.softmax(SA).npvalue()
            cnt = 0
            for b in range(batch_num):
                prefix, suffix = A_prefix[:,:,b], A_suffix[:,:,b]
                L = truth['length'][b]
                results = [0]
                flag = [False] * L
                flag[0] = True
                now = 0
                for _ in range(L-2):
                    maxi, maxv = 0, 0
                    for i in range(1, L-1):
                        if not flag[i] and prefix[now][i]+suffix[i][now]>maxv:
                            maxv = prefix[now][i]+suffix[i][now]
                            maxi = i
                    flag[maxi] = True
                    results.append(maxi)
                    now = maxi
                if results == list(range(L-1)):
                    cnt += 1

            return cnt, batch_num, 1, 1

    @staticmethod
    def from_spec(spec, model):
        """Create and return a new instane with the needed parameters.

        It is one of the prerequisites for Dynet save/load method.
        """
        cfg, vocabulary = spec
        return GraphRepresentation(model, cfg, vocabulary)

    def param_collection(self):
        """Return a :code:`dynet.ParameterCollection` object with the parameters.

        It is one of the prerequisites for Dynet save/load method.
        """
        return self.pc
