import torch
from torch.autograd import Variable


def l2norm(X):
    """L2-normalize columns of X
    """
    norm = torch.pow(X, 2).sum(dim=1, keepdim=True).sqrt()
    X = torch.div(X, norm)
    return X


class PairwiseRankingLoss(torch.nn.Module):

    def __init__(self, margin=1.0):
        super(PairwiseRankingLoss, self).__init__()
        self.margin = margin

    def forward(self, im, s):
        margin = self.margin
        # compute image-sentence score matrix
        scores = torch.mm(im, s.transpose(1, 0))
        diagonal = scores.diag()

        # compare every diagonal score to scores in its column (i.e, all
        # contrastive images for each sentence)
        cost_s = torch.max(Variable(torch.zeros(scores.size()[0], scores.size()[
                           1]).cuda()), (margin - diagonal).expand_as(scores) + scores)
        # compare every diagonal score to scores in its row (i.e, all
        # contrastive sentences for each image)
        cost_im = torch.max(Variable(torch.zeros(scores.size()[0], scores.size()[
                            1]).cuda()), (margin - diagonal).expand_as(scores).transpose(1, 0) + scores)

        for i in range(scores.size()[0]):
            cost_s[i, i] = 0
            cost_im[i, i] = 0

        return cost_s + cost_im


class PairwiseRankingScore(torch.nn.Module):

    def __init__(self, margin=1.0):
        super(PairwiseRankingScore, self).__init__()
        self.margin = margin

    def forward(self, im, s):
        margin = self.margin
        # compute image-sentence score matrix
        scores = torch.mm(im, s.transpose(1, 0))
        diagonal = scores.diag()
        # print(diagonal)
        # print(diagonal.expand_as(scores))
        # compare every diagonal score to scores in its column (i.e, all
        # contrastive images for each sentence)
        cost_s = torch.max(Variable(torch.zeros(scores.size()[0], scores.size()[
            1]).cuda()), (margin - diagonal).expand_as(scores) + scores)
        # compare every diagonal score to scores in its row(i.e, all contrastive
        # sentences for each image)
        cost_im = torch.max(Variable(torch.zeros(scores.size()[0], scores.size()[
            1]).cuda()), (margin - diagonal).expand_as(scores).transpose(1, 0) +
            scores)
        # print(cost_im.size())
        # print(cost_s.size())
        # cost_s = (margin - diagonal).expand_as(scores) + scores
        for i in range(scores.size()[0]):
            cost_s[i, i] = 0
            cost_im[i, i] = 0

        return cost_s + cost_im


class PairwiseRankingScore_2(torch.nn.Module):

    def __init__(self, margin=1.0):
        super(PairwiseRankingScore, self).__init__()
        self.margin = margin

    def forward(self, im, s):
        margin = self.margin
        # compute image-sentence score matrix
        scores = torch.mm(im, s.transpose(1, 0))
        diagonal = scores.diag()
        # print(diagonal)
        # print(diagonal.expand_as(scores))
        # compare every diagonal score to scores in its column (i.e, all
        # contrastive images for each sentence)
        cost_s = (margin - diagonal).expand_as(scores) + scores
        # compare every diagonal score to scores in its row(i.e, all contrastive
        # sentences for each image)
        cost_im = (margin - diagonal).expand_as(scores).transpose(1, 0) + scores
        # # print(cost_im.size())
        # # print(cost_s.size())
        # # cost_s = (margin - diagonal).expand_as(scores) + scores
        # for i in range(scores.size()[0]):
        #     cost_s[i, i] = 0
        #     cost_im[i, i] = 0

        return cost_s + cost_im


class PairwiseRankingScore_3(torch.nn.Module):

    def __init__(self, margin=1.0):
        super(PairwiseRankingScore, self).__init__()
        self.margin = margin

    def forward(self, im, s):
        margin = self.margin
        # compute image-sentence score matrix
        scores = torch.mm(im, s.transpose(1, 0))
        diagonal = scores.diag()
        # print(diagonal)
        # print(diagonal.expand_as(scores))
        # compare every diagonal score to scores in its column (i.e, all
        # contrastive images for each sentence)
        cost_s = scores - diagonal.expand_as(scores)

        return cost_s
