from torch import nn
from pipetools import pipe
from models.activations.func import tanh_exp

class TransformerFFN(nn.Module):
  def __init__(self, dim):
    super().__init__()
    self.linear_f = nn.Linear(dim, dim)
    self.linear_l = nn.Linear(dim, dim)
  def forward(self, x):
    return (pipe | self.linear_f | tanh_exp | self.linear_l)(x)

class EncoderLayer(nn.Module):
  def __init__(self, dim, head=1):
    super().__init__()
    self.attn = nn.MultiheadAttention(dim, head)
    self.linear_o = TransformerFFN(dim)
    self.do = nn.Dropout(0.1)
    self.ln_f = nn.LayerNorm(dim)
    self.ln_l = nn.LayerNorm(dim)

  def forward(self, x, mask):
    x = x.permute(1,0,2)
    atted = self.ln_f(self.attn(x, x, x, ~mask)[0] + x)
    return self.ln_l(self.do(self.linear_o(atted)) + atted).permute(1,0,2)

class TransformerEncoder(nn.Module):
  def __init__(self, dim, head, layers):
    super().__init__()
    self.layers = nn.ModuleList([EncoderLayer(dim, head) for _ in range(layers)])
  def forward(self, x, mask):
    hidden = x
    for _, layer_module in enumerate(self.layers):
      hidden = layer_module(hidden, mask)
    return hidden

