import torch



def hinge(x):
    return torch.clamp(x, min=0.)


def paired_hinge_rank_loss(
        lang_output: torch.Tensor,
        visn_output: torch.Tensor,
        lang_mask: torch.Tensor,
        margin: float,
):
    """
    Consider the first half as positive and the second half as negative.
    :param lang_output: [batch_size, max_len, hid_dim]
    :param visn_output: [batch_size, hid_dim]
    :param lang_mask: Int Tensor [batch_size, max_len], 1 for tokens, 0 for paddings.
    :param margin: margin in the ranking loss
    :return: a scalar loss
    """
    batch_size, lang_len, dim = lang_output.shape
    assert batch_size % 2 == 0 and batch_size == visn_output.shape[0]
    assert margin > 0.

    # Expand the visn_output to match each word
    visn_output = visn_output.unsqueeze(1)      # [b, 1, hid_dim]

    # Split to positive and negative sets.
    half_batch_size = batch_size // 2
    pos_lang, neg_lang = torch.split(lang_output, half_batch_size, dim=0)
    pos_visn, neg_visn = torch.split(visn_output, half_batch_size, dim=0)

    # Calculate positive and negative scores.
    true_pos_score = (pos_lang * pos_visn).sum(-1)           # [batch_size / 2, max_len]
    true_neg_score = (neg_lang * neg_visn).sum(-1)           # [batch_size / 2, max_len]
    false_pos_score = (pos_lang * neg_visn).sum(-1)          # [batch_size / 2, max_len]
    false_neg_score = (neg_lang * pos_visn).sum(-1)          # [batch_size / 2, max_len]

    # Hinge Loss
    float_lang_mask = lang_mask.type(lang_output.dtype)      # Either fp16 or fp32
    pos_lang_mask, neg_lang_mask = torch.split(float_lang_mask, half_batch_size, dim=0)
    pos_loss = hinge(margin - true_pos_score + false_pos_score) * pos_lang_mask
    neg_loss = hinge(margin - true_neg_score + false_neg_score) * neg_lang_mask

    # Averaging
    cnt = float_lang_mask.sum()    # Number of words.
    loss = (pos_loss.sum() + neg_loss.sum()) / cnt

    return loss

def paired_hinge_rank_loss2(
        lang_output: torch.Tensor,
        visn_output: torch.Tensor,
        neg_lang_output: torch.Tensor,
        neg_visn_output: torch.Tensor,
        margin: float,
        bertonly=False
):
    """
    Consider the first half as positive and the second half as negative.
    :param lang_output: [batch_size, max_len, hid_dim]
    :param visn_output: [batch_size, hid_dim]
    :param lang_mask: Int Tensor [batch_size, max_len], 1 for tokens, 0 for paddings.
    :param margin: margin in the ranking loss
    :return: a scalar loss
    """
    batch_size, dim = lang_output.shape
    assert batch_size == visn_output.shape[0]
    assert margin > 0.


    
    pos_lang = lang_output
    neg_lang = neg_lang_output
    pos_visn = visn_output
    neg_visn = neg_visn_output

    
    if bertonly:
        pos_score = pos_lang
        neg_score = neg_lang
    else:
        pos_score = (pos_lang * pos_visn).sum(-1)
        neg_score = (neg_lang * neg_visn).sum(-1)  
    # Calculate positive and negative scores.

    loss = hinge(margin - pos_score + neg_score) 
    loss = loss.sum() / batch_size

    # # Hinge Loss
    # float_lang_mask = lang_mask.type(lang_output.dtype)      # Either fp16 or fp32
    # pos_lang_mask, neg_lang_mask = torch.split(float_lang_mask, half_batch_size, dim=0)
    # pos_loss = hinge(margin - true_pos_score + false_pos_score) * pos_lang_mask
    # neg_loss = hinge(margin - true_neg_score + false_neg_score) * neg_lang_mask

    # # Averaging
    # cnt = float_lang_mask.sum()    # Number of words.
    # loss = (pos_loss.sum() + neg_loss.sum()) / cnt

    return loss

