# Copyright (c) <anonymized for review>
# Copyright (c) Facebook, Inc. and its affiliates.
# All rights reserved.
#
# This source code is licensed under the license found in the
# LICENSE file in the root directory of this source tree.
#
import torch
import numpy as np
import scipy
import logging

logger = logging.getLogger(__name__)


def __max_probs_values_indices(masked_indices, log_probs, topk=1000):

    # score only first mask
    masked_indices = masked_indices[:1]

    masked_index = masked_indices[0]
    log_probs = log_probs[masked_index]

    value_max_probs, index_max_probs = torch.topk(input=log_probs,k=topk,dim=0)
    index_max_probs = index_max_probs.numpy().astype(int)
    value_max_probs = value_max_probs.detach().numpy()

    return log_probs, index_max_probs, value_max_probs


def __print_top_k(value_max_probs, index_max_probs, vocab, mask_topk, index_list, max_printouts = 10, masked_token_probs = None):
    # If masked_token_probs != None, value_max_probs has sentence scores.
    # Otherwise, value_max_probs has masked token log-probability scores.

    result = []
    msg = "\n| Top{} predictions\n".format(max_printouts)
    for i in range(mask_topk):
        filtered_idx = index_max_probs[i].item()

        if index_list is not None:
            # the softmax layer has been filtered using the vocab_subset
            # the original idx should be retrieved
            idx = index_list[filtered_idx]
        else:
            idx = filtered_idx

        log_prob = value_max_probs[i].item()
        word_form = vocab[idx]

        if i < max_printouts:
            msg += "{:<8d}{:<20s}{:<12.3f}\n".format(
                i,
                word_form,
                log_prob
            )
        element = {'i' : i, 'token_idx': idx, 'log_prob': log_prob, 'token_word_form': word_form}
        if masked_token_probs is not None:
            masked_token_log_prob = masked_token_probs[i].item()
            element['masked_token_log_prob'] = masked_token_log_prob
        result.append(element)
    return result, msg


def get_ppl_reranking_result(model, vocab, token_ids, masked_indices, index_max_probs, index_list = None):
    # Replace only first mask
    masked_indices = masked_indices[:1]
    masked_index = masked_indices[0]
    
    sent_scores = np.zeros(len(index_max_probs))
    for i, predicted_token_index in enumerate(index_max_probs):
        if index_list is not None:
            idx = index_list[predicted_token_index]
        else:
            idx = predicted_token_index

        new_token_ids = token_ids.copy()
        # Replace the mask with the predicted token
        new_token_ids[masked_index] = idx

        try:
            sent_score = model.get_sentence_score(new_token_ids, logger=logger)
            sent_scores[i] = sent_score

        except Exception as e:
            logger.error(e)
            logger.error(f"token_ids: {token_ids}")
            logger.error(f"predicted_token_index: {predicted_token_index}")
            logger.error(f"original idx: {idx}")
            logger.error(f"new_token_ids: {new_token_ids}")
            raise

    # print(f"sent_scores: {sent_scores}")
    reranked_candidate_indices = np.argsort(-sent_scores)

    return sent_scores, reranked_candidate_indices


def get_ranking(log_probs, masked_indices, vocab, label_index = None, index_list = None, topk = 1000, P_AT = 10, print_generation=True, ppl_reranking=False, token_ids=None, model=None):

    experiment_result = {}

    log_probs, index_max_probs, value_max_probs = __max_probs_values_indices(masked_indices, log_probs, topk=topk)

    masked_token_probs = None
    if ppl_reranking:
        assert token_ids is not None
        assert model is not None

        sent_scores, reranked_indices = get_ppl_reranking_result(model, vocab, token_ids, masked_indices, index_max_probs, index_list)

        # print(f"reranked indices: {reranked_indices}")

        # Keep token prediction scores
        masked_token_probs = np.take(value_max_probs, reranked_indices)

        index_max_probs = np.take(index_max_probs, reranked_indices)
        value_max_probs = np.take(sent_scores, reranked_indices)

    result_masked_topk, return_msg = __print_top_k(value_max_probs, index_max_probs, vocab, topk, index_list, masked_token_probs=masked_token_probs)
    # print(result_masked_topk)

    experiment_result['topk'] = result_masked_topk

    if print_generation:
        print(return_msg)

    MRR = 0.
    P_AT_X = 0.
    P_AT_1 = 0.
    PERPLEXITY = None

    if label_index is not None:

        # check if the labe_index should be converted to the vocab subset
        if index_list is not None:
            label_index = index_list.index(label_index)

        query = torch.full(value_max_probs.shape, label_index, dtype=torch.long).numpy().astype(int)
        ranking_position = (index_max_probs==query).nonzero()

        # LABEL PERPLEXITY
        tokens = torch.from_numpy(np.asarray(label_index))
        label_perplexity = log_probs.gather(
            dim=0,
            index=tokens,
        )
        PERPLEXITY = label_perplexity.item()

        if len(ranking_position) >0 and ranking_position[0].shape[0] != 0:
            rank = ranking_position[0][0] + 1

            # print("rank: {}".format(rank))

            if rank >= 0:
                MRR = (1/rank)
            if rank >= 0 and rank <= P_AT:
                P_AT_X = 1.
            if rank == 1:
                P_AT_1 = 1.

    experiment_result["MRR"] = MRR
    experiment_result["P_AT_X"] = P_AT_X
    experiment_result["P_AT_1"] = P_AT_1
    experiment_result["PERPLEXITY"] = PERPLEXITY
    #
    # print("MRR: {}".format(experiment_result["MRR"]))
    # print("P_AT_X: {}".format(experiment_result["P_AT_X"]))
    # print("P_AT_1: {}".format(experiment_result["P_AT_1"]))
    # print("PERPLEXITY: {}".format(experiment_result["PERPLEXITY"]))

    return MRR, P_AT_X, experiment_result, return_msg


def __overlap_negation(index_max_probs__negated, index_max_probs):
    # compares first ranked prediction of affirmative and negated statements
    # if true 1, else: 0
    return int(index_max_probs__negated == index_max_probs)


def get_negation_metric(log_probs, masked_indices, log_probs_negated,
                        masked_indices_negated, vocab, index_list=None,
                        topk = 1):

    return_msg = ""
    # if negated sentence present
    if len(masked_indices_negated) > 0:

        log_probs, index_max_probs, _ = \
            __max_probs_values_indices(masked_indices, log_probs, topk=topk)
        log_probs_negated, index_max_probs_negated, _ = \
            __max_probs_values_indices(masked_indices_negated,
                                       log_probs_negated, topk=topk)

        # overlap btw. affirmative and negated first ranked prediction: 0 or 1
        overlap = __overlap_negation(index_max_probs_negated[0],
                                     index_max_probs[0])
        # rank corrl. btw. affirmative and negated predicted log_probs
        spearman_rank_corr = scipy.stats.spearmanr(log_probs,
                                                   log_probs_negated)[0]

    else:
        overlap = np.nan
        spearman_rank_corr = np.nan

    return overlap, spearman_rank_corr, return_msg
