from modules.Layer import *

class Decoder(nn.Module):
    def __init__(self, vocab, config):
        super(Decoder, self).__init__()

        self.config = config

        self.mlp = NonLinear(input_size=config.lstm_hiddens * 2 * 4, ## four feats
                             hidden_size=config.hidden_size,
                             activation=nn.LeakyReLU(0.1))

        self.output = nn.Linear(in_features=config.hidden_size,
                                out_features=len(vocab._id2ac),
                                bias=False)


    def forward(self, hidden_state, cut=None):
        mlp_hidden = self.mlp(hidden_state)
        if self.training:
            mlp_hidden = drop_sequence_sharedmask(mlp_hidden, self.config.dropout_mlp)
        action_score = self.output(mlp_hidden)
        if cut is not None:
            action_score = action_score + cut
        return action_score



