from typing import Tuple, Dict
from collections import Counter

import torch
from torch import nn
from torch.nn import functional as F

from page.const import PAD_ID


class TripletMarginLossOnDistance(nn.Module):
    """
    Computes triplet margin loss on distance matrix.
    """

    def __init__(self, margin: float = 1.0):
        """
        Initialize Triplet Margin Loss.

        :param float margin: Margin to be applied.
        """
        super().__init__()
        self.margin = margin

    def forward(self, distances: torch.Tensor, positives: torch.Tensor) -> torch.Tensor:
        """
        Computes triplet margin loss on distance matrix.

        Since we already computed distance matrix, we only need to compute

        .. math::
            max(0, d(anchor, positive) - d(anchor, negative) + margin).

        We'll use online triplet mining with semi-hard negatives. Thus,

        .. math::
            d_p = max_p{d(anchor, p)}
            d_n^* = min_n{d(anchor, n)} s.t. d(anchor, n) > d_p
            max(0, d_p - d_n^* + margin), where p: positives, n: negatives.

        :param torch.Tensor distances:
            Float tensor that indicates distance between anchor and indicated embeddings.
            Shape [B, M], where B = batch size, M = number of instances compared with anchors
        :param torch.Tensor positives:
            Bool tensor that indicates positive examples. Shape [B, M]

        :return:
            A Float Tensor of computed loss with shape [].
        """
        assert distances.shape == positives.shape

        # Compute max_p{d(anchor, p)}, shape [B]
        max_positive = distances.masked_fill(~positives, 0.0).max(dim=-1).values

        # Compute min_n{d(anchor, n)} s.t. d > d_p, shape [B].
        min_negative = distances.masked_fill(positives & (distances <= max_positive.unsqueeze(-1)),
                                             float('inf')).min(dim=-1).values

        # Return triplet margin loss.
        return (max_positive - min_negative + self.margin).clamp_min(0).mean()


class TupletLoss(nn.Module):
    """
    Computes tuplet ranking loss on distance matrix.
    """

    def forward(self, distances: torch.Tensor, positives: torch.Tensor) -> torch.Tensor:
        """
        Computes triplet margin loss on distance matrix.

        Since we already computed distance matrix, we only need to compute

        .. math::
            max(0, d(anchor, positive) - d(anchor, negative) + margin).

        We'll use online triplet mining with batch hard strategy. Thus,

        .. math::
            max(0, max_p{d(anchor, p)} - min_n{d(anchor, n)} + margin), where p: positives, n: negatives.

        :param torch.Tensor distances:
            Float tensor that indicates distance between anchor and indicated embeddings.
            Shape [B, M], where B = batch size, M = number of instances compared with anchors
        :param torch.Tensor positives:
            Bool tensor that indicates positive examples. Shape [B, M]

        :return:
            A Float Tensor of computed loss with shape [].
        """
        assert distances.shape == positives.shape

        batch_sz, examples = distances.shape
        loss = None

        for b in range(batch_sz):
            # Shape [P, 1]
            pos = distances[b].masked_select(positives).unsqueeze(-1)
            # Shape [1, N]
            neg = distances[b].masked_select(~positives).unsqueeze(0)

            # Compute d(a, p) - d(a, n). Shape [P, N]. Then sum up values to [1].
            diff = (pos - neg).exp().sum()
            loss_b = torch.log1p(diff)

            if loss is None:
                loss = loss_b
            else:
                loss += loss_b

        # Return tuple ranking loss
        return loss / batch_sz


class MarginBasedLoss(nn.Module):
    """
    Computes margin-based loss on distance matrix.
    """

    def __init__(self, margin: float = 1.0):
        """
        Initialize Margin Based Loss.

        :param float margin: Margin to be applied.
        """
        super().__init__()
        self.margin = margin

    def forward(self, distances: torch.Tensor, positives: torch.Tensor) -> torch.Tensor:
        """
        Computes margin-based loss on distance matrix.

        Since we already computed distance matrix, we only need to compute

        .. math::
            RELU(margin + y(anchor, ex) * d(anchor, ex))

        Where `y` is 1 for positives, -1 for negatives.

        :param torch.Tensor distances:
            Float tensor that indicates distance between anchor and indicated embeddings.
            Shape [B, M], where B = batch size, M = number of instances compared with anchors
        :param torch.Tensor positives:
            Bool tensor that indicates positive examples. Shape [B, M]

        :return:
            A Float Tensor of computed loss with shape [].
        """
        assert distances.shape == positives.shape

        side = positives.float() * 2 - 1

        return (self.margin + side * distances).relu().mean()


