from modules.Layer import *

class WordLSTM(nn.Module):
    def __init__(self, vocab, config):
        super(WordLSTM, self).__init__()
        self.config = config
        #self.word_embed = nn.Embedding(vocab.vocab_size, config.word_dims, padding_idx=0)

        self.lstm = MyLSTM(
            input_size=config.word_dims,
            hidden_size=config.lstm_hiddens,
            num_layers=config.lstm_layers,
            batch_first=True,
            bidirectional=True,
            dropout_in=config.dropout_lstm_input,
            dropout_out=config.dropout_lstm_hidden,
        )

    def forward(self, x_extword_embed, masks):
        if self.training:
            x_extword_embed = drop_input(x_extword_embed, self.config.dropout_emb)
        outputs, _ = self.lstm(x_extword_embed, masks, None)
        outputs = outputs.transpose(1, 0)
        outputs = outputs.contiguous()
        outputs = outputs * masks.unsqueeze(-1)
        #if self.training:
            #outputs = drop_sequence_sharedmask(outputs, self.config.dropout_mlp)
        return outputs
