from torch import nn


class EntityMentionEncoder(nn.Module):

    def __init__(self):
        super(EntityMentionEncoder, self).__init__()

    def encode_entity_mention(self, mention_id, doc, encoded_doc):
        """Encodes an entity mention into a vector, if possible.

        If the mention is a post author, return the author username.

        Otherwise, return the average of the encoded tokens within the
        mention's span.
        """
        mention = doc.evaluator_ere.entity_mentions[mention_id]
        offset = mention.offset
        length = mention.length

        start, end = doc.offset_to_flat_tokens(offset, length)

        if start == end:  # no tokens in span => mention is a DF username
            return mention.mention_text

        # otherwise, mention really is a span of text.

        return encoded_doc[start:end].mean(dim=0).squeeze()
