import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.autograd import Variable
import numpy as np
import math
from sklearn.cluster import MiniBatchKMeans
from sklearn.cluster import KMeans


from model.tree import Tree, head_to_tree, tree_to_adj
from model.sinkhorn import SinkhornDistance
from RAMS.classification_code.code.utils import constant, torch_utils

from RAMS.classification_code.baseline.LSR.code.models.reasoner import DynamicReasoner

class GCNClassifier(nn.Module):
    """ A wrapper classifier for GCNRelationModel. """
    def __init__(self, opt, emb_matrix=None):
        super().__init__()
        self.gcn_model = GCNRelationModel(opt, emb_matrix=emb_matrix)
        in_dim = opt['hidden_dim']
        self.classifier = nn.Linear(in_dim+opt['num_class'], opt['num_class'])
        # self.classifier = nn.Linear(in_dim, opt['num_class'])
        self.opt = opt

    def conv_l2(self):
        return self.gcn_model.gcn.conv_l2()

    def forward(self, inputs):
        outputs, class_scores = self.gcn_model(inputs)
        # outputs = self.gcn_model(inputs)
        logits = self.classifier(torch.cat([outputs,class_scores],dim=1))
        # logits = self.classifier(outputs)
        return logits
        # return class_scores

class GCNRelationModel(nn.Module):
    def __init__(self, opt, emb_matrix=None):
        super().__init__()
        self.opt = opt
        self.emb_matrix = emb_matrix

        # create embedding layers
        self.emb = nn.Embedding(opt['vocab_size'], opt['emb_dim'], padding_idx=constant.PAD_ID)
        self.pos_emb = nn.Embedding(len(constant.POS_TO_ID), opt['pos_dim']) if opt['pos_dim'] > 0 else None
        self.ner_emb = nn.Embedding(len(constant.NER_TO_ID), opt['ner_dim']) if opt['ner_dim'] > 0 else None
        embeddings = (self.emb, self.pos_emb, self.ner_emb)
        self.init_embeddings()

        # gcn layer
        self.gcn = GCN(opt, embeddings, opt['hidden_dim'], opt['num_layers'])

        # output mlp layers
        in_dim = opt['hidden_dim']*3*2
        layers = [nn.Linear(in_dim, opt['hidden_dim']), nn.ReLU()]
        for _ in range(self.opt['mlp_layers']-1):
            layers += [nn.Linear(opt['hidden_dim'], opt['hidden_dim']), nn.ReLU()]
        self.out_mlp = nn.Sequential(*layers)

    def init_embeddings(self):
        if self.emb_matrix is None:
            self.emb.weight.data[1:,:].uniform_(-1.0, 1.0)
        else:
            self.emb_matrix = torch.from_numpy(self.emb_matrix)
            self.emb.weight.data.copy_(self.emb_matrix)
        # decide finetuning
        if self.opt['topn'] <= 0:
            print("Do not finetune word embedding layer.")
            self.emb.weight.requires_grad = False
        elif self.opt['topn'] < self.opt['vocab_size']:
            print("Finetune top {} word embeddings.".format(self.opt['topn']))
            self.emb.weight.register_hook(lambda x: \
                    torch_utils.keep_partial_grad(x, self.opt['topn']))
        else:
            print("Finetune all embeddings.")

    def forward(self, inputs):
        words, masks, head, trigger_pos, arg_pos, dep_path, adj, bert, scores = inputs # unpack
        l = (masks.data.cpu().numpy() == 0).astype(np.int64).sum(1)
        maxlen = max(l)

        # def inputs_to_tree_reps(head, words, l, prune, subj_pos, obj_pos):
        #     trees = [head_to_tree(head[i], l[i]) for i in range(len(l))]
        #     adj = [tree_to_adj(maxlen, tree, directed=False, self_loop=False).reshape(1, maxlen, maxlen) for tree in trees]
        #     adj = np.concatenate(adj, axis=0)
        #     adj = torch.from_numpy(adj)
        #     return Variable(adj.cuda()) if self.opt['cuda'] else Variable(adj)
        #
        # adj = inputs_to_tree_reps(head.data, words.data, l, self.opt['prune_k'], trigger_pos.data, arg_pos.data)
        h, pool_mask, class_scores = self.gcn(inputs)
        # h, pool_mask = self.gcn(inputs)

        # pooling
        subj_mask, obj_mask = trigger_pos.eq(0).eq(0).unsqueeze(2), arg_pos.eq(0).eq(0).unsqueeze(2) # invert mask
        pool_type = self.opt['pooling']
        h_out = pool(h, pool_mask, type=pool_type)
        subj_out = pool(h, subj_mask, type=pool_type)
        obj_out = pool(h, obj_mask, type=pool_type)
        outputs = torch.cat([h_out, subj_out, obj_out], dim=1)
        outputs = self.out_mlp(outputs)
        return outputs, class_scores
        # return outputs

