#!/usr/bin/python3
#-*- coding:utf-8 -*-
############################
#File Name: losses.py
#Created Time:
############################
from torch.nn import CrossEntropyLoss
from torch.nn import BCEWithLogitsLoss
from torch.nn import MultiLabelMarginLoss
from torch.nn import KLDivLoss
import torch.nn.functional as F
import torch.nn as nn

__call__=['CrossEntropy','BCELogitsLoss','MyMultiLabelMarginLoss']

class CrossEntropy(object):
    def __init__(self):
        outputs = outputs.float()
        targets = targets.long()
        self.loss_f = CrossEntropyLoss()

    def __call__(self,outputs,targets):
        loss = self.loss_f(input=outputs,target=targets)
        return loss

class BCELogitsLoss(object):
    def __init__(self):
        self.loss_f = BCEWithLogitsLoss()

    def __call__(self,outputs,targets):
        outputs = outputs.float()
        targets = targets.float()
        loss = self.loss_f(input=outputs,target=targets)
        return loss

class MyMultiLabelMarginLoss(object):
    def __init__(self):
        self.loss_f = MultiLabelMarginLoss()

    def __call__(self,outputs,targets):
        outputs = outputs
        targets = targets.long()
        loss = self.loss_f(input=outputs,target=targets)
        return loss


#def KLLoss():
    #def __init__(self):
    #    super(KLLoss,self).__init__()
    #    self.loss_f = KLDivLoss()

def KLLoss(outputs,targets):
    loss_f = KLDivLoss()
    outputs = F.log_softmax(outputs,1)
    targets = F.softmax(targets,1)
    loss = loss_f(outputs,targets)
    return loss/outputs.shape[1]
