import torch
from utils.config import *
# USE_CUDA = False


def max_margin_loss(scores, target, delta=0.5):
    """
    Args:
        scores: A Variable containing a FloatTensor of size
            (batch, num_triples) which contains the
            unnormalized probability for each triple.
        target: A Variable containing a LongTensor of size
            (batch, num_triples) which contains the true
            class for each corresponding triple.

    Returns:
        loss: An average loss value.
    """
    if scores is None or target is None:
        if USE_CUDA:
            loss = torch.tensor(0.0).cuda()
            return loss
        else:
            loss = torch.tensor(0.0)
            return loss
    _, indices = torch.sort(target, descending=True)
    pos_idx, neg_idx = indices[:, 0].unsqueeze(1), indices[:, 1:]
    pos_val = torch.gather(scores, dim=1, index=pos_idx)
    neg_val = torch.gather(scores, dim=1, index=neg_idx)
    max_neg_val, _ = torch.max(neg_val, dim=1)
    loss = -1 * pos_val + max_neg_val.unsqueeze(1) + delta
    loss[loss < 0] = 0
    loss_avg = torch.mean(loss)
    return loss_avg


if __name__ == '__main__':
    ret = max_margin_loss(torch.Tensor([[0.1, 0.2, 0.3], [0.2, 0.3, 0.4], [0.4, 0.5, 0.6], [0.7, 0.1, 0.3]]), torch.LongTensor([[1, 0, 0], [0, 0, 1], [0, 1, 0], [0, 0, 1]]))
    print(ret)
