import torch
from autoencoder import Encoder
from autoencoders.base_encoder import BaseEncoder

class BoWEncoder(BaseEncoder):
    def __init__(self, config):
        super(BoWEncoder, self).__init__(config)

        self.reduction = config.reduction
            
    def _to_hidden_representation(self, embedded, lengths):
        h = embedded
        
        if self.reduction == "sum":        
            h = h.sum(dim=1)
        elif self.reduction == "mean":
            h = h.sum(dim=1)
            h = h / lengths.unsqueeze(1).float()
        elif self.reduction == "max":
            h[h == 0] = -1e9
            h = torch.max(h, dim = 1)[0]

        return h