import torch
from torch import nn
from torch.nn import init


class EventEncoder(nn.Module):

    def __init__(self, hidden_dim, resolve_author_function, event_metadata, bidirectional=False):
        super(EventEncoder, self).__init__()
        self.hidden_dim = hidden_dim
        if bidirectional:
            if hidden_dim % 2 == 1:
                raise ValueError
            else:
                hidden_dim = int(hidden_dim / 2)
        self.pair_dim_ = 2 * hidden_dim
        self.concat_dim = 3 * hidden_dim
        self.resolve_author = resolve_author_function
        # TODO Consider if num layers matters for event_pooling_rnn and definitely test bidirectionality
        self.event_pooling_rnn = nn.GRU(self.pair_dim_, hidden_dim, bidirectional=bidirectional)
        self.concat = nn.Linear(self.concat_dim, self.hidden_dim)
        (self.type_list, self.subtype_list, self.role_list) = event_metadata
        self.type_dict = {}
        self.subtype_dict = {}
        self.role_dict = {}
        for event_type in self.type_list:
            type_vector = nn.Parameter(torch.Tensor(self.hidden_dim))
            init.uniform_(type_vector, -0.01, 0.01)
            self.register_parameter(event_type, type_vector)
            type_map = nn.Linear(self.hidden_dim, self.hidden_dim)
            self.type_dict[event_type] = (type_vector, type_map)
        for event_subtype in self.subtype_list:
            subtype_vector = nn.Parameter(torch.Tensor(self.hidden_dim))
            init.uniform_(subtype_vector, -0.01, 0.01)
            self.register_parameter(event_subtype, subtype_vector)
            subtype_map = nn.Linear(self.hidden_dim, self.hidden_dim)
            self.subtype_dict[event_subtype] = (subtype_vector, subtype_map)
        for role in self.role_list:
            role_vector = nn.Parameter(torch.Tensor(self.hidden_dim))
            init.uniform_(role_vector, -0.01, 0.01)
            self.register_parameter(role, role_vector)
            role_map = nn.Linear(self.pair_dim_, self.hidden_dim)
            self.role_dict[role] = (role_vector, role_map)

    def forward(self, mention_id, enc_mentions, doc, encoded_doc, enc_mention=""):
        mention = doc.evaluator_ere.event_mentions[mention_id]
        offset, length = doc.event_mention_to_span(mention)
        start, end = doc.offset_to_flat_tokens(offset, length)
        post_author = doc.get_author(offset)
        event_type = mention.event_type
        event_subtype = mention.event_subtype
        enc_args = None
        type_vector, type_map = self.type_dict[event_type]
        subtype_vector, subtype_map = self.subtype_dict[event_subtype]

        for arg_id in mention.arguments:
            argument = mention.arguments[arg_id]
            role = argument.role
            role_vector, role_map = self.role_dict[role]

            if arg_id in enc_mentions:
                encoded_mention = enc_mentions[arg_id]
                encoded_mention = self.resolve_author(encoded_value=encoded_mention, post_author=post_author)
            else:
                filler_offset = argument.entity.offset
                filler_length = argument.entity.length
                filler_start, filler_end = doc.offset_to_flat_tokens(filler_offset, filler_length)
                encoded_mention = encoded_doc[filler_start:filler_end].mean(dim=0).squeeze()

            enc_arg = torch.cat((role_vector, encoded_mention), 0)
            enc_arg = enc_arg.unsqueeze(0)
            if enc_args is not None:
                enc_args = torch.cat((enc_args, enc_arg), 0)
            else:
                enc_args = enc_arg
        enc_args = enc_args.unsqueeze(1)
        _, enc_args = self.event_pooling_rnn(enc_args)

        if enc_mention == "naive":
            assert start != end
            return encoded_doc[start:end].mean(dim=0).squeeze()

        if enc_mention == "vector":
            enc_args = enc_args.squeeze()
            encoded_mention_big = torch.cat((type_vector, subtype_vector, enc_args), 0)
            encoded_mention = self.concat(encoded_mention_big.unsqueeze(0))
            encoded_mention = encoded_mention.squeeze()
            return torch.tanh(encoded_mention)
        elif enc_mention == "affine":
            enc_args = enc_args.squeeze(1)
            encoded_type = type_map(enc_args)
            encoded_subtype = subtype_map(enc_args)
            encoded_mention = (torch.tanh(encoded_type + encoded_subtype)).squeeze()
            return encoded_mention
        else:
            raise NotImplementedError
