from modules.Layer import *

class EDULSTM(nn.Module):

    def __init__(self, vocab, config):
        super(EDULSTM, self).__init__()
        self.config = config

        #self.edu_types_embed = nn.Embedding(vocab.EDUtype_size, config.edu_type_dims, padding_idx=0)
        #edu_type_init = np.random.randn(vocab.EDUtype_size, config.edu_type_dims).astype(np.float32)
        #self.edu_types_embed.weight.data.copy_(torch.from_numpy(edu_type_init))

        self.lstm = MyLSTM(
            input_size=config.lstm_hiddens * 2,
            hidden_size=config.lstm_hiddens,
            num_layers=config.lstm_layers,
            batch_first=True,
            bidirectional=True,
            dropout_in = config.dropout_lstm_input,
            dropout_out=config.dropout_lstm_hidden,
        )

    def forward(self, word_represents, masks, word_denominator):
        batch_size, EDU_num, EDU_len, hidden = word_represents.size()
        word_represents = word_represents.view(batch_size * EDU_num, EDU_len, hidden)
        EDU_lstm_input = AvgPooling(word_represents, word_denominator.view(-1))

        #EDU_lstm_input = EDU_lstm_input + pEDU_lstm_input

        EDU_lstm_input = EDU_lstm_input.view(batch_size, EDU_num, hidden)
        #type_embs = self.edu_types_embed(edu_types)
        #EDU_lstm_input = torch.cat([EDU_lstm_input, type_embs], -1)

        #EDU_lstm_input = F.avg_pool1d(word_represents.transpose(2, 1), kernel_size=EDU_len).squeeze(-1)
        #EDU_lstm_input = EDU_lstm_input.view(batch_size, EDU_num, -1)
        outputs, _ = self.lstm(EDU_lstm_input, masks, None)
        outputs = outputs.transpose(1, 0)


        #if self.training:
            #outputs = drop_sequence_sharedmask(outputs, self.config.dropout_mlp)

        return outputs
