from typing import List
from collections import defaultdict

import torch
import torch.nn.functional as F


class DiffLoss(torch.nn.Module):
    def __init__(
        self, 
        n_attributes, 
        n_values, 
        generalization=False, 
        loss_type: str = 'cross_entroopy',
        cross_entropy_weight: float = 1.0,
    ):
        super().__init__()
        self.n_attributes = n_attributes
        self.n_values = n_values
        self.test_generalization = generalization
        self.loss_type = loss_type
        self.cross_entropy_weight = cross_entropy_weight

    def forward(
        self,
        sender_input,
        _message,
        _receiver_input,
        receiver_output,
        labels,
        _aux_input,
    ):
        if labels is None:
            batch_size = sender_input.size(0)
            sender_input = sender_input.view(batch_size, self.n_attributes, self.n_values)
        else:
            batch_size = labels.size(0)
            sender_input = labels.view(batch_size, self.n_attributes, self.n_values)

        receiver_output = receiver_output.view(
            batch_size, self.n_attributes, self.n_values
        )

        if self.test_generalization:
            acc, acc_or, loss = 0.0, 0.0, 0.0

            for attr in range(self.n_attributes):
                zero_index = torch.nonzero(sender_input[:, attr, 0]).squeeze()
                masked_size = zero_index.size(0)
                masked_input = torch.index_select(sender_input, 0, zero_index)
                masked_output = torch.index_select(receiver_output, 0, zero_index)

                no_attribute_input = torch.cat(
                    [masked_input[:, :attr, :], masked_input[:, attr + 1 :, :]], dim=1
                )
                no_attribute_output = torch.cat(
                    [masked_output[:, :attr, :], masked_output[:, attr + 1 :, :]], dim=1
                )

                n_attributes = self.n_attributes - 1
                attr_acc = (
                    (
                        (
                            no_attribute_output.argmax(dim=-1)
                            == no_attribute_input.argmax(dim=-1)
                        ).sum(dim=1)
                        == n_attributes
                    )
                    .float()
                    .mean()
                )
                acc += attr_acc

                attr_acc_or = (
                    (
                        no_attribute_output.argmax(dim=-1)
                        == no_attribute_input.argmax(dim=-1)
                    )
                    .float()
                    .mean()
                )
                acc_or += attr_acc_or
                labels = no_attribute_input.argmax(dim=-1).view(
                    masked_size * n_attributes
                )
                predictions = no_attribute_output.view(
                    masked_size * n_attributes, self.n_values
                )
                # NB: THIS LOSS IS NOT SUITABLY SHAPED TO BE USED IN REINFORCE TRAINING!
                loss += F.cross_entropy(predictions, labels, reduction="mean")

            acc /= self.n_attributes
            acc_or /= self.n_attributes
        else:
            acc = (
                torch.sum(
                    (
                        receiver_output.argmax(dim=-1) == sender_input.argmax(dim=-1)
                    ).detach(),
                    dim=1,
                )
                == self.n_attributes
            ).float()
            acc_or = (
                receiver_output.argmax(dim=-1) == sender_input.argmax(dim=-1)
            ).float().mean(-1)

            receiver_output = receiver_output.view(
                batch_size * self.n_attributes, self.n_values
            )
            labels = sender_input.argmax(dim=-1).view(batch_size * self.n_attributes)
            
            if self.loss_type == 'cross_entropy':
                loss = (
                    F.cross_entropy(receiver_output, labels, reduction="none")
                    .view(batch_size, self.n_attributes)
                    .mean(dim=-1)
                )
            elif self.loss_type == 'task_success':
                loss = -acc
            elif self.loss_type == 'mixed':
                loss = (
                    F.cross_entropy(receiver_output, labels, reduction="none")
                    .view(batch_size, self.n_attributes)
                    .mean(dim=-1)
                )
                loss = self.cross_entropy_weight * loss + (-acc)
            else:
                ValueError(f'Invalid loss type {self.loss_type}')

        return loss, {
            "acc": acc, 
            "acc_or": acc_or, 
            f'acc_{self.n_attributes}': acc,
            f'acc_or_{self.n_attributes}': acc_or,
        }


