import torch
from torch import nn
from models.activations.func import tanh_exp

class MSU(nn.Module):
  def __init__(self, l_dim, a_dim, v_dim, seq_size, l_head):
    super().__init__()
    self.a_integ = None if a_dim is None else ModalityExciter(l_dim, a_dim, seq_size, l_head)
    self.v_integ = None if v_dim is None else ModalityExciter(l_dim, v_dim, seq_size, l_head)
    self.emb = MSUEmbedder(l_dim, seq_size, l_head)

  def forward(self, l_emb, a_emb, v_emb, att_mask=None):
    a_exc = None if a_emb is None else self.a_integ(l_emb, a_emb, att_mask)
    v_exc = None if v_emb is None else self.v_integ(l_emb, v_emb, att_mask)
    modals = a_exc if v_emb is None else v_exc if a_emb is None else a_exc + v_exc
    l_exc = self.emb(modals, l_emb, att_mask)
    return l_exc, a_exc, v_exc

class ModalityExciter(nn.Module):
  def __init__(self, l_dim, modal_dim, seq_size, l_head):
    super().__init__()
    self.cat_dense = nn.Linear(l_dim + modal_dim, l_dim)
    self.modal_dense = nn.Linear(modal_dim, l_dim)
    self.emb = MSUEmbedder(l_dim, seq_size, l_head)

  def forward(self, l_emb, modal_emb, att_mask=None):
    cat_emb = tanh_exp(self.cat_dense(torch.cat((l_emb, modal_emb), -1)))
    modal_emb = self.modal_dense(modal_emb)
    return self.emb(cat_emb, modal_emb, att_mask)

class MSUEmbedder(nn.Module):
  def __init__(self, l_dim, seq_size, l_head, emb_norm_min=2, source_dim=None):
    super().__init__()
    self.dense = nn.Linear(l_dim, l_dim)
    self.norm =  nn.LayerNorm((seq_size, l_dim))
    self.dropout = nn.Dropout(0.1)
    self.attn = nn.MultiheadAttention(l_dim, l_head, kdim=source_dim, vdim=source_dim)
    self.emb_norm_min = emb_norm_min
  def forward(self, sorce, target, att_mask=None):
    att_mask = None if att_mask == None else ~att_mask.bool()
    target = target.transpose(0,1).contiguous()
    sorce = sorce.transpose(0,1).contiguous()
    atted = self.attn(target, sorce, sorce, key_padding_mask=att_mask)[0].transpose(0,1).contiguous()
    target = target.transpose(0,1).contiguous()
    atted_densed = self.dense(atted)

    eps=1e-6
    atted_densed_norm = atted_densed.norm(2, dim=-1)
    eps_tensor = torch.ones_like(atted_densed_norm).to(atted_densed_norm.device) * eps
    atted_densed_norm = torch.max(atted_densed_norm, eps_tensor)
    target_norm = torch.max(target.norm(2, dim=-1), eps_tensor)
    threshold = torch.ones_like(atted_densed_norm).to(atted_densed_norm.device) * self.emb_norm_min
    modal_alignment = torch.max(atted_densed_norm / target_norm, threshold)

    return self.norm(self.dropout(atted_densed / modal_alignment.unsqueeze(-1)) + target)
