from modules.Layer import *
from modules.ScaleMix import *

class PretrainedWordEncoder(nn.Module):
    def __init__(self, config, model, input_dims, layer_num):
        super(PretrainedWordEncoder, self).__init__()
        self.config = config
        self.word_dims = config.word_dims
        self.layer_num = layer_num
        self.input_dims = input_dims
        self.pretrain_model = model
        self.model = model

        self.mlp_words = nn.ModuleList([NonLinear(self.input_dims, self.word_dims, activation=GELU()) \
                                        for i in range(self.layer_num)])
        self.rescale = ScalarMix(mixture_size=self.layer_num)

    def forward(self, bert_indice_input, bert_segments_ids, bert_piece_ids, bert_mask=None):
        batch_size, sent_num, sent_len, bert_len = bert_piece_ids.size()

        bert_indice_input = torch.split(bert_indice_input, dim=0, split_size_or_sections=1)
        bert_segments_ids = torch.split(bert_segments_ids, dim=0, split_size_or_sections=1)
        bert_piece_ids = torch.split(bert_piece_ids, dim=0, split_size_or_sections=1)
        bert_mask = torch.split(bert_mask, dim=0, split_size_or_sections=1)

        with torch.no_grad():
            bert_outputs = [[] for idy in range(self.layer_num)]
            for idx in range(batch_size):
                outputs = self.pretrain_model(bert_indice_input[idx].view(sent_num, bert_len),
                                              bert_segments_ids[idx].view(sent_num, bert_len),
                                              bert_piece_ids[idx].view(sent_num, sent_len, bert_len),
                                              bert_mask[idx].view(sent_num, bert_len))
                assert len(outputs) == self.layer_num
                for idy in range(len(outputs)):
                    outputs[idy] = outputs[idy].unsqueeze(0)
                    bert_outputs[idy].append(outputs[idy])
            for idx in range(len(bert_outputs)):
                bert_outputs[idx] = torch.cat(bert_outputs[idx], dim=0)

        proj_hiddens = []
        for idx, input in enumerate(bert_outputs):
            cur_hidden = self.mlp_words[idx](input)
            proj_hiddens.append(cur_hidden)
        x_embed = self.rescale(proj_hiddens)
        return x_embed
