# Copyright (c) <anonymized for review>

import numpy as np

from lama.generation import GenerationImpl
import logging

logger = logging.getLogger(__name__)


class DropoutConfidence():

    def __init__(self, args):
        self.gen = GenerationImpl(args.generation)
        self.num = args.num

        if hasattr(args, "seeds"):
            assert len(args.seeds) == self.num
            self.seeds = args.seeds
        else:
            self.seeds = range(self.num)

    def get_metrics(self, sample, predictions):
        sentences_b = [sample["masked_sentences"]]

        genouts = {}
        for seed in self.seeds:
            self.gen.model.apply_dropout(seed)
            genouts[seed] = self.gen.get_batch_generation(sentences_b, logger)

        scores = []
        for prediction in predictions:
            target_label = prediction["token_word_form"]
            if not self.gen.is_in_model_vocabulary(target_label):
                token_idx = prediction["token_idx"]
                logger.warning(
                    f"Object label {target_label} not in model vocabulary.\n" +
                    f"Using token index ({token_idx}) from the output. \n" +
                    f"Sample: {sample}"
                )
                target_label_id = token_idx
            else:
                target_label_id = self.gen.model.get_id(target_label)[0]

            output_probs = []
            for seed in self.seeds:
                genout = genouts[seed]

                masked_index = genout.masked_indices_list[0][0]
                log_probs = genout.original_log_probs_list[0][masked_index]
                target_log_prob = log_probs[target_label_id].item()
                target_prob = np.exp(target_log_prob)
                output_probs.append(target_prob)

            mean = np.mean(output_probs)
            neg_var = np.var(output_probs) * (-1)

            scores.append({
                "output_probs": output_probs,
                "mean": mean,
                "neg_var": neg_var
            })

        return scores