class GCN(nn.Module):
    """ A GCN/Contextualized GCN module operated on dependency graphs. """
    def __init__(self, opt, embeddings, mem_dim, num_layers):
        super(GCN, self).__init__()
        self.opt = opt
        self.layers = num_layers
        self.use_cuda = opt['cuda']
        self.mem_dim = mem_dim
        self.in_dim = opt['emb_dim'] + 768

        self.emb, self.pos_emb, self.ner_emb = embeddings

        # rnn layer
        if self.opt.get('rnn', False):
            input_size = self.in_dim
            self.rnn = nn.LSTM(input_size, opt['rnn_hidden'], opt['rnn_layers'], batch_first=True, \
                    dropout=opt['rnn_dropout'], bidirectional=True)
            self.in_dim = opt['rnn_hidden'] * 2
            self.rnn_drop = nn.Dropout(opt['rnn_dropout']) # use on last layer output

        self.in_drop = nn.Dropout(opt['input_dropout'])
        self.gcn_drop = nn.Dropout(opt['gcn_dropout'])
        self.lstm_w = nn.Linear(2*opt['rnn_hidden'],opt['rnn_hidden'])
        self.gcn_w = nn.Sequential(nn.Linear(opt['rnn_hidden'],opt['rnn_hidden']),nn.Sigmoid())

        # gcn layer
        self.W = nn.ModuleList()
        for layer in range(self.layers):
            input_dim = self.in_dim if layer == 0 else self.mem_dim
            self.W.append(nn.Linear(input_dim, self.mem_dim))

        self.W1 = nn.ModuleList()
        for layer in range(self.layers):
            input_dim = self.in_dim if layer == 0 else self.mem_dim
            self.W1.append(nn.Linear(input_dim, self.mem_dim))

        self.W2 = nn.ModuleList()
        for layer in range(self.layers):
            input_dim = self.in_dim if layer == 0 else self.mem_dim
            self.W2.append(nn.Linear(input_dim, self.mem_dim))

        self.W3 = nn.ModuleList()
        for layer in range(self.layers):
            input_dim = self.in_dim if layer == 0 else self.mem_dim
            self.W3.append(nn.Linear(input_dim, self.mem_dim))

        ## OT
        self.sinkhorn = SinkhornDistance(eps=0.1, max_iter=100, reduction='mean')
        self.classes = Variable(torch.rand(opt['num_class'],1*opt['rnn_hidden'])).cuda()
        self.kmeans = MiniBatchKMeans(n_clusters=20,
                                 random_state=0,
                                 batch_size=self.opt['batch_size'],
                                 max_iter=10)

    def conv_l2(self):
        conv_weights = []
        for w in self.W:
            conv_weights += [w.weight, w.bias]
        return sum([x.pow(2).sum() for x in conv_weights])

    def encode_with_rnn(self, rnn_inputs, masks, batch_size):
        seq_lens = list(masks.data.eq(constant.PAD_ID).long().sum(1).squeeze())
        h0, c0 = rnn_zero_state(batch_size, self.opt['rnn_hidden'], self.opt['rnn_layers'])
        rnn_inputs = nn.utils.rnn.pack_padded_sequence(rnn_inputs, seq_lens, batch_first=True)
        rnn_outputs, (ht, ct) = self.rnn(rnn_inputs, (h0, c0))
        rnn_outputs, _ = nn.utils.rnn.pad_packed_sequence(rnn_outputs, batch_first=True)
        return rnn_outputs

    def forward(self, inputs):
        words, masks, head, trigger_pos, arg_pos, dep_path, adj, bert, scores = inputs # unpack

        ## embedding
        word_embs = self.emb(words)
        embs = [word_embs, bert]
        embs = torch.cat(embs, dim=2)
        embs = self.in_drop(embs)

        ## RNN
        lstm_outputs = self.rnn_drop(self.encode_with_rnn(embs, masks, words.size()[0]))
        gcn_inputs = lstm_outputs.clone()


        ## OT
        sf0 = nn.Softmax(0)
        lefts = []
        rights = []
        mus = []
        nus = []
        clusterss = []
        for i in range(dep_path.shape[0]):
            left = lstm_outputs[i][~dep_path[i].bool()]
            left = torch.cat([left, torch.zeros(lstm_outputs[i].shape[0]-left.shape[0],left.shape[1]).cuda()],dim=0)
            lefts += [left]
            right = torch.cat([lstm_outputs[i][dep_path[i].bool()],torch.max(left, 0)[0].unsqueeze(0)], dim=0)
            right = torch.cat([right, torch.zeros(lstm_outputs[i].shape[0]-right.shape[0]+1,right.shape[1]).cuda()])
            rights += [right]
            right_scores = scores[i][dep_path[i].bool()].float()
            left_scores = scores[i][~dep_path[i].bool()].float()
            nu = torch.cat([0.3*sf0(right_scores),torch.Tensor([0.7]).cuda()], dim=0)
            nu = torch.cat([nu,torch.zeros(lstm_outputs[i].shape[0]-nu.shape[0]+1).cuda()],dim=0)
            nus += [nu]
            mu = sf0(left_scores)
            mu = torch.cat([mu,torch.zeros(lstm_outputs[i].shape[0]-mu.shape[0]).cuda()], dim=0)
            mus += [mu]

            # left_clone = left.clone()
            # left_clone[left_clone != left_clone] = 0
            # left_clone[left_clone >= constant.INFINITY_NUMBER] = 0
            # clusters = self.kmeans.fit(left_clone.cpu().detach().numpy())
            # clusters_labels_ = torch.Tensor(clusters.labels_).cuda()
            # lefts += [torch.Tensor(clusters.cluster_centers_).unsqueeze(0).cuda()]
            # clusterss += [torch.cat([clusters_labels_.unsqueeze(1), torch.zeros(lstm_outputs[i].shape[0]-clusters_labels_.shape[0],1).cuda()],dim=0)]
            # left_cluster_scores = []
            # for j in range(20):
            #     cluster_mask = clusters_labels_.eq(j)
            #     left_cluster_scores += [left_scores[cluster_mask.bool()].mean().unsqueeze(0)]
            # before_score = torch.cat(left_cluster_scores,dim=0)
            # before_score[before_score != before_score] = 0
            # mu = sf0(before_score)
            # mus += [mu]

        lefts = torch.stack(lefts, dim=0).squeeze()
        # clusterss = torch.stack(clusterss, dim=0).squeeze()
        rights = torch.stack(rights, dim=0)
        mus = torch.stack(mus, dim=0)
        nus = torch.stack(nus, dim=0)

        cost, pi, C = self.sinkhorn(lefts.cpu(), rights.cpu(), mu=mus.cpu(), nu=nus.cpu(), cuda=False)
        pi = pi.cuda()
        algins = torch.max(pi, 2)[1]
        path_num = dep_path.sum(1).unsqueeze(1).repeat(1,algins.shape[1])
        algins -= path_num
        algins[algins >= 0] = 0
        algins[algins < 0] = 1
        # algins = torch.gather(algins.long(),1,clusterss.long())
        algins = algins.masked_fill(masks, 0)
        algins += dep_path
        algins[algins > 0] = 1

        sel = algins.unsqueeze(1).repeat(1,adj.shape[1],1)
        adj1 = sel*adj
        adj2 = ~sel*adj
        notsel = ~sel.transpose(1,2)
        sel = sel.float().bmm(notsel.float())
        sel[sel > 0] = 1
        adj3 = sel*adj

        denom1 = adj1.sum(2).unsqueeze(2) + 1
        denom2 = adj2.sum(2).unsqueeze(2) + 1
        denom3 = adj3.sum(2).unsqueeze(2) + 1
        for l in range(self.layers):
            Ax1 = adj1.bmm(gcn_inputs)
            Ax2 = adj2.bmm(gcn_inputs)
            Ax3 = adj3.bmm(gcn_inputs)
            AxW1 = self.W1[l](Ax1)
            AxW2 = self.W2[l](Ax2)
            AxW3 = self.W3[l](Ax3)
            AxW = AxW1 + AxW2 + AxW3 + self.W1[l](gcn_inputs)  # self loop
            AxW = AxW / (denom1+denom2+denom3)

            gAxW = F.relu(AxW)
            gcn_inputs = self.gcn_drop(gAxW) if l < self.layers - 1 else gAxW


        gcn_inputs0 = lstm_outputs.clone()
        denom = adj.sum(2).unsqueeze(2) + 1
        for l in range(self.layers):
            Ax = adj.bmm(gcn_inputs0)
            AxW = self.W[l](Ax)
            AxW = AxW + self.W1[l](gcn_inputs0)  # self loop
            AxW = AxW / denom

            gAxW = F.relu(AxW)
            gcn_inputs0 = self.gcn_drop(gAxW) if l < self.layers - 1 else gAxW

        class_scores = []
        for i in range(dep_path.shape[0]):
            selected_gcn = gcn_inputs[i][algins[i].bool()]
            selected_lstm = lstm_outputs[i][algins[i].bool()]
            selected = 0.000001 * self.gcn_w(selected_gcn) + self.lstm_w(selected_lstm)
            rep = torch.max(selected, 0)[0]
            class_score = []
            for j in range(self.opt['num_class']):
                class_score += [torch.dot(rep,self.classes[j])]
            class_scores += [torch.stack(class_score, dim=0)]
            ## regularize

        class_scores = torch.stack(class_scores, dim=0)

        return lstm_outputs, masks.unsqueeze(2), class_scores
        # return lstm_outputs, masks.unsqueeze(2)

def pool(h, mask, type='max'):
    if type == 'max':
        h = h.masked_fill(mask, -constant.INFINITY_NUMBER)
        return torch.max(h, 1)[0]
    elif type == 'avg':
        h = h.masked_fill(mask, 0)
        return h.sum(1) / (mask.size(1) - mask.float().sum(1))
    else:
        h = h.masked_fill(mask, 0)
        return h.sum(1)

def rnn_zero_state(batch_size, hidden_dim, num_layers, bidirectional=True, use_cuda=True):
    total_layers = num_layers * 2 if bidirectional else num_layers
    state_shape = (total_layers, batch_size, hidden_dim)
    h0 = c0 = Variable(torch.zeros(*state_shape), requires_grad=False)
    if use_cuda:
        return h0.cuda(), c0.cuda()
    else:
        return h0, c0