def paired_hinge_rank_loss3(
        lang_output: torch.Tensor,
        visn_output: torch.Tensor,
        margin=1,
):
    """
    Consider the first half as positive and the second half as negative.
    :param lang_output: [batch_size, max_len, hid_dim]
    :param visn_output: [batch_size, hid_dim]
    :param lang_mask: Int Tensor [batch_size, max_len], 1 for tokens, 0 for paddings.
    :param margin: margin in the ranking loss
    :return: a scalar loss
    """
    batch_size, dim = lang_output.shape
    assert batch_size == visn_output.shape[0]
    assert margin > 0.

    half_batch_size = batch_size // 2
    a= lang_output/ lang_output.norm(2, dim=-1, keepdim=True)
    b= visn_output/ visn_output.norm(2, dim=-1, keepdim=True)
    pos_lang, neg_lang = torch.split(a, half_batch_size, dim=0)
    pos_visn, neg_visn = torch.split(b, half_batch_size, dim=0)

    

    pos_score = (pos_lang * pos_visn).sum(-1)
    neg_score = (neg_lang * neg_visn).sum(-1)  
    # Calculate positive and negative scores.

    loss = hinge(margin - pos_score + neg_score) 
    loss = loss.sum() / batch_size

    # # Hinge Loss
    # float_lang_mask = lang_mask.type(lang_output.dtype)      # Either fp16 or fp32
    # pos_lang_mask, neg_lang_mask = torch.split(float_lang_mask, half_batch_size, dim=0)
    # pos_loss = hinge(margin - true_pos_score + false_pos_score) * pos_lang_mask
    # neg_loss = hinge(margin - true_neg_score + false_neg_score) * neg_lang_mask

    # # Averaging
    # cnt = float_lang_mask.sum()    # Number of words.
    # loss = (pos_loss.sum() + neg_loss.sum()) / cnt

    return loss

def binary_classification_loss(
        lang_output: torch.Tensor,
        visn_output: torch.Tensor,
        # neg_lang_output: torch.Tensor,
        # neg_visn_output: torch.Tensor,
):
    batch_size,  dim = lang_output.shape
    assert batch_size == visn_output.shape[0]
    half_batch_size = batch_size // 2
    pos_lang, neg_lang = torch.split(lang_output, half_batch_size, dim=0)
    pos_visn, neg_visn = torch.split(visn_output, half_batch_size, dim=0)
    # if bertonly:
    #     pos_scores = lang_output.squeeze(1)
    #     neg_scores = neg_lang_output.squeeze(1)
    # else:
    #     pos_scores = (lang_output * visn_output).sum(-1)
    #     neg_scores = (neg_lang_output * neg_visn_output).sum(-1)
    pos_predictions = torch.cat((pos_lang*pos_visn, neg_lang*neg_visn) , dim=0)
    pos_scores = pos_predictions.sum(-1)
    neg_predictions = torch.cat((pos_lang*neg_visn, neg_lang*pos_visn), dim=0)
    neg_scores = neg_predictions.sum(-1)
    pos_label = torch.ones(batch_size).to(pos_scores.device)
    neg_label = torch.zeros(batch_size).to(neg_scores.device)
    scores = torch.cat((pos_scores, neg_scores), dim=0)
    label = torch.cat((pos_label, neg_label), dim=0)
    loss = torch.nn.functional.binary_cross_entropy_with_logits(scores, label)

    return loss

def binary_classification_loss_with_neg(
        lang_output: torch.Tensor,
        visn_output: torch.Tensor,
        neg_lang_output: torch.Tensor,
        with_random_neg=False
        # neg_visn_output: torch.Tensor,
):
    batch_size,  dim = lang_output.shape
    assert batch_size == visn_output.shape[0]
    half_batch_size = batch_size // 2

    # pos_lang, neg_lang = torch.split(lang_output, half_batch_size, dim=0)
    # pos_visn, neg_visn = torch.split(visn_output, half_batch_size, dim=0)
    # if bertonly:
    #     pos_scores = lang_output.squeeze(1)
    #     neg_scores = neg_lang_output.squeeze(1)
    # else:
    pos_scores = (lang_output * visn_output).sum(-1)
    neg_scores = (neg_lang_output * visn_output).sum(-1)


    pos_predictions = lang_output * visn_output
    neg_predictions = neg_lang_output * visn_output
    #  torch.cat((pos_lang*pos_visn, neg_lang*neg_visn) , dim=0)
    pos_scores = pos_predictions.sum(-1)
    # neg_predictions = torch.cat((pos_lang*neg_visn, neg_lang*pos_visn), dim=0)
    neg_scores = neg_predictions.sum(-1)
    pos_label = torch.ones(batch_size).to(pos_scores.device)
    neg_label = torch.zeros(batch_size).to(neg_scores.device)
    scores = torch.cat((pos_scores, neg_scores), dim=0)
    label = torch.cat((pos_label, neg_label), dim=0)
    
    if with_random_neg:
        pos_lang, neg_lang = torch.split(lang_output, half_batch_size, dim=0)
        pos_visn, neg_visn = torch.split(visn_output, half_batch_size, dim=0)
        neg_predictions_random = torch.cat((pos_lang*neg_visn, neg_lang*pos_visn), dim=0)
        neg_scores_random = neg_predictions_random.sum(-1)
        neg_label_random = torch.zeros(batch_size).to(neg_scores.device)
        scores = torch.cat((pos_scores, neg_scores,neg_scores_random), dim=0)
        label = torch.cat((pos_label, neg_label, neg_label_random), dim=0)

    loss = torch.nn.functional.binary_cross_entropy_with_logits(scores, label)


    return loss

