from torch import nn


class EntityEncoder(nn.Module):

    def __init__(self):
        super(EntityEncoder, self).__init__()

    def encode_entity(self, entity_id, doc, encoded_mentions):
        """Encodes an entity into a vector, if possible.

        If the entity is a post author, return the author username.

        Otherwise, return the average of the encoded mentions referring to the
        entity.
        """

        entity = doc.evaluator_ere.entities[entity_id]
        enc_mentions = []

        for mention in entity.mentions:
            enc_mention = encoded_mentions[mention.mention_id]

            if hasattr(enc_mention, 'lower'):
                # this is a post author. Postpone encoding until paired.
                return enc_mention

            enc_mentions.append(enc_mention)
        return sum(enc_mentions) / len(enc_mentions)
