from modules.Layer import *
from modules.ScaleMix import *

class Sent2Span(nn.Module):
    def __init__(self, vocab, config):
        super(Sent2Span, self).__init__()
        self.config = config


    def forward(self, x_embed, padding, sent2span_index):
        # x = (batch size, sequence length, dimension of embedding)
        x_embed = torch.cat([x_embed, padding], dim=2)

        batch_size, sent_num, sent_len, hidden_size = x_embed.size()

        x_embed = x_embed.view(batch_size, sent_num * sent_len, hidden_size)

        _, edu_num, edu_len = sent2span_index.size()
        sent2span_index = sent2span_index.unsqueeze(-1).expand(-1, -1, -1, hidden_size)
        sent2span_index = sent2span_index.view(batch_size, -1, hidden_size)

        x_embed = torch.gather(x_embed, dim=1, index=sent2span_index)

        x_embed = x_embed.view(batch_size, edu_num, edu_len, -1)
        #if self.config.max_edu_len < edu_len:
            #x_embed_list = x_embed.split(split_size=1, dim=2)
            #x_embed = torch.cat(x_embed_list[:self.config.max_edu_len], dim=2)
        return x_embed