import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.autograd import Variable
import geoopt.manifolds.poincare.math as pmath_geo
import math, itertools


class SimpleAttn(torch.nn.Module):
    def __init__(self, in_shape, use_attention=True, maxlen=None):
        super(SimpleAttn, self).__init__()
        self.use_attention = use_attention
        if self.use_attention:
            self.W1 = torch.nn.Linear(in_shape, in_shape)
            self.W2 = torch.nn.Linear(in_shape, in_shape)
            self.V = torch.nn.Linear(in_shape, 1)
        if maxlen is not None:
            self.arange = torch.arange(maxlen)

    def forward(self, full, last, lens=None, dim=1):
        """
        full : B*30*in_shape
        last : B*1*in_shape
        lens: B*1
        """
        if self.use_attention:
            score = self.V(F.tanh(self.W1(last) + self.W2(full)))

            if lens != None:
                mask = self.arange[None, :] < lens[:, None]
                score[~mask] = float("-inf")

            attention_weights = F.softmax(score, dim=dim)
            context_vector = attention_weights * full
            context_vector = torch.sum(context_vector, dim=dim)
            return context_vector
        else:
            if lens != None:
                mask = self.arange[None, :] < lens[:, None]
                mask = mask.type(torch.float).unsqueeze(-1).cuda()
                context_vector = full * mask
                context_vector = torch.mean(context_vector, dim=dim)
                return context_vector
            else:
                return torch.mean(full, dim=dim)


def one_rnn_transform(W, h, U, x, c):
    W_otimes_h = pmath_geo.mobius_matvec(W, h, c=c)
    U_otimes_x = pmath_geo.mobius_matvec(U, x, c=c)
    Wh_plus_Ux = pmath_geo.mobius_add(W_otimes_h, U_otimes_x, c=c)
    return Wh_plus_Ux

