# Copyright (c) <anonymized for review>

import re
import numpy as np

from lama.generation import GenerationImpl

import logging

logger = logging.getLogger(__name__)


def kl_divergence_from_log_probs(logp, logq):
    p = np.exp(logp)
    return np.sum(np.where(p != 0., p * (logp - logq), 0))


def entropy_diff(logp, logq):
    p = np.exp(logp)
    q = np.exp(logq)
    entropy_p = np.sum(-p * logp)
    entropy_q = np.sum(-q * logq)
    return entropy_q - entropy_p


class PromptDiffConfidence():

    def __init__(self, args):
        self.gen = GenerationImpl(args.generation)
        # self.normalize = args.normalize
        # self.num_special_tokens = args.num_special_tokens
        self.mask_str = getattr(args, "mask", "[MASK]")

    def get_metrics(self, sample, predictions, debug=False):
        self.gen.model.eval()
        sentence = sample["masked_sentences"][0]
        sub_label = sample["sub_label"]
        if sub_label.lower() not in sentence.lower():
            raise RuntimeError(
                f"Sentence '{sentence}' does not contain subject label '{sub_label}'."
            )
        # re_case_insensitive_subj = re.compile(f" {sub_label} ", re.IGNORECASE)
        re_case_insensitive_subj = re.compile(f"(?<![0-9a-zA-z]){re.escape(sub_label)}(?![0-9a-zA-z])", re.IGNORECASE)
        prompt_only_sentence = re_case_insensitive_subj.sub(self.mask_str, sentence, count=1)

        sentences_b = [[sentence], [prompt_only_sentence]]
        genout = self.gen.get_batch_generation(sentences_b, logger)

        # Check number of mask tokens
        masked_indices_orig = genout.masked_indices_list[0]
        masked_indices_prompt_only = genout.masked_indices_list[1]
        assert (
            len(masked_indices_orig) == 1 and len(masked_indices_prompt_only) == 2
        ), f"Input: {sample}, Subject: {sub_label}, Output token ids: {genout.token_ids_list}"

        # Compare mask positions to identify original target mask in prompt-only sample
        masked_index_orig = masked_indices_orig[0]
        if masked_index_orig == masked_indices_prompt_only[0]:
            masked_index_prompt_only = masked_indices_prompt_only[0]
        else:
            masked_index_prompt_only = masked_indices_prompt_only[1]

        scores = []
        for prediction in predictions:
            target_label = prediction["token_word_form"]
            if not self.gen.is_in_model_vocabulary(target_label):
                token_idx = prediction["token_idx"]
                logger.warning(
                    f"Object label {target_label} not in model vocabulary.\n" +
                    f"Using token index ({token_idx}) from the output. \n" +
                    f"Sample: {sample}"
                )
                target_label_id = token_idx
            else:
                target_label_id = self.gen.model.get_id(target_label)[0]

            log_probs_pred = genout.original_log_probs_list[0][masked_index_orig].detach().numpy()
            log_probs_prompt_only = genout.original_log_probs_list[1][masked_index_prompt_only].detach().numpy()

            kld = kl_divergence_from_log_probs(log_probs_pred, log_probs_prompt_only)
            ent_diff = entropy_diff(log_probs_pred, log_probs_prompt_only)

            target_prob_pred = np.exp(log_probs_pred[target_label_id])
            target_prob_prompt_only = np.exp(log_probs_prompt_only[target_label_id])

            prob_diff = max(target_prob_pred - target_prob_prompt_only, 0)

            scores.append({
                "kl_divergence": kld,
                "prob_diff": prob_diff,
                "entropy_diff": ent_diff,
            })

        if debug:
            return scores, genout

        return scores
