import torch
from torch import nn
from torch.nn import init


class RelationEncoder(nn.Module):

    def __init__(self, hidden_dim, resolve_author_function, relation_metadata):
        super(RelationEncoder, self).__init__()
        self.hidden_dim = hidden_dim
        self.pair_dim_ = 2 * hidden_dim
        self.resolve_author = resolve_author_function
        (self.type_list, self.subtype_list, self.role_list) = relation_metadata
        self.type_dict = {}
        self.subtype_dict = {}
        self.role_dict = {}
        for relation_type in self.type_list:
            type_vector = nn.Parameter(torch.Tensor(self.hidden_dim), requires_grad=True)
            init.uniform_(type_vector, -0.01, 0.01)
            self.register_parameter(relation_type, type_vector)
            type_map = nn.Linear(self.pair_dim_, self.hidden_dim)
            self.type_dict[relation_type] = (type_vector, type_map)
        for relation_subtype in self.subtype_list:
            subtype_vector = nn.Parameter(torch.Tensor(self.hidden_dim), requires_grad=True)
            init.uniform_(subtype_vector, -0.01, 0.01)
            self.register_parameter(relation_subtype, subtype_vector)
            subtype_map = nn.Linear(self.pair_dim_, self.hidden_dim)
            self.subtype_dict[relation_subtype] = (subtype_vector, subtype_map)
        for role in self.role_list:
            role_vector = nn.Parameter(torch.Tensor(self.hidden_dim), requires_grad=True)
            init.uniform_(role_vector, -0.01, 0.01)
            self.register_parameter(role, role_vector)
            role_map = nn.Linear(self.hidden_dim, self.hidden_dim)
            self.role_dict[role] = (role_vector, role_map)

        # Takes vector of form naive_relation_mention-type-subtype
        self.naive_plus = nn.Linear(3 * self.hidden_dim, self.hidden_dim)

        # Takes vector that is concatenated vector of form type-subtype-role1-mention1-role2-mention2
        self.rel_all_in_one = nn.Linear(6 * self.hidden_dim, self.hidden_dim)

        # Takes vector that is concatenated vector of form type-subtype-arg1-arg2
        self.rel_concat_args = nn.Linear(4 * self.hidden_dim, self.hidden_dim)

    def forward(self, mention_id, doc, encoded_doc, enc_mentions, enc_mention=""):
        mention = doc.evaluator_ere.relation_mentions[mention_id]
        arg1 = mention.rel_arg1
        arg1_role = arg1.role
        arg2 = mention.rel_arg2
        arg2_role = arg2.role
        relation_type = mention.relation.relation_type
        relation_subtype = mention.relation.relation_subtype

        # The following two lines are generally not used
        # They are fail-safes if an inappropriate parameter is given
        # So that the original way of encoding relations can be done
        offset, length = doc.relation_mention_to_span(mention)
        start, end = doc.offset_to_flat_tokens(offset, length)

        post_author = doc.get_author(offset)

        if enc_mention == "naive":
            assert start != end
            return encoded_doc[start:end].mean(dim=0).squeeze()

        # Encodes the two mentions corresponding to args 1 and 2
        if arg1.mention_id in enc_mentions:
            enc_mention1 = enc_mentions[arg1.mention_id]
            enc_mention1 = self.resolve_author(encoded_value=enc_mention1, post_author=post_author)
        else:
            filler_offset = arg1.entity.offset
            filler_length = arg1.entity.length
            filler_start, filler_end = doc.offset_to_flat_tokens(filler_offset, filler_length)
            enc_mention1 = encoded_doc[filler_start:filler_end].mean(dim=0).squeeze()

        if arg2.mention_id in enc_mentions:
            enc_mention2 = enc_mentions[arg2.mention_id]
            enc_mention2 = self.resolve_author(encoded_value=enc_mention2, post_author=post_author)
        else:
            filler_offset = arg2.entity.offset
            filler_length = arg2.entity.length
            filler_start, filler_end = doc.offset_to_flat_tokens(filler_offset, filler_length)
            enc_mention2 = encoded_doc[filler_start:filler_end].mean(dim=0).squeeze()

        type_vector, type_map = self.type_dict[relation_type]
        subtype_vector, subtype_map = self.subtype_dict[relation_subtype]
        role1_vector, role1_map = self.role_dict[arg1_role]
        role2_vector, role2_map = self.role_dict[arg2_role]

        if enc_mention == "vector":
            encoded_mention_big = torch.cat((self.type_dict[relation_type][0], self.subtype_dict[relation_subtype][0],
                                             self.role_dict[arg1_role][0],
                                             enc_mention1, self.role_dict[arg2_role][0], enc_mention2), 0)
            encoded_mention = self.rel_all_in_one(encoded_mention_big.unsqueeze(0))
            encoded_mention = encoded_mention.squeeze()
            return torch.tanh(encoded_mention)

        enc_mention1 = enc_mention1.unsqueeze(0)
        encoded_arg1 = role1_map(enc_mention1)
        enc_mention2 = enc_mention2.unsqueeze(0)
        encoded_arg2 = role2_map(enc_mention2)

        if enc_mention == "vectoraffine":
            encoded_mention_big = torch.cat((type_vector, subtype_vector, encoded_arg1.squeeze(), encoded_arg2.squeeze()),0)
            encoded_mention = self.rel_concat_args(encoded_mention_big.unsqueeze(0))
            encoded_mention = encoded_mention.squeeze()
            return torch.tanh(encoded_mention)

        encoded_args = torch.tanh(torch.cat((encoded_arg1, encoded_arg2), 1))
        encoded_type = type_map(encoded_args)
        encoded_subtype = subtype_map(encoded_args)

        if enc_mention == "affine":
            encoded_mention = (torch.tanh(encoded_type + encoded_subtype)).squeeze()
            return encoded_mention

        raise NotImplementedError