class SmoothedCrossEntropyLoss(nn.Module):
    """
    Computes cross entropy loss with uniformly smoothed targets.
    """

    def __init__(self, smoothing: float = 0.1, ignore_index: int = PAD_ID, reduction: str = 'batchmean'):
        """
        Cross entropy loss with uniformly smoothed targets.

        :param float smoothing: Label smoothing factor, between 0 and 1 (exclusive; default is 0.1)
        """
        assert 0 < smoothing < 1, "Smoothing factor should be in (0.0, 1.0)"
        assert reduction in {'batchmean', 'none', 'sum'}
        super().__init__()

        self.smoothing = smoothing
        self.ignore_index = ignore_index
        self.reduction = reduction

    def forward(self, input: torch.Tensor, target: torch.LongTensor) -> torch.Tensor:
        """
        Computes cross entropy loss with uniformly smoothed targets.
        Since the entropy of smoothed target distribution is always same, we can compute this with KL-divergence.

        :param torch.Tensor input: Log probability for each class. This is a Tensor with shape [B, C]
        :param torch.LongTensor target: List of target classes. This is a LongTensor with shape [B]
        :return torch.Tensor: Computed loss
        """
        target = target.view(-1, 1)

        # Prepare smoothed target
        # Set all probability of the targets which should be ignored as zero.
        # Since D_KL(p, q) = p (log(p) - log(q)), by setting p(x) ≡ 0, these target cannot affect loss anymore.
        smoothed_target = torch.zeros(input.shape, requires_grad=False, device=target.device)

        # Set target values zero if predicted values are masked with -inf.
        for r, row in enumerate(input):
            tgt = target[r].item()
            if tgt == self.ignore_index:
                continue

            finites = torch.isfinite(row)
            n_cls = finites.sum().item()
            assert n_cls > 0

            smoothing_prob = self.smoothing / n_cls
            smoothed_target[r].masked_fill_(finites, smoothing_prob)
            smoothed_target[r, tgt] = 1.0 - self.smoothing

        # Compute loss: - p log q
        loss = - smoothed_target * input.masked_fill(~torch.isfinite(input), 0.0)

        if self.reduction == 'batchmean':
            return loss.sum() / input.shape[0]
        elif self.reduction == 'sum':
            return loss.sum()
        else:
            return loss


class FocalLoss(nn.Module):
    """
    Computes cross entropy loss with uniformly smoothed targets.
    """

    def __init__(self, ignore_index: int = PAD_ID, gamma: float = 1.0, reduction: str = 'batchmean'):
        """
        Cross entropy loss with uniformly smoothed targets.

        :param float smoothing: Label smoothing factor, between 0 and 1 (exclusive; default is 0.1)
        """
        assert reduction in {'batchmean', 'none', 'sum'}
        super().__init__()

        self.ignore_index = ignore_index
        self.gamma = gamma
        self.reduction = reduction

    def forward(self, input: torch.Tensor, target: torch.LongTensor) -> torch.Tensor:
        """
        Computes cross entropy loss with uniformly smoothed targets.
        Since the entropy of smoothed target distribution is always same, we can compute this with KL-divergence.

        :param torch.Tensor input: Log probability for each class. This is a Tensor with shape [B, C]
        :param torch.LongTensor target: List of target classes. This is a LongTensor with shape [B]
        :return torch.Tensor: Computed loss
        """
        # Build masks
        mask = target == self.ignore_index

        # Build target matrix [B, C]
        target = F.one_hot(target.masked_fill(mask, 0), num_classes=input.shape[-1])\
            .masked_fill_(mask.unsqueeze(-1), 0).bool()

        # Compute log-likelihood [B' < B], where ignored inputs are removed.
        log_likelihood = input.masked_select(target)
        focal_weight = (-log_likelihood.expm1()).pow(self.gamma)

        # Shape [B']: (1-q)^gamma * (-log(q))
        loss = - focal_weight * log_likelihood

        if self.reduction == 'batchmean':
            return loss.sum() / input.shape[0]
        elif self.reduction == 'sum':
            return loss.sum()
        else:
            return loss


