# Copyright (c) <anonymized for review>

from lama.generation import GenerationImpl
import logging

logger = logging.getLogger(__name__)


class SentConfidence():

    def __init__(self, args):
        self.gen = GenerationImpl(args.generation)
        self.normalize = args.normalize
        self.num_special_tokens = args.num_special_tokens

    def get_metrics(self, sample, predictions):
        self.gen.model.eval()
        sentences_b = [sample["masked_sentences"]]
        genout = self.gen.get_batch_generation(sentences_b, logger)

        scores = []
        for prediction in predictions:
            target_label = prediction["token_word_form"]
            if not self.gen.is_in_model_vocabulary(target_label):
                token_idx = prediction["token_idx"]
                logger.warning(
                    f"Object label {target_label} not in model vocabulary.\n" +
                    f"Using token index ({token_idx}) from the output. \n" +
                    f"Sample: {sample}"
                )
                target_label_id = token_idx
            else:
                target_label_id = self.gen.model.get_id(target_label)[0]

            # Replace mask with target token
            masked_index = genout.masked_indices_list[0][0]
            token_ids = genout.token_ids_list[0]
            token_ids[masked_index] = target_label_id

            sent_score = self.gen.model.get_sentence_score(token_ids, logger=logger)

            if self.normalize:
                sent_score = sent_score / (len(token_ids) - self.num_special_tokens)

            scores.append(sent_score)

        return scores
