import torch
from autoencoders.base_encoder import BaseEncoder
import numpy as np
import math
from autoencoder import Encoder

class ConcatEncoder(Encoder):
    def __init__(self, config, encoder_list):
        super(ConcatEncoder, self).__init__(config)

        self.encoder_list = torch.nn.ModuleList(encoder_list)
            
    def encode(self, x, lengths, train=False, reparameterize = True):
        embedding_list = []
        for enc in self.encoder_list:
            emb = enc.encode(x, lengths, train=train, reparameterize = reparameterize)
            embedding_list.append(emb)
            
        embedding = torch.cat(embedding_list, dim = -1)
        return embedding