from typing import Dict, TypeVar, List
from antu.io.vocabulary import Vocabulary
from collections import Counter
from antu.io.ext_embedding_readers import fasttext_reader
from antu.nn.dynet.attention.biaffine import BiaffineAttention
from antu.nn.dynet.units.graph_nn_unit import GraphNNUnit
from antu.nn.dynet.multi_layer_perception import MLP
Indices = TypeVar("Indices", List[int], List[List[int]])
import numpy as np
import dynet as dy
import random, math


def leaky_relu(x):
    return dy.bmax(.1*x, x)

class GraphRepresentation(object):

    def __init__(
        self,
        model,
        cfg,
        vocabulary: Vocabulary):

        pc = model.add_subcollection()
        word_num = vocabulary.get_vocab_size('word')
        tag_num = vocabulary.get_vocab_size('tag')
        self.nglookup = pc.lookup_parameters_from_numpy(
            np.random.randn(100, cfg.NUM_DIM).astype(np.float32))
        # self.nglookup = pc.lookup_parameters_from_numpy(
        #     np.zeros((100, cfg.NUM_DIM), dtype=np.float32))
        # self.ntlookup = pc.lookup_parameters_from_numpy(
        #     np.zeros((100, cfg.NUM_DIM), dtype=np.float32))

        self.ntlookup = pc.lookup_parameters_from_numpy(
            np.random.randn(100, cfg.NUM_DIM).astype(np.float32))
        self.tlookup = pc.lookup_parameters_from_numpy(
            np.random.randn(tag_num, cfg.TAG_DIM).astype(np.float32))
        _, glove_vec = fasttext_reader(cfg.GLOVE, False)
        glove_dim = len(glove_vec[0])
        unk_pad_vec = [[0.0 for _ in range(glove_dim)]]
        glove_num = vocabulary.get_vocab_size('glove')
        glove_vec = unk_pad_vec + unk_pad_vec + glove_vec
        glove_vec = np.array(glove_vec, dtype=np.float32)
        self.glookup = glove_vec
        self.token_dim = cfg.WORD_DIM + cfg.TAG_DIM + cfg.NUM_DIM * 2
        self.head_MLP = MLP(pc, [self.token_dim, self.token_dim], leaky_relu, 'orthonormal', cfg.MLP_BIAS, cfg.MLP_DROP)
        self.dept_MLP = MLP(pc, [self.token_dim, self.token_dim], leaky_relu, 'orthonormal', cfg.MLP_BIAS, cfg.MLP_DROP)
        self.BiAttn   = BiaffineAttention(pc, self.token_dim, self.token_dim, 1, True, 0)
        self.head_GNN    = GraphNNUnit(pc, self.token_dim, self.token_dim, leaky_relu, 'orthonormal')
        self.dept_GNN    = GraphNNUnit(pc, self.token_dim, self.token_dim, leaky_relu, 'orthonormal')
        self.vocabulary = vocabulary
        self.pc, self.cfg = pc, cfg
        self.spec = (cfg, vocabulary)

    def __call__(
        self,
        indexes: Dict[str, List[Indices]],
        masks: List[np.array],
        truth: List[int],
        is_train=False) -> List[dy.Expression]:
        len_s = len(indexes['head'][0])
        batch_num = len(indexes['head'])
        vectors = []
        cntg = [Counter() for _ in range(batch_num)]
        cntt = [Counter() for _ in range(batch_num)]
        for i in range(len_s):
            # print(i)
            # map token indexes -> vector
            g_idxes = [indexes['word']['glove'][x][i] for x in range(batch_num)]
            ng_idxes = [cntg[k][x] for k, x in enumerate(g_idxes)]
            for k, x in enumerate(g_idxes):
                cntg[k].update([x,])
            t_idxes = [indexes['tag']['tag'][x][i] for x in range(batch_num)]
            nt_idxes = [cntt[k][x] for k, x in enumerate(t_idxes)]
            for k, x in enumerate(t_idxes):
                cntt[k].update([x,])
            w_vec = dy.inputTensor(self.glookup[g_idxes, :].T, True)
            t_vec = dy.lookup_batch(self.tlookup, t_idxes)
            ng_vec = dy.lookup_batch(self.nglookup, ng_idxes)
            nt_vec = dy.lookup_batch(self.ntlookup, nt_idxes)

            # build token mask with dropout scale
            # For only word dropped: tag * 3
            # For only tag dropped: word * 1.5
            # For both word and tag dropped: 0 vector
            if is_train:
                wm = np.random.binomial(1, 1.-self.cfg.WORD_DROP, batch_num).astype(np.float32)
                tm = np.random.binomial(1, 1.-self.cfg.TAG_DROP, batch_num).astype(np.float32)
                scale = np.logical_or(wm, tm) * 3 / (2*wm + tm + 1e-12)
                wm *= scale
                tm *= scale
                w_vec *= dy.inputTensor(wm, batched=True)
                t_vec *= dy.inputTensor(tm, batched=True)

            w_vec = dy.concatenate([w_vec, ng_vec])
            t_vec = dy.concatenate([t_vec, nt_vec])
            vectors.append(dy.concatenate([w_vec, t_vec]))

        X = dy.concatenate_cols(vectors)
        if is_train: X = dy.dropout_dim(X, 1, self.cfg.MLP_DROP)
        H = leaky_relu(self.head_MLP(X, is_train))
        D = leaky_relu(self.dept_MLP(X, is_train))
        if is_train:
            H = dy.dropout_dim(H, 1, self.cfg.MLP_DROP)
            D = dy.dropout_dim(D, 1, self.cfg.MLP_DROP)
        masks_2D = dy.inputTensor(masks['2D'], True)
        total_token = masks['2D'].sum()
        A = dy.cmult(dy.logistic(self.BiAttn(H, D)/math.sqrt(self.token_dim)), masks_2D)
        if is_train:
            # A_ = dy.reshape(A, (1,), len_s*len_s*batch_num)
            # A = dy.reshape(A,(len_s, len_s, batch_num))
            Y = dy.inputTensor(truth['LP1'], True)
            Y = dy.reshape(Y, (len_s, len_s), batch_num)
            B = dy.sum_dim(A, [0,])
            C = dy.sum_dim(A, [1,])
            arc_losses = dy.binary_log_loss(A, Y)
            REG = dy.reshape(dy.inputTensor(truth['REG']), (len_s,), batch_num)
            reg_losses = dy.squared_distance(B, REG) + dy.squared_distance(C, REG)
            # arc_loss = (dy.sum_batches(arc_losses)) / total_token + dy.sum_batches(reg_losses)/math.sqrt(total_token)
            arc_loss = (dy.sum_batches(arc_losses)+dy.sum_batches(reg_losses)) / total_token
            
            
        # H_, X_ = H*A, D*dy.transpose(A)
        # FX = H_ + X_
        # H = self.head_GNN(FX, H, is_train)
        # D = self.dept_GNN(FX, D, is_train)
        if is_train:
            return arc_loss
        else:
            # A = dy.reshape(A, (1,), len_s*len_s*batch_num).npvalue()
            # A = np.reshape(A, (len_s*len_s*batch_num,)) 
            A = A.npvalue()
            pred = np.where(np.reshape(A, (len_s*len_s*batch_num,)) <= 0.2, 0, 1)
            total_token_cnt = sum(pred != np.array(truth['LP1']))
            # tot1 = tot1_T = 0

            # for i in range(len(truth['LP1'])):
            #     if truth['LP1'][i] == 1:
            #         tot1 += 1
            #         if pred[i] == 1:
            #             tot1_T += 1
            # A = A.npvalue()
            tot1_T = tot1 = 0
            for i in range(batch_num):
                a = A[:,:,i]
                # flag = [True] * len_s
                # flag[0] = False
                flag = True
                for j in range(1, truth['length'][i]-1):
                    for k in range(j+1, truth['length'][i]):
                        if a[j-1][k]+a[k][j-1] > a[j-1][j]+a[j][j-1]:
                            flag = False
                            break
                    if not flag: break
                if flag:
                    tot1_T += 1
                tot1 += 1
            return total_token-total_token_cnt, total_token, tot1_T, tot1

    @staticmethod
    def from_spec(spec, model):
        """Create and return a new instane with the needed parameters.

        It is one of the prerequisites for Dynet save/load method.
        """
        cfg, vocabulary = spec
        return GraphRepresentation(model, cfg, vocabulary)

    def param_collection(self):
        """Return a :code:`dynet.ParameterCollection` object with the parameters.

        It is one of the prerequisites for Dynet save/load method.
        """
        return self.pc
