import math
import torch
import torch.functional
from torch import nn
from torch.nn import init
from torch.autograd import Variable
import matplotlib.pyplot as plt

class MultiplicativeAttention(nn.Module):
    def __init__(self, input_dim):
        super(MultiplicativeAttention, self).__init__()
        self.W = nn.Parameter(torch.rand(1, input_dim))
        self.W.register_hook(self.printnorm)

    def printnorm(self, input):
        print('W norm: {}'.format((torch.norm(input.grad)) if input.grad is not None else ""))

    def forward(self, src, trg, encoded_doc, src_id, trg_id, start, end, src_dict, trg_dict, ux_word):
        encoded_doc = torch.t(encoded_doc)
        if(start == end):
            print(src_id, trg_id)
        encoded_interior = torch.index_select(encoded_doc, 1,
         Variable(torch.LongTensor(range(start, end))))
        if ux_word is None:
            ux_word = torch.mm(self.W, encoded_doc)  # 1 x doc_length
        scores = torch.index_select(ux_word, 1,
         Variable(torch.LongTensor(range(start, end))))
        scores = scores.squeeze(0)
        scores = nn.functional.softmax(scores, 0)
        return scores, torch.mv(encoded_interior, scores), src_dict, trg_dict, ux_word

class MultilinearAttention(nn.Module):
    def __init__(self, input_dim, n_components):
        super(MultilinearAttention, self).__init__()
        self.U_src = nn.Parameter(torch.rand(n_components, input_dim))
        self.U_trg = nn.Parameter(torch.rand(n_components, input_dim))
        self.U_word = nn.Parameter(torch.rand(n_components, input_dim))

    def forward(self, src, trg, encoded_doc, src_id, trg_id, src_dict, trg_dict, ux_word):
        encoded_doc = torch.t(encoded_doc)

        if src_id in src_dict:
            ux_src = src_dict[src_id]
        else:
            ux_src = torch.mv(self.U_src, src)  # n_components
            src_dict[src_id] = ux_src
        if trg_id in trg_dict:
            ux_trg = trg_dict[trg_id]
        else:
            ux_trg = torch.mv(self.U_trg, trg)  # n_components
            trg_dict[trg_id] = ux_trg
        if ux_word is None:
            ux_word = torch.mm(self.U_word, encoded_doc)  # n_components x doc_length
        scores = torch.mv(torch.t(ux_word), torch.mul(ux_src, ux_trg))  # doc_length
        scores = nn.functional.softmax(scores, 0)
        return scores, torch.mv(encoded_doc, scores), src_dict, trg_dict, ux_word


class InteriorMultilinear(nn.Module):
    def __init__(self, input_dim, n_components):
        super(InteriorMultilinear, self).__init__()
        self.U_src = nn.Parameter(torch.rand(n_components, input_dim))
        self.U_trg = nn.Parameter(torch.rand(n_components, input_dim))
        self.U_word = nn.Parameter(torch.rand(n_components, input_dim))
    #     self.U_src.register_hook(self.printnorm)
    #     self.U_trg.register_hook(self.printnorm)
    #     self.U_word.register_hook(self.printnorm)

    # def printnorm(self, input):

    def forward(self, src, trg, encoded_doc, src_id, trg_id, start, end, src_dict, trg_dict, ux_word):
        encoded_doc = torch.t(encoded_doc)
        if(start == end):
            print(src_id, trg_id)
        encoded_interior = torch.index_select(encoded_doc, 1,
         Variable(torch.LongTensor(range(start, end))))
        if src_id in src_dict:
            ux_src = src_dict[src_id]
        else:
            ux_src = torch.mv(self.U_src, src)  # n_components
            src_dict[src_id] = ux_src
        if trg_id in trg_dict:
            ux_trg = trg_dict[trg_id]
        else:
            ux_trg = torch.mv(self.U_trg, trg)  # n_components
            trg_dict[trg_id] = ux_trg
        if ux_word is None:
            ux_word = torch.mm(self.U_word, encoded_doc)  # n_components x doc_length
        ux_interior = torch.index_select(ux_word, 1,
         Variable(torch.LongTensor(range(start, end))))
        scores = torch.mv(torch.t(ux_interior), torch.mul(ux_src, ux_trg))  # doc_length
        scores = nn.functional.softmax(scores, 0)
        return scores, torch.mv(encoded_interior, scores), src_dict, trg_dict, ux_word


class MLPAttention(nn.Module):
    def __init__(self, concat_size):
        self.concat_size = concat_size
        super(MLPAttention, self).__init__()
        if self.concat_size % 3 == 0:
            self.hidden_size = int(self.concat_size/3)
            self.linear = nn.Linear(self.concat_size, self.hidden_size)
        else:
            raise ValueError
        self.context = nn.Linear(self.hidden_size, 1, bias=False)
        self.sm = nn.Softmax()
        self.tanh = nn.Tanh()

    def forward(self, H, lengths=None):
        """
        H: context, shape=(batch_sz, source_length, dim)
        """
        batch_size, seq_len, sz = H.size()
        assert sz == self.concat_size
        # compute alignment score with context vector, independently over seqs
        U = self.linear(H.view(-1, self.concat_size))
        U = self.tanh(U)
        scores = self.context(U)
        # normalize scores for each sequence
        scores = scores.view(batch_size, seq_len)
        # softmax-specific dealing with masks: set to -inf
        if lengths:
            for i in range(batch_size):
                if lengths[i] < seq_len:
                    scores[i, lengths[i]:].data.fill_(-float('inf'))
        scores = self.sm(scores)
        # take average of input sequences, weighted by attention
        weighted_avg = torch.bmm(H.transpose(1, 2)[0:, self.concat_size-self.hidden_size:, 0:], scores.unsqueeze(2))
        weighted_avg = weighted_avg.squeeze(2)
        return scores, weighted_avg