class SmoothedFocalLoss(SmoothedCrossEntropyLoss):
    """
    Computes cross entropy loss with uniformly smoothed targets.
    """

    def __init__(self, ignore_index: int = PAD_ID, gamma: float = 1.0, reduction: str = 'batchmean',
                 smoothing: float = 0.1):
        """
        Cross entropy loss with uniformly smoothed targets.

        :param float smoothing: Label smoothing factor, between 0 and 1 (exclusive; default is 0.1)
        """
        assert 0 < smoothing < 1, "Smoothing factor should be in (0.0, 1.0)"
        super().__init__(ignore_index=ignore_index, reduction='none', smoothing=smoothing)

        self.gamma = gamma
        self.reduction = reduction

    def forward(self, input: torch.Tensor, target: torch.LongTensor) -> torch.Tensor:
        """
        Computes cross entropy loss with uniformly smoothed targets.
        Since the entropy of smoothed target distribution is always same, we can compute this with KL-divergence.

        :param torch.Tensor input: Log probability for each class. This is a Tensor with shape [B, C]
        :param torch.LongTensor target: List of target classes. This is a LongTensor with shape [B]
        :return torch.Tensor: Computed loss
        """
        # Build masks
        mask = (target == self.ignore_index).unsqueeze(-1)

        # Compute focal term. Shape [B, C]
        focal_weight = (-input.expm1()).pow(self.gamma).masked_fill(mask, 0.0)

        # Compute loss. Shape [B, C]
        smoothed_loss = super().forward(input, target)

        # Shape [B, C]: (1-q)^gamma * (-p log(q))
        loss = focal_weight * smoothed_loss

        if self.reduction == 'batchmean':
            return loss.sum() / input.shape[0]
        elif self.reduction == 'sum':
            return loss.sum()
        else:
            return loss


class PaddedBCELoss(nn.BCEWithLogitsLoss):
    def __init__(self, ignore_index: int = PAD_ID, pos_weight=None):
        super().__init__(reduction='none', pos_weight=pos_weight)
        self.ignore_index = ignore_index

    def forward(self, predicted: torch.Tensor, target: torch.Tensor) -> torch.Tensor:
        # Predicted: Float tensor before applying log-sigmoid
        # Target: Long tensor
        ignored = target == self.ignore_index

        # Compute loss
        loss = super().forward(predicted, (target == 1.0).float())

        # Ignore losses come from ignored positions
        loss.masked_fill_(ignored, 0.0)

        # Take means.
        loss = loss.flatten(start_dim=1).sum(dim=1) / (~ignored).flatten(start_dim=1).sum(dim=1)
        return loss.mean()