class MultiDiffLoss(torch.nn.Module):
    def __init__(
        self, 
        att_indices: List, 
        n_values: int, 
        loss_type: str = 'cross_entropy',
        group_size: int = 1,
        cross_entropy_weight: float = 1.0,
    ):
        super().__init__()
        self.att_indices = att_indices
        self.n_values = n_values
        self.loss_type = loss_type
        self.group_size = group_size
        self.cross_entropy_weight = cross_entropy_weight
        self.losses = [
            DiffLoss(
                n_attributes=len(att_idx), 
                n_values=self.n_values, 
                generalization=False,
                loss_type=self.loss_type if self.group_size <= 1 else 'cross_entropy',
            )
            for att_idx in self.att_indices
        ]
    
    def forward(
        self,
        sender_input,
        _message,
        _receiver_input,
        receiver_output,
        labels,
        _aux_input,
    ):
        outputs = []
        if labels is None:
            labels = sender_input
        labels = labels.view(sender_input.size(0), -1, self.n_values)
        
        receiver_output = receiver_output.view(
            len(self.losses), 
            receiver_output.size(0) // len(self.losses),
            -1,
        )

        outputs = defaultdict(list)
        for rec_output, att_idx, loss in zip(
            receiver_output, 
            self.att_indices, 
            self.losses
        ):
            rec_output = rec_output.view(rec_output.size(0), -1, self.n_values)
            p_rec_output = rec_output[:, att_idx]
            p_labels = labels[:, att_idx]
            p_rec_output = p_rec_output.view(rec_output.size(0), -1)
            p_labels = p_labels.view(labels.size(0), -1)
            loss_val, aux = loss(
                _message=_message,
                _receiver_input=_receiver_input,
                sender_input=sender_input,
                receiver_output=p_rec_output,
                labels=p_labels,
                _aux_input=_aux_input,
            )
            outputs['loss'].append(loss_val)
            for key, value in aux.items():
                outputs[key].append(value)

        aux = {}
        for key, value in outputs.items():
            aux[key] = torch.cat(value)

        if self.group_size > 1:
            rnd_idx = torch.randperm(len(self.losses))
            rnd_idx = rnd_idx.split(self.group_size)
            for idx in rnd_idx:
                # (N * B) -> (N, B) -> (N', B)
                acc = aux['acc'].view(len(self.losses), -1)[idx]
                # (N', B) -> (B)
                joint_acc = torch.all(acc, dim=0).float()
                # (N', B) := (1, B)
                aux['acc'].view(len(self.losses), -1)[idx] = joint_acc.unsqueeze(0)
                loss = aux['loss'].view(len(self.losses), -1)[idx]
                if self.loss_type == 'mixed':
                    # (N', B) += (1, B)
                    aux['loss'].view(len(self.losses), -1)[idx] *= self.cross_entropy_weight
                    aux['loss'].view(len(self.losses), -1)[idx] += (-joint_acc.unsqueeze(0))
                elif self.loss_type == 'task_success':
                    aux['loss'].view(len(self.losses), -1)[idx] = (-joint_acc.unsqueeze(0))

        loss = aux.pop('loss')

        return loss, aux


class ContinuousDiffLoss(torch.nn.Module):
    def __init__(
        self, 
        loss_type: str = 'mse',
    ):
        super().__init__()
        self.loss_type = loss_type

    def forward(
        self,
        sender_input,
        _message,
        _receiver_input,
        receiver_output,
        _labels,
        _aux_input,
    ):
        batch_size = sender_input.size(0)
        n_att = sender_input.size(-1)

        if self.loss_type == 'mse':
            loss = F.mse_loss(receiver_output, sender_input, reduction='none').mean(-1)
        else:
            ValueError(f'Invalid loss type {self.loss_type}')

        return loss, {f'loss_{n_att}': loss}


class MultiContinuousDiffLoss(torch.nn.Module):
    def __init__(
        self, 
        att_indices: List, 
        loss_type: str = 'mse',
        group_size: int = 1,
    ):
        super().__init__()
        self.att_indices = att_indices
        self.loss_type = loss_type
        self.group_size = group_size
        self.losses = [
            ContinuousDiffLoss(loss_type=self.loss_type)
            for att_idx in self.att_indices
        ]
    
    def forward(
        self,
        sender_input,
        _message,
        _receiver_input,
        receiver_output,
        _labels,
        _aux_input,
    ):
        outputs = []
        # (R * B, H) -> (R, B, H)
        receiver_output = receiver_output.view(
            len(self.losses), 
            receiver_output.size(0) // len(self.losses),
            -1,
        )

        outputs = defaultdict(list)
        for rec_output, att_idx, loss in zip(receiver_output, self.att_indices, self.losses):
            p_rec_output = rec_output[:, att_idx]
            p_sen_input = sender_input[:, att_idx]
            loss_val, aux = loss(
                _message=_message,
                _receiver_input=_receiver_input,
                sender_input=p_sen_input,
                receiver_output=p_rec_output,
                _labels=_labels,
                _aux_input=_aux_input,
            )
            outputs['loss'].append(loss_val)
            for key, value in aux.items():
                outputs[key].append(value)

        aux = {}
        for key, value in outputs.items():
            aux[key] = torch.cat(value)

        if self.group_size > 1:
            rnd_idx = torch.randperm(len(self.losses))
            rnd_idx = rnd_idx.split(self.group_size)
            for idx in rnd_idx:
                # (N * B) -> (N, B) -> (N', B)
                acc = aux['acc'].view(len(self.losses), -1)[idx]
                # (N', B) -> (B)
                joint_acc = torch.all(acc, dim=0).float()
                # (N', B) := (1, B)
                aux['acc'].view(len(self.losses), -1)[idx] = joint_acc.unsqueeze(0)
                loss = aux['loss'].view(len(self.losses), -1)[idx]
                if self.loss_type == 'mixed':
                    # (N', B) += (1, B)
                    aux['loss'].view(len(self.losses), -1)[idx] *= self.cross_entropy_weight
                    aux['loss'].view(len(self.losses), -1)[idx] += (-joint_acc.unsqueeze(0))
                elif self.loss_type == 'task_success':
                    aux['loss'].view(len(self.losses), -1)[idx] = (-joint_acc.unsqueeze(0))

        loss = aux.pop('loss')

        return loss, aux


