# Copyright (c) Facebook, Inc. and its affiliates.

# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.

from typing import List

import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.distributions.categorical import Categorical


class Receiver(nn.Module):
    def __init__(self, n_outputs, n_hidden):
        super(Receiver, self).__init__()
        self.fc = nn.Linear(n_hidden, n_outputs)
    
    def reset_parameters(self):
        self.fc.reset_parameters()

    def forward(self, x, _input, _aux_input):
        return self.fc(x)


class DiscriminationReceiver(nn.Module):
    def __init__(self, out_dim, message_in_dim, visual_in_dim):
        super(DiscriminationReceiver, self).__init__()
        self.message_fc = nn.Linear(message_in_dim, out_dim)
        self.visual_fc = nn.Linear(visual_in_dim, out_dim)
    
    def reset_parameters(self):
        self.message_fc.reset_parameters()
        self.visual_fc.reset_parameters()

    def forward(self, x, input, _aux_input):
        # (B, H) -> (B, H')
        message_rep = self.message_fc(x)

        # (C-1, H), (B, H)
        dist, target = input
        # (C-1, H')
        visual_dist = self.visual_fc(dist)
        # (B, H')
        visual_target = self.visual_fc(target)

        # (B, N, C, H')

        B, H = message_rep.size()
        C, H = visual_dist.size()
        # (B, 1, H') * (1, C-1, H') -> (B, C-1, H') -> (B, C-1)
        dist_logits = (message_rep.view(B, 1, H) * visual_dist.unsqueeze(0)).sum(-1)
        # (B, H') * (B, H') -> (B, H') -> (B)
        target_logits = (message_rep * visual_target).sum(-1)

        # (B, C-1), (B)
        # (B) -> (B, 1)
        target_logits = target_logits.unsqueeze(-1)
        # (B, 1), (B, C-1), -> (B, C)
        logits = torch.cat([target_logits, dist_logits], dim=-1)

        return logits


class ReceiverLogProbWrapper(nn.Module):
    def __init__(self, receiver, n_attributes, n_values, att_idx):
        super().__init__()
        self.receiver = receiver
        self.n_attributes = n_attributes
        self.n_values = n_values
        self.att_idx = att_idx
        
    def forward(self, message, label, *args):
        # (B, A, V) -> (B, A) -> (B, A')
        label = label.view(-1, self.n_attributes, self.n_values)[:, self.att_idx]
        value_idx = torch.argmax(label, dim=-1)

        # (B, A, V) -> (B, A', V) -> (B, A') -> B
        output = self.receiver(message)[0].view(-1, self.n_attributes, self.n_values)[:, self.att_idx]
        log_prob = torch.log_softmax(output, dim=-1)
        log_prob = torch.gather(log_prob, -1, value_idx.unsqueeze(-1)).squeeze(-1)
        log_prob = log_prob.sum(-1)

        return output, log_prob


class DiscriminationReceiverLogProbWrapper(nn.Module):
    def __init__(self, receiver):
        super().__init__()
        self.receiver = receiver
        
    def forward(self, message, label, receiver_input, aux, *args):
        # (B, C) -> (B)
        # message: [B], receiver_input: (1, C-1, H), sender_input: (B, H)
        output = self.receiver(message, (receiver_input.squeeze(0), aux['sender_input']))[0]
        # log_prob: (B, C), label: (B, 1)
        log_prob = torch.log_softmax(output, dim=-1)
        log_prob = torch.gather(log_prob, -1, label).squeeze(-1)

        return output, log_prob


class MultiRnnReceiverDeterministic(nn.Module):
    def __init__(self, receivers: List):
        super().__init__()
        self.receivers = nn.ModuleList(receivers)

    def forward(self, message, input=None, aux_input=None, lengths=None):
        outputs = []
        for receiver in self.receivers:
            outputs.append(
                receiver(
                    message,
                    input,
                    aux_input,
                    lengths,
                )
            )

        # (n_receivers * B, ...)
        agent_output = torch.cat([output[0] for output in outputs], dim=0)
        logits = torch.cat([output[1] for output in outputs], dim=0)
        entropy = torch.cat([output[2] for output in outputs], dim=0)
        return agent_output, logits, entropy


class MultiDiscriminationRnnReceiverDeterministic(nn.Module):
    def __init__(self, receivers: List):
        super().__init__()
        self.receivers = nn.ModuleList(receivers)

    def forward(self, message, input=None, aux_input=None, lengths=None):
        # (R, C-1, H), (B, H)
        dist = input
        target = aux_input['sender_input']
        outputs = []
        for receiver, x in zip(self.receivers, dist):
            outputs.append(
                receiver(
                    message,
                    (x, target), # (C-1, H), (B, H) -> (B, C)
                    aux_input,
                    lengths,
                )
            )

        # (R * B, C)
        agent_output = torch.cat([output[0] for output in outputs], dim=0)
        # (R * B)
        logits = torch.cat([output[1] for output in outputs], dim=0)
        # (R * B)
        entropy = torch.cat([output[2] for output in outputs], dim=0)
        return agent_output, logits, entropy