def accuracy(greedy_choice_correct: torch.Tensor, target_focus: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
    """
    Compute accuracy by comparing two Bool Tensors

    :param torch.Tensor greedy_choice_correct:
        Bool Tensor indicating whether prediction is correct or not.
        `True` if greedy choice based on prediction is correct on the entry.
    :param torch.Tensor target_focus:
        Bool Tensor indicating whether we interested in the position or not.
        `True` if we don't ignore the position.
    :rtype: Tuple[torch.Tensor, torch.Tensor]
    :return:
        Tuple of two Float Tensors.
        - [0] indicates token level accuracy.
        - [1] indicates sequence level accuracy.
    """
    with torch.no_grad():
        token_lv_acc = (greedy_choice_correct & target_focus).sum().float()
        token_lv_acc /= target_focus.sum().float()

        # Set NaN as 1.0 since there are no such token to measure the accuracy.
        token_lv_acc.masked_fill_(torch.isnan(token_lv_acc), 1.0)

        if target_focus.dim() > 1:
            # Case of [B, *]. (multiple values per a sequence)
            seq_lv_acc = ((~greedy_choice_correct & target_focus).sum(dim=-1) == 0).sum().float()  # Add by batch
            seq_lv_acc /= greedy_choice_correct.shape[0]
        else:
            # Case of predicting a single value per a sequence.
            seq_lv_acc = token_lv_acc

    return token_lv_acc, seq_lv_acc


def get_class_weights(target: torch.Tensor, n_class: int, smoothing_factor: float = 1.0,
                      device: torch.device = None) -> torch.Tensor:
    counter = Counter(target.flatten().tolist())
    if PAD_ID in counter:
        counter.pop(PAD_ID)

    max_class = max(counter.values())
    return torch.tensor([max_class / counter.get(i, smoothing_factor) for i in range(n_class)],
                        device=device, requires_grad=False)


def loss_and_accuracy(predicted: torch.Tensor, target: torch.Tensor, prefix='TrainToken',
                      loss_factor: float = 1.0) -> Dict[str, torch.Tensor]:
    """
    Compute loss and accuracy. Loss will be selected by following rules.
    - If target.dim + 1 == predicted.dim and target: LongTensor and predicted: FloatTensor -> use Cross-Entropy
    - If target and predicted dimensions are the same and both are FloatTensor -> use KL-divergence
    - If target and predicted dimensions are the same and target: BoolTensor and predicted: FloatTensor -> use BinaryCE.

    :param torch.Tensor predicted: Tensor of predicted result.
    :param torch.Tensor target: Tensor of targeted result.
    :param str prefix: String prefix for dictionary keys.
    :rtype: Dict[str, torch.Tensor]
    :return: Dictionary that contains the following items
        - [prefix]/loss: Loss value
        - [prefix]/acc_seq: Sequence level accuracy
        - [prefix]/acc_token: Token level accuracy.
    """

    tdim = target.dim()
    pdim = predicted.dim()
    tdtype = target.dtype

    result = {}

    if tdtype == torch.long:
        if tdim + 1 == pdim:
            # This is the case for Cross-Entropy.
            # Compute accuracy.
            target_focus = target != PAD_ID
            greedy_choice_correct = predicted.argmax(dim=-1) == target
            token_lv_acc, seq_lv_acc = accuracy(greedy_choice_correct, target_focus)

            # Flatten predicted to [*, C] and target to [*]
            predicted = predicted.flatten(0, -2)
            target = target.flatten()

            # Prepare loss function
            # loss_fct = nn.CrossEntropyLoss(ignore_index=PAD_ID)
            loss_fct = SmoothedCrossEntropyLoss(ignore_index=PAD_ID)
        elif tdim == pdim:
            # This is the case for Binary Cross-Entropy.
            target_focus = (target != PAD_ID).any(dim=-1)
            greedy_choice_correct = ((predicted.sigmoid() > 0.75) == (target == 1)).all(dim=-1)
            token_lv_acc, seq_lv_acc = accuracy(greedy_choice_correct, target_focus)

            class_weights = get_class_weights(target, predicted.shape[-1], device=predicted.device)
            loss_fct = PaddedBCELoss(pos_weight=class_weights)
        else:
            raise NotImplementedError('If target has long type of dimension %s, '
                                      'predicted tensor should be %s- or %s-dimensional, not %s-dimensional!' %
                                      (tdim, tdim, tdim+1, pdim))

        result.update({
            'acc_token': token_lv_acc,
            'acc_seq': seq_lv_acc
        })
    elif tdtype == torch.float:
        assert tdim == pdim, 'If target has float type, target and predicted tensor should have the same dimensions!'

        # This is the case for KL-divergence.
        # Compute MSE
        diff = predicted.exp() - target
        mse = diff.pow(2).flatten(1).sum(dim=1).mean()

        # KL-Divergence loss does not need to flatten the target.
        # Prepare loss function.
        loss_fct = nn.KLDivLoss(reduction='batchmean')
        result['mse_token'] = mse
    elif tdtype == torch.bool:
        assert tdim == pdim, 'If target has bool type, target and predicted tensor should have the same dimensions!'

        # This is the case for Triplet Margin Loss.
        # Compute accuracy. Target indicates positive examples
        greedy_choice_correct = target.gather(dim=-1, index=predicted.argmin(dim=-1, keepdim=True))
        target_focus = target.sum(dim=-1, keepdim=True) > 0
        token_lv_acc, seq_lv_acc = accuracy(greedy_choice_correct, target_focus)

        # Set metric loss
        loss_fct = MarginBasedLoss()

        # Add information about positive/negative distances
        distances = {'border': [], 'margin': [], 'max_positive': [], 'min_negative': []}
        batch_sz = target.shape[0]
        for b in range(batch_sz):
            pred_b = predicted[b]
            tgt_b = target[b]

            pos = pred_b.masked_select(tgt_b).min()
            neg = pred_b.masked_select(~tgt_b).max()

            distances['border'].append((pos + neg) / 2)
            distances['margin'].append((neg - pos) / 2)
            distances['max_positive'].append(pos)
            distances['min_negative'].append(neg)

        result.update({key: sum(value) / batch_sz for key, value in distances.items()})
        result.update({
            'num_positives': target.sum().float() / batch_sz,
            'acc_token': token_lv_acc,
            'acc_seq': seq_lv_acc
        })
    else:
        raise NotImplementedError('There are no such rules for computing loss between %s-dim predicted %s tensor '
                                  'and %s-dim target %s tensor' % (pdim, predicted.dtype, tdim, tdtype))

    # Compute loss
    loss = loss_fct(predicted, target)

    if loss_factor != 1.0:
        loss = loss * loss_factor

    # For debugging.
    if not torch.isfinite(loss).all().item():
        print('NAN')

    result['loss'] = loss

    return {prefix + '/' + key: value for key, value in result.items()}


__all__ = ['loss_and_accuracy']
