import torch
import torch.nn as nn
class MLP(nn.Module):
    def __init__(self, args,hidden_size):
        super(MLP, self).__init__()
        self.hidden_size=hidden_size
        self.device = torch.device('cuda') if args.gpu else torch.device('cpu')
        self.linear = nn.Sequential(
            nn.Linear(1,hidden_size),
            nn.ReLU(inplace=True),
            nn.Linear(hidden_size, hidden_size),
            nn.ReLU(inplace=True),
            nn.Linear(hidden_size, hidden_size),
            nn.ReLU(inplace=True),
            nn.Linear(hidden_size, hidden_size),
            nn.Dropout(p=0.3)
        )
        self.softmax=nn.Softmax(dim=0)

    def forward(self, input_ids, attention_mask,token_type_ids=None):
        input=input_ids.view(input_ids.shape[0]*input_ids.shape[1],1).float().to(self.device)
        out = self.linear(input)
        out=out.view(input_ids.shape[0],input_ids.shape[1],self.hidden_size).to(self.device)
        return OUTPUT(out[:, 0, :].to(self.device),out)

class OUTPUT(object):
    def __init__(self,pooler_output,last_hidden_state):
        self.last_hidden_state=last_hidden_state
        self.pooler_output=pooler_output