import torch


def batchwise_accuracy(lang_output, visn_output, lang_mask):
    """
    Calculate the accuracy of contextual word retrieval, average by batch.
    :param lang_output: [batch_size, max_len, hid_dim]
    :param visn_output: [batch_size, hid_dim]
    :param lang_mask: Int Tensor [batch_size, max_len], 1 for tokens, 0 for paddings.
    :return:
    """
    batch_size, lang_len, dim = lang_output.shape
    # assert batch_size % 2 == 0 and batch_size == visn_output.shape[0]
    assert batch_size == visn_output.shape[0]

    # Expand the visn_output to match each word
    visn_output = visn_output.unsqueeze(1)                  # [b, 1, dim]

    # The score of negative pairs. Note that the diagonal is actually the positive score,
    # but it would be zero-graded in calculating the loss below.
    negative_scores = (lang_output.reshape(batch_size, 1, lang_len, dim) *
                       visn_output.reshape(1, batch_size, 1, dim)).sum(-1)    # [b(lang), b(visn), max_len]
    # negative_scores = torch.einsum('ikd,jd->ijk', lang_output, visn_output)

    max_neg_score, max_neg_idx = negative_scores.max(1)        # [batch, max_len], the batch_idx of max-aligned img
    pos_idx = torch.arange(0, batch_size, dtype=torch.int64).to(lang_output.device)

    correct = (pos_idx.unsqueeze(1) == max_neg_idx)
    bool_lang_mask = lang_mask.type(correct.dtype)
    correct = correct * bool_lang_mask
    correct_num = correct.sum()

    accuracy = correct_num * 1. / bool_lang_mask.sum()

    return accuracy

def batchwise_accuracy2(lang_output, visn_output):
    """
    Calculate the accuracy of contextual word retrieval, average by batch.
    :param lang_output: [batch_size, max_len, hid_dim]
    :param visn_output: [batch_size, hid_dim]
    :param lang_mask: Int Tensor [batch_size, max_len], 1 for tokens, 0 for paddings.
    :return:
    """
    batch_size,  dim = lang_output.shape
    # assert batch_size % 2 == 0 and batch_size == visn_output.shape[0]
    assert batch_size == visn_output.shape[0]

    # Expand the visn_output to match each word
    # visn_output = visn_output.unsqueeze(1)                  # [b, 1, dim]


    half_batch_size = batch_size // 2
    pos_lang, neg_lang = torch.split(lang_output, half_batch_size, dim=0)
    pos_visn, neg_visn = torch.split(visn_output, half_batch_size, dim=0)

    # Calculate positive and negative scores.
    true_pos_score = (pos_lang * pos_visn).sum(-1)           # [batch_size / 2, max_len]
    true_neg_score = (neg_lang * neg_visn).sum(-1)           # [batch_size / 2, max_len]
    false_pos_score = (pos_lang * neg_visn).sum(-1)          # [batch_size / 2, max_len]
    false_neg_score = (neg_lang * pos_visn).sum(-1)          # [batch_size / 2, max_len]
    # print(pos_lang.shape, pos_visn.shape)

    correct = torch.sum(true_pos_score>0) + torch.sum(true_neg_score>0) + torch.sum(false_pos_score<0) + torch.sum(false_neg_score<0)
    accuracy = correct / (batch_size*2)


    # # The score of negative pairs. Note that the diagonal is actually the positive score,
    # # but it would be zero-graded in calculating the loss below.
    # negative_scores = (lang_output.reshape(batch_size, 1, lang_len, dim) *
    #                    visn_output.reshape(1, batch_size, 1, dim)).sum(-1)    # [b(lang), b(visn), max_len]
    # # negative_scores = torch.einsum('ikd,jd->ijk', lang_output, visn_output)

    # max_neg_score, max_neg_idx = negative_scores.max(1)        # [batch, max_len], the batch_idx of max-aligned img
    # pos_idx = torch.arange(0, batch_size, dtype=torch.int64).to(lang_output.device)

    # correct = (pos_idx.unsqueeze(1) == max_neg_idx)
    # bool_lang_mask = lang_mask.type(correct.dtype)
    # correct = correct * bool_lang_mask
    # correct_num = correct.sum()

    # accuracy = correct_num * 1. / bool_lang_mask.sum()


    return accuracy

# for perturbed sentences
# def batchwise_accuracy(lang_output, visn_output, neg_lang_output, neg_visn_output, hinge=True, bertonly=False):
#     """
#     Calculate the accuracy of contextual word retrieval, average by batch.
#     :param lang_output: [batch_size, max_len, hid_dim]
#     :param visn_output: [batch_size, hid_dim]
#     :param lang_mask: Int Tensor [batch_size, max_len], 1 for tokens, 0 for paddings.
#     :return:
#     """
#     batch_size, dim = lang_output.shape
#     assert batch_size == visn_output.shape[0]

#     if hinge:
#         # The score of negative pairs. Note that the diagonal is actually the positive score,
#         # but it would be zero-graded in calculating the loss below.
        
#         if bertonly:
#             pos_score = lang_output
#             neg_score = neg_lang_output
#         else:
#             pos_score = (lang_output * visn_output).sum(-1) 
#             neg_score = (neg_lang_output * neg_visn_output).sum(-1) 
#         accuracy = torch.sum(pos_score > neg_score)
#     else:
#         if bertonly:
#             pos_score = lang_output
#             neg_score = neg_lang_output
#         else:
#             pos_score = (lang_output * visn_output).sum(-1) 
#             neg_score = (neg_lang_output * neg_visn_output).sum(-1) 
#         accuracy = torch.sum(pos_score >= 0) + torch.sum(neg_score < 0)
#         # print("pos",pos_score)
#         # print("neg",neg_score)
#         batch_size = 2 * batch_size

#     return accuracy, batch_size


def batchwise_recall(lang_output, visn_output, lang_mask, recalls=(1,)):
    """
    Calculate the accuracy of contextual word retrieval, average by batch.
    :param lang_output: [batch_size, max_len, hid_dim]
    :param visn_output: [batch_size, hid_dim]
    :param lang_mask: Int Tensor [batch_size, max_len], 1 for tokens, 0 for paddings.
    :param recall: a list, which are the number of recalls to be evaluated.
    :return:
    """
    batch_size, lang_len, dim = lang_output.shape
    assert batch_size % 2 == 0 and batch_size == visn_output.shape[0]

    # Expand the visn_output to match each word
    visn_output = visn_output.unsqueeze(1)                  # [b, 1, dim]

    # The score of positive pairs
    positive_score = (lang_output * visn_output).sum(-1)    # [b, max_len]

    # The score of negative pairs. Note that the diagonal is actually the positive score,
    # but it would be zero-graded in calculating the loss below.
    negative_scores = (lang_output.reshape(batch_size, 1, lang_len, dim) *
                       visn_output.reshape(1, batch_size, 1, dim)).sum(-1)    # [b(lang), b(visn), max_len]
    # negative_scores = torch.einsum('ikd,jd->ijk', lang_output, visn_output)

    result = {}
    for recall in recalls:
        kthscore, kthidx = torch.kthvalue(negative_scores, batch_size - recall, dim=1)     # [b, max_len]
        # print(kthscore.shape) print(positive_score.shape)
        correct = (positive_score >= kthscore)                                # [b, max_len]
        bool_lang_mask = lang_mask.type(correct.dtype)
        correct = correct * bool_lang_mask
        correct_num = correct.sum()
        # print(correct_num)
        # print(bool_lang_mask.sum())
        result[recall] = (correct_num * 1. / bool_lang_mask.sum()).item()

    return result


if __name__ == "__main__":
    print("-")