# Copyright (c) <anonymized for review>

from lama.generation import GenerationImpl
import logging

logger = logging.getLogger(__name__)


class TokenConfidence():

    def __init__(self, args):
        self.use_prediction_value = getattr(
            args, "use_prediction_value", False)
        self.gen = GenerationImpl(args.generation)

    def get_metrics(self, sample, predictions):
        all_has_log_prob = all(["log_prob" in p for p in predictions])

        # Generate model prediction if needed
        if not (self.use_prediction_value and all_has_log_prob):
            self.gen.model.eval()
            sentences_b = [sample["masked_sentences"]]
            genout = self.gen.get_batch_generation(sentences_b, logger)

        scores = []
        for prediction in predictions:
            if self.use_prediction_value and "log_prob" in prediction:
                scores.append(prediction["log_prob"])
            else:
                target_label = prediction["token_word_form"]
                if not self.gen.is_in_model_vocabulary(target_label):
                    token_idx = prediction["token_idx"]
                    logger.warning(
                        f"Object label {target_label} not in model vocabulary.\n" +
                        f"Using token index ({token_idx}) from the output. \n" +
                        f"Sample: {sample}"
                    )
                    target_label_id = token_idx
                else:
                    target_label_id = self.gen.model.get_id(target_label)[0]

                masked_index = genout.masked_indices_list[0][0]
                log_probs = genout.original_log_probs_list[0][masked_index]
                target_log_prob = log_probs[target_label_id].item()
                scores.append(target_log_prob)
        
        return scores
