from argparse import ArgumentParser

import torch
import torch.nn as nn

class BiLSTM(nn.Module):
    def __init__(self,args,config):
        hidden_size=config.hidden_size
        vocab_size=config.vocab_size
        super(BiLSTM, self).__init__()
        self.device = torch.device('cuda') if args.gpu else torch.device('cpu')
        self.hidden_size=hidden_size
        self.embedding=nn.Embedding(num_embeddings=vocab_size+1000,embedding_dim=hidden_size,padding_idx=config.pad_token_id)
        self.softmax = nn.Softmax(dim=0).cuda()
        self.l1=nn.Linear(1, hidden_size).cuda()
        self.r1=nn.ReLU(inplace=True).cuda()
        self.lstm = nn.LSTM(input_size=hidden_size, hidden_size=hidden_size, bidirectional=True,batch_first=True).cuda()
        # fc
        self.fc = nn.Linear(2*hidden_size, hidden_size).cuda()

    def forward(self, input_ids, attention_mask=None):
        # X: [batch_size, max_len, n_class]
        batch_size = input_ids.shape[0]
        max_len=input_ids.shape[1]

        input=self.embedding(input_ids)

        hidden_state = torch.randn(1*2,batch_size, self.hidden_size).to(self.device)   # [num_layers(=1) * num_directions(=2), batch_size, n_hidden]
        cell_state = torch.randn(1*2,batch_size, self.hidden_size).to(self.device)     # [num_layers(=1) * num_directions(=2), batch_size, n_hidden]

        outputs, (_, _) = self.lstm(input, (hidden_state, cell_state))
        outputs = outputs.reshape(batch_size*max_len,2*self.hidden_size).to(self.device)  # [batch_size, n_hidden * 2]
        out = self.fc(outputs)  # model : [batch_size, n_class]
        out=out.view(batch_size,max_len,self.hidden_size).to(self.device)
        return OUTPUT(out[:, 0, :].to(self.device), out)

class OUTPUT(object):
    def __init__(self, pooler_output, last_hidden_state):
        self.last_hidden_state = last_hidden_state
        self.pooler_output = pooler_output