import torch
import torch.nn.functional as F

def focal_loss(labels, logits, alpha, gamma):
    """Compute the focal loss between `logits` and the ground truth `labels`.
    Focal loss = -alpha_t * (1-pt)^gamma * log(pt)
    where pt is the probability of being classified to the true class.
    pt = p (if true class), otherwise pt = 1 - p. p = sigmoid(logit).
    Args:
      labels: A float tensor of size [batch, num_classes].
      logits: A float tensor of size [batch, num_classes].
      alpha: A float tensor of size [batch_size]
        specifying per-example weight for balanced cross entropy.
      gamma: A float scalar modulating loss from hard and easy examples.
    Returns:
      focal_loss: A float32 scalar representing normalized total loss.
    """
    BCLoss = F.binary_cross_entropy_with_logits(input=logits, target=labels, reduction="none")

    if gamma == 0.0:
        modulator = 1.0
    else:
        modulator = torch.exp(-gamma * labels * logits - gamma * torch.log(1 +
            torch.exp(-1.0 * logits)))

    loss = modulator * BCLoss

    weighted_loss = alpha * loss
    focal_loss = torch.sum(weighted_loss)

    focal_loss /= torch.sum(labels)
    return focal_loss



def cb_loss(labels, logits, samples_per_cls, beta=0.999, gamma=1):
    """Compute the Class Balanced Loss between `logits` and the ground truth `labels`.
    Class Balanced Loss: ((1-beta)/(1-beta^n))*Loss(labels, logits)
    where Loss is one of the standard losses used for Neural Networks.
    Args:
      labels: A int tensor of size [batch].
      logits: A float tensor of size [batch, no_of_classes].
      samples_per_cls: A python list of size [no_of_classes].
      no_of_classes: total number of classes. int
      loss_type: string. One of "sigmoid", "focal", "softmax".
      beta: float. Hyperparameter for Class balanced loss.
      gamma: float. Hyperparameter for Focal loss.
    Returns:
      cb_loss: A float tensor representing class balanced loss
    """
    eps = 5e-1
    no_of_classes = logits.size(-1)
    effective_num = 1.0 - beta ** samples_per_cls
    weights = (1.0 - beta) / torch.max(effective_num, torch.ones_like(effective_num) * eps)
    weights = weights / torch.sum(weights) * no_of_classes

    weights = weights.unsqueeze(0)
    weights = weights.repeat(labels.shape[0],1) * labels
    weights = weights.sum(1)
    weights = weights.unsqueeze(1)
    weights = weights.repeat(1,no_of_classes)

    return focal_loss(labels, logits, weights, gamma)

    # if loss_type == "focal":
    #     cb_loss = focal_loss(labels, logits, weights, gamma)
    # elif loss_type == "sigmoid":
    #     cb_loss = F.binary_cross_entropy_with_logits(input = logits,target = labels, weights = weights)
    # elif loss_type == "softmax":
    #     pred = logits.softmax(dim = 1)
    #     cb_loss = F.binary_cross_entropy(input = pred, target = labels, weight = weights)
    # return cb_loss