import torch
from torch import nn


class BaseModel(nn.Module):
    def __init__(self, d_model: int = 17, mlp_size: int = 1, tag_size: int = 2) -> None:
        super(BaseModel, self).__init__()
        self.mlp = nn.Linear(d_model, mlp_size, bias=True)
        self.proj = nn.Linear(mlp_size, tag_size)
        # self.mlp_weight = nn.Parameter(torch.FloatTensor([
        #     -0.9519,  0.4437, -0.2534, -0.2536,  0.1615, -0.1089,  0.2335, -0.2801,  0.1137,  0.0458,  0.5948,  0.5689,  0.0593,  0.4672,  0.7049,  1.5863,  0.0789
        # ]).view(-1, 1), requires_grad=False)
        self.mlp_weight = nn.Parameter(torch.randn((d_model, 1), dtype=torch.float32), requires_grad=True)
        self.bias = nn.Parameter(torch.zeros((1), dtype=torch.float32), requires_grad=True)

        self.mlp_mask = nn.Parameter(torch.LongTensor([1, 1, 0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 0, 1, 1, 1, 0, 0]).view(1, -1), requires_grad=False)

        # self.mlp_weight = nn.Parameter(torch.FloatTensor([
        #     -0.9519,  0.4437, -0.2534, -0.2536,  0.1615, -0.1089,  0.2335, -0.2801,  0.1137,  0.0458,  0.5948,  0.5689,  0.0593,  0.4672,  0.7049,  1.5863,  0.0789
        # ]).view(-1, 1), requires_grad=False)

        # print(self.mlp.weight, self.mlp.bias)
        # print(self.proj.weight, self.proj.bias)
        # exit()
        self.activation = nn.Sigmoid()
    
    def forward(self, seqs):
        # seqs = seqs * self.mlp_mask
        logits = self.mlp(seqs)
        # logits = torch.matmul(seqs, self.mlp_weight) + 1.0448
        # logits = self.proj(logits)
        # weight = self.activation(self.mlp_weight * 100)
        # logits = torch.matmul(seqs, weight) + self.bias


        logits = self.activation(logits)

        return logits
