from torch import nn
from torch.nn import functional as F


class ActionConceptLoss(nn.Module):
    def forward(self, similarity, intent_action, intent_concept):
        loss_action = F.binary_cross_entropy(similarity, intent_action, reduction='sum')
        loss_concept = F.binary_cross_entropy(similarity, intent_concept, reduction='sum')
        return loss_concept + loss_action


def get_loss(cfg):
    if cfg.model.model_type == 'nli_ca':
        return ActionConceptLoss()
    elif cfg.model.model_type == 'nli_strict':
        return nn.BCELoss(reduction='sum')
    # elif cfg.model.model_type == 'nli_contrastive':
    #     return ContrastiveLoss
    else:
        ValueError("Unknown model type")