def contrastive_loss(z1, z2, sim, z3=None, weight = 0, neg_weight = 0):
    # z1: vision, z2: lang, z2: neg lang
    cos_sim_ori = sim(z1.unsqueeze(1), z2.unsqueeze(0))

    if z3 is not None:
        # z1_z3_cos = sim(z1.unsqueeze(1), z3.unsqueeze(0))
        z1_z3_cos = sim(z1.unsqueeze(1), z3.unsqueeze(0)).squeeze()
        # z1_z3_cos = sim(z1.unsqueeze(1), z3.unsqueeze(0))
        cos_sim = torch.cat([cos_sim_ori, z1_z3_cos], 1)
        z3_weight = neg_weight
        weights = torch.ones(cos_sim.shape).to(cos_sim.device) * z3_weight
        # a = cos_sim.size(-1)
        # b = z1_z3_cos.size(-1)
        # weights = torch.tensor(
        #     [[neg_weight] * i +[0.0]+[neg_weight] * (a - b - i - 1) + [neg_weight] * i + [z3_weight] + [neg_weight] * (b - i - 1) for i in range(b)]
        #     # [[0.0] * (cos_sim.size(-1) - z1_z3_cos.size(-1)) + [0.0] * i + [z3_weight] + [0.0] * (z1_z3_cos.size(-1) - i - 1) for i in range(z1_z3_cos.size(-1))]

        # ).to(cos_sim.device)
        cos_sim = cos_sim + weights
    else:
        cos_sim = cos_sim_ori
    labels = torch.arange(cos_sim.size(0)).long().to(cos_sim.device)
    loss_fct = torch.nn.CrossEntropyLoss()

    loss1 = loss_fct(cos_sim, labels)

    cos_sim_ori = cos_sim_ori.transpose(0,1)
    labels2 = torch.arange(cos_sim_ori.size(0)).long().to(cos_sim_ori.device)
    loss2 = loss_fct(cos_sim_ori, labels2)
    loss = loss1/2 + loss2/2
    
    return loss



def batchwise_hinge_rank_loss(
        lang_output: torch.Tensor,
        visn_output: torch.Tensor,
        lang_mask: torch.Tensor,
        margin: float,
):
    """
    Consider all un-matched pairs in the batch as negative samples.
    :param lang_output: [batch_size, max_len, hid_dim]
    :param visn_output: [batch_size, hid_dim]
    :param lang_mask: Int Tensor [batch_size, max_len], 1 for tokens, 0 for paddings.
    :param margin: margin in the ranking loss
    :return: a scalar loss
    """
    batch_size, lang_len, dim = lang_output.shape
    assert batch_size % 2 == 0 and batch_size == visn_output.shape[0]
    assert margin > 0.

    # Expand the visn_output to match each word
    visn_output = visn_output.unsqueeze(1)                  # [b, 1, dim]

    # The score of positive pairs
    positive_score = (lang_output * visn_output.unsqueeze(1)).sum(-1)    # [b, max_len]

    # The score of negative pairs. Note that the diagonal is actually the positive score,
    # but it would be zero-graded in calculating the loss below.
    negative_scores = (lang_output.reshape(batch_size, 1, lang_len, dim) *
                       visn_output.reshape(1, batch_size, 1, dim)).sum(-1)    # [b(lang), b(visn), max_len]
    # negative_scores = torch.einsum('ikd,jd->ijk', lang_output, visn_output)

    # Calculate of the hinge rank loss, let me explain why it works:
    # For the diagonal, the scores are for positive, we thus create a positive_mask to neglect these scores.
    #   max(0., margin - x^T x + (x^T x - 2 margin) )
    # = max(0., -margin)
    # = 0.      , since we have made sure that margin > 0
    # During backwards, the operator max(0., -margin) would raise a grad of 0 to the operand "-margin",
    #   thus it is just what we want.
    float_lang_mask = lang_mask.type(lang_output.dtype)       # Either fp16 or fp32
    positive_mask = torch.eye(batch_size)
    negative_scores = negative_scores - positive_mask.unsqueeze(-1) * margin * 2
    lang_loss = hinge(margin - positive_score.unsqueeze(1) + negative_scores) * float_lang_mask.unsqueeze(1)
    visn_loss = hinge(margin - positive_score.unsqueeze(0) + negative_scores) * float_lang_mask.unsqueeze(1)

    # Averaging
    # Each sentence is duplicated by batch_size thus the total length is also multiplied by this term.
    cnt = max(float_lang_mask.sum() * batch_size, 1.)    # Number of words.
    lang_loss = lang_loss.sum() / cnt
    visn_loss = visn_loss.sum() / cnt

    return lang_loss + visn_loss

