from typing import Dict, TypeVar, List
from antu.io.vocabulary import Vocabulary
from collections import Counter
from antu.io.ext_embedding_readers import fasttext_reader
from antu.nn.dynet.attention.biaffine import BiaffineAttention
Indices = TypeVar("Indices", List[int], List[List[int]])
import numpy as np
import dynet as dy
import random


class WinRepresentation(object):

    def __init__(
        self,
        model,
        cfg,
        vocabulary: Vocabulary):

        pc = model.add_subcollection()
        word_num = vocabulary.get_vocab_size('word')
        tag_num = vocabulary.get_vocab_size('tag')
        self.nglookup = pc.lookup_parameters_from_numpy(
            np.random.randn(cfg.MAX_NUM, cfg.NUM_DIM).astype(np.float32))
        self.ntlookup = pc.lookup_parameters_from_numpy(
            np.random.randn(cfg.MAX_NUM, cfg.NUM_DIM).astype(np.float32))
        self.tlookup = pc.lookup_parameters_from_numpy(
            np.random.randn(tag_num, cfg.TAG_DIM).astype(np.float32))
        _, glove_vec = fasttext_reader(cfg.GLOVE, False)
        glove_dim = len(glove_vec[0])
        unk_pad_vec = [[0.0 for _ in range(glove_dim)]]
        glove_num = vocabulary.get_vocab_size('glove')
        glove_vec = unk_pad_vec + unk_pad_vec + glove_vec
        glove_vec = np.array(glove_vec, dtype=np.float32)
        self.glookup = glove_vec
        self.token_dim = cfg.WORD_DIM + cfg.TAG_DIM # + glove_dim
        self.vocabulary = vocabulary
        self.pc, self.cfg = pc, cfg
        self.spec = (cfg, vocabulary)

    def __call__(
        self,
        indexes: Dict[str, List[Indices]],
        is_train=False) -> List[dy.Expression]:
        len_s = len(indexes['head'][0])
        batch_num = len(indexes['head'])
        vectors = []
        cntg = [Counter() for _ in range(batch_num)]
        cntt = [Counter() for _ in range(batch_num)]
        for i in range(len_s):
            # map token indexes -> vector
            g_idxes = [indexes['word']['glove'][x][i] for x in range(batch_num)]
            ng_idxes = [cntg[k][x] for k, x in enumerate(g_idxes)]
            for k, x in enumerate(g_idxes):
                cntg[k].update([x,])
            t_idxes = [indexes['tag']['tag'][x][i] for x in range(batch_num)]
            nt_idxes = [cntt[k][x] for k, x in enumerate(t_idxes)]
            for k, x in enumerate(t_idxes):
                cntt[k].update([x,])
            w_vec = dy.inputTensor(self.glookup[g_idxes, :].T, True)
            t_vec = dy.lookup_batch(self.tlookup, t_idxes)
            ng_vec = dy.lookup_batch(self.nglookup, ng_idxes)
            nt_vec = dy.lookup_batch(self.ntlookup, nt_idxes)

            # build token mask with dropout scale
            # For only word dropped: tag * 3
            # For only tag dropped: word * 1.5
            # For both word and tag dropped: 0 vector
            if is_train:
                wm = np.random.binomial(1, 1.-self.cfg.WORD_DROP, batch_num).astype(np.float32)
                tm = np.random.binomial(1, 1.-self.cfg.TAG_DROP, batch_num).astype(np.float32)
                scale = np.logical_or(wm, tm) * 3 / (2*wm + tm + 1e-12)
                wm *= scale
                tm *= scale
                w_vec *= dy.inputTensor(wm, batched=True)
                t_vec *= dy.inputTensor(tm, batched=True)

            w_vec = dy.concatenate([w_vec, ng_vec])
            t_vec = dy.concatenate([t_vec, nt_vec])
            vectors.append(dy.concatenate([w_vec, t_vec]))

        res = []
        zerox = dy.zeros(vectors[0].dim()[0], vectors[0].dim()[1])
        for i in range(len_s):
            tmp = vectors[max(i-self.cfg.WIN//2, 0): i+self.cfg.WIN//2+1]
            if i < self.cfg.WIN//2:
                tmp = [zerox,]*(self.cfg.WIN-len(tmp)) + tmp
            if i >= len_s-self.cfg.WIN//2:
                tmp = tmp + [zerox,]*(self.cfg.WIN-len(tmp))
            x = dy.concatenate(tmp)
                
            res.append(x)
        return res
        # return vectors

    @staticmethod
    def from_spec(spec, model):
        """Create and return a new instane with the needed parameters.

        It is one of the prerequisites for Dynet save/load method.
        """
        cfg, vocabulary = spec
        return WinRepresentation(model, cfg, vocabulary)

    def param_collection(self):
        """Return a :code:`dynet.ParameterCollection` object with the parameters.

        It is one of the prerequisites for Dynet save/load method.
        """
        return self.pc