class MultiDiscriminationRnnReceiverReinforce(nn.Module):
    def __init__(self, receivers: List):
        super().__init__()
        self.receivers = nn.ModuleList(receivers)
    
    def forward(self, message, input=None, aux_input=None, lengths=None):
        # (R, C-1, H), (B, H)
        dist = input
        target = aux_input['sender_input']
        outputs = []
        for receiver, x in zip(self.receivers, dist):
            outputs.append(
                receiver(
                    message,
                    (x, target), # (C-1, H), (B, H) -> (B, C)
                    aux_input,
                    lengths,
                )
            )

        # (R * B, C)
        agent_output = torch.cat([output[0] for output in outputs], dim=0)

        # (R * B, C) -> (R * B)
        log_prob = torch.log_softmax(agent_output, dim=-1)
        dists = Categorical(logits=log_prob)
        actions = dists.sample()
        log_probs = dists.log_prob(actions)

        entropy = torch.cat([output[2] for output in outputs], dim=0)

        return agent_output, log_probs, entropy


class MultiRnnReceiverReinforce(nn.Module):
    def __init__(self, receivers: List, att_indices, n_attribuits):
        super().__init__()
        self.receivers = nn.ModuleList(receivers)
        self.att_indices = att_indices
        self.n_attributes = n_attribuits
    
    def forward(self, message, input=None, aux_input=None, lengths=None):
        outputs = []
        for receiver in self.receivers:
            outputs.append(
                receiver(
                    message,
                    input,
                    aux_input,
                    lengths,
                )
            )

        agent_output = torch.cat([output[0] for output in outputs], dim=0)
        log_probs = []
        for output, idx in zip(outputs, self.att_indices):
            output = output[0]
            output = output.view(output.size(0), self.n_attributes, -1)
            logits = output[:, idx]
            log_prob = torch.log_softmax(logits, dim=-1)
            dists = Categorical(logits=log_prob)
            actions = dists.sample()
            log_prob = dists.log_prob(actions).mean(-1)
            log_probs.append(log_prob)
        log_probs = torch.cat(log_probs)
        entropy = torch.cat([output[2] for output in outputs], dim=0)
        return agent_output, log_probs, entropy


class Sender(nn.Module):
    def __init__(self, n_inputs, n_hidden):
        super(Sender, self).__init__()
        self.fc1 = nn.Linear(n_inputs, n_hidden)
    
    def reset_parameters(self):
        self.fc1.reset_parameters()

    def forward(self, x, _aux_input):
        x = self.fc1(x)
        return x


class SenderExpansionWrapper(nn.Module):
    def __init__(self, sender, ratio):
        super().__init__()
        self.sender = sender
        self.ratio = ratio

    def forward(self, x, _aux_input):
        x = self.sender(x, _aux_input)
        x = x.repeat(self.ratio, 1)
        return x


class NonLinearReceiver(nn.Module):
    def __init__(self, n_outputs, vocab_size, n_hidden, max_length):
        super().__init__()
        self.max_length = max_length
        self.vocab_size = vocab_size

        self.fc_1 = nn.Linear(vocab_size * max_length, n_hidden)
        self.fc_2 = nn.Linear(n_hidden, n_outputs)

        self.diagonal_embedding = nn.Embedding(vocab_size, vocab_size)
        nn.init.eye_(self.diagonal_embedding.weight)

    def forward(self, x, _input, _aux_input):
        with torch.no_grad():
            x = self.diagonal_embedding(x).view(x.size(0), -1)

        x = self.fc_1(x)
        x = F.leaky_relu(x)
        x = self.fc_2(x)

        zeros = torch.zeros(x.size(0), device=x.device)
        return x, zeros, zeros


class Freezer(nn.Module):
    def __init__(self, wrapped):
        super().__init__()
        self.wrapped = wrapped
        self.wrapped.eval()

    def train(self, mode):
        pass

    def forward(self, *input):
        with torch.no_grad():
            r = self.wrapped(*input)
        return r


class PlusOneWrapper(nn.Module):
    def __init__(self, wrapped, preserve_eos: bool = True):
        super().__init__()
        self.wrapped = wrapped
        self.preserve_eos = preserve_eos

    def forward(self, *input):
        r1, r2, r3 = self.wrapped(*input)
        if self.preserve_eos:
            r1[:, :-1] += 1
            return r1, r2, r3
        else:
            return r1 + 1, r2, r3

    def reset_parameters(self):
        self.wrapped.reset_parameters()


class VisualAttributeValues(nn.Module):
    def __init__(
        self, 
        filter_size: int = 8, 
        kernel_size: int = 3, 
        stride: int = 1,
        out_dim: int = 500,
        dropout_rate: float = 0.0,
    ):
        super().__init__()
        self.conv_1 = torch.nn.Conv2d(
            in_channels=3, 
            out_channels=filter_size, 
            kernel_size=kernel_size,
            stride=stride,
            padding='same',
        )
        self.conv_2 = torch.nn.Conv2d(
            in_channels=filter_size, 
            out_channels=filter_size, 
            kernel_size=kernel_size,
            stride=stride,
            padding='same',
        )
        self.max_pool = torch.nn.MaxPool2d(
            kernel_size=2,
            stride=2,
        )
        self.dropout = nn.Dropout(dropout_rate)

        in_dim = int(filter_size * 16 * 16)
        self.linear = nn.Linear(in_dim, out_dim)
    
    def forward(self, x, *args, **kwargs):
        x = self.conv_1(x)
        x = self.max_pool(F.elu(x))
        x = self.dropout(x)
        x = self.conv_2(x)
        x = self.max_pool(F.elu(x))
        # (B, N, C) -> (B, N * C)
        x = torch.flatten(x, 1, -1)
        x = self.dropout(x)
        x = self.linear(x)
        x = F.elu(x)
        x = self.dropout(x)

        return x
    