class DiscriminationLoss(torch.nn.Module):
    def __init__(
        self, 
        loss_type: str = 'cross_entroopy',
        cross_entropy_weight: float = 1.0,
    ):
        super().__init__()
        self.loss_type = loss_type
        self.cross_entropy_weight = cross_entropy_weight

    def forward(
        self,
        sender_input,
        _message,
        _receiver_input,
        receiver_output,
        labels,
        _aux_input,
    ):
        acc = (torch.argmax(receiver_output, dim=-1) == labels).float()
        acc_or = acc
        
        if self.loss_type == 'cross_entropy':
            loss = F.cross_entropy(receiver_output, labels, reduction='none')
        elif self.loss_type == 'task_success':
            loss = -acc
        elif self.loss_type == 'mixed':
            loss = F.cross_entropy(receiver_output, labels, reduction='none')
            loss = self.cross_entropy_weight * loss + (-acc)
        elif self.loss_type == 'mixed_wo_ce':
            loss = F.cross_entropy(receiver_output, labels, reduction='none').detach()
            loss = self.cross_entropy_weight * loss + (-acc)
        else:
            ValueError(f'Invalid loss type {self.loss_type}')

        return loss, {
            "acc": acc, 
        }


class MultiDiscriminationLoss(torch.nn.Module):
    def __init__(
        self, 
        att_indices: List, 
        loss_type: str = 'cross_entropy',
        group_size: int = 1,
        cross_entropy_weight: float = 1.0,
        ce_group_mean: bool = False,
    ):
        super().__init__()
        self.att_indices = att_indices
        self.loss_type = loss_type
        self.group_size = group_size
        self.cross_entropy_weight = cross_entropy_weight
        self.ce_group_mean = ce_group_mean
        self.loss = DiscriminationLoss(
            loss_type=self.loss_type if self.group_size <= 1 else 'cross_entropy',
        )
    
    def forward(
        self,
        sender_input,
        _message,
        _receiver_input,
        receiver_output,
        labels,
        _aux_input,
    ):
        outputs = defaultdict(list)
        # (B, R) -> (R, B)
        labels = labels.permute(1, 0)

        receiver_output = receiver_output.view(
            len(self.att_indices), -1, receiver_output.size(1)
        )
        # (R, B, C) -> (B, C)
        for rec_output, label, att_idx in zip(receiver_output, labels, self.att_indices):
            loss_val, aux = self.loss(
                _message=_message,
                _receiver_input=_receiver_input,
                sender_input=sender_input,
                receiver_output=rec_output, # (B, C)
                labels=label, # (B)
                _aux_input=_aux_input,
            )
            outputs['loss'].append(loss_val)
            aux[f'acc_{len(att_idx)}'] = torch.clone(aux['acc'])
            for key, value in aux.items():
                outputs[key].append(value)

        aux = {}
        # [R, (B)] -> (R * B)
        for key, value in outputs.items():
            aux[key] = torch.cat(value)

        if self.group_size > 1:
            rnd_idx = torch.randperm(len(self.att_indices))
            rnd_idx = rnd_idx.split(self.group_size)
            for idx in rnd_idx:
                # (N * B) -> (N, B) -> (N', B)
                acc = aux['acc'].view(len(self.att_indices), -1)[idx]
                # (N', B) -> (B)
                joint_acc = torch.all(acc, dim=0).float()
                # (N', B) := (1, B)
                aux['acc'].view(len(self.att_indices), -1)[idx] = joint_acc.unsqueeze(0)
                loss = aux['loss'].view(len(self.att_indices), -1)[idx]
                if self.loss_type in ['mixed', 'mixed_wo_ce']:
                    if self.ce_group_mean:
                        # (N', B) -> (N', 1)
                        mean_loss = loss.mean(-1, keepdim=True)
                        aux['loss'].view(len(self.att_indices), -1)[idx] = mean_loss

                    # (N', B) += (1, B)
                    aux['loss'].view(len(self.att_indices), -1)[idx] *= self.cross_entropy_weight
                    aux['loss'].view(len(self.att_indices), -1)[idx] += (-joint_acc.unsqueeze(0))
                elif self.loss_type == 'task_success':
                    aux['loss'].view(len(self.att_indices), -1)[idx] = (-joint_acc.unsqueeze(0))

        loss = aux.pop('loss')
        if self.loss_type == 'mixed_wo_ce':
            loss = loss.detach()

        return loss, aux