from modules.Layer import *

class BiSent2Sent(nn.Module):
    def __init__(self, vocab, config):
        super(BiSent2Sent, self).__init__()
        self.config = config

    def forward(self, bisent_extword_embed,
                padding, bisent2pre_sent_offset, bisent2sent_offset):
        bisent_extword_embed = torch.cat([bisent_extword_embed, padding], dim=2)

        _, _, _, hidden_size = bisent_extword_embed.size()

        bisent2pre_sent_offset = bisent2pre_sent_offset.unsqueeze(-1).expand(-1, -1, -1, hidden_size)
        bisent2sent_offset = bisent2sent_offset.unsqueeze(-1).expand(-1, -1, -1, hidden_size)

        pre_sent_extword_embed = torch.gather(bisent_extword_embed, dim=-2, index=bisent2pre_sent_offset)
        sent_extword_embed = torch.gather(bisent_extword_embed, dim=-2, index=bisent2sent_offset)

        plist = list(torch.split(pre_sent_extword_embed, dim=1, split_size_or_sections=1))
        plist.append(plist.pop(0))

        pre_sent_extword_embed = torch.cat(plist, dim=1)

        extword_embed = (pre_sent_extword_embed + sent_extword_embed) / 2
        return extword_embed



