# Copyright (c) <anonymized for review>

from dataclasses import dataclass
import torch
from typing import List

from lama.modules import build_model_by_name
from lama.utils import load_vocab

@dataclass
class GenerationOuptut:
    original_log_probs_list: torch.Tensor
    token_ids_list: List
    masked_indices_list: List
    filtered_log_probs_list: List


class GenerationImpl:
    """
    This class is for performing generation considering
    vocabulary filtering using given common vocabulary.
    """

    def __init__(self, args, logger=None):

        if len(args.models_names) > 1:
            raise ValueError(
                'Please specify a single language model '
                '(e.g., --lm "bert").')
        [model_type_name] = args.models_names
        self.model = build_model_by_name(model_type_name, args)

        # deal with vocab subset
        self.vocab_subset = None
        self.filter_logprob_indices = None
        self.index_list = None
        if args.common_vocab_filename is not None:
            self.vocab_subset = load_vocab(args.common_vocab_filename)

            # optimization for some LM (such as ELMo)
            if not self.model._optimize_top_layer_is_called:
                self.model.optimize_top_layer(self.vocab_subset)

            (
                self.filter_logprob_indices, self.index_list
            ) = self.model.init_indices_for_filter_logprobs(
                self.vocab_subset, logger
            )

    def get_batch_generation(self, sentences_b, logger=None):

        (
            original_log_probs_list,
            token_ids_list,
            masked_indices_list,
        ) = self.model.get_batch_generation(sentences_b, logger=logger)

        if self.vocab_subset is not None:
            # filter log_probs
            filtered_log_probs_list = self.model.filter_logprobs(
                original_log_probs_list, self.filter_logprob_indices
            )
        else:
            filtered_log_probs_list = original_log_probs_list

        return GenerationOuptut(
            original_log_probs_list,
            token_ids_list,
            masked_indices_list,
            filtered_log_probs_list,
        )

    def is_in_model_vocabulary(self, token_label: str):
        token_id = self.model.get_id(token_label)

        if token_id is None:
            return False
        elif self.model.vocab[token_id[0]] != token_label:
            return False
        elif (
            self.vocab_subset is not None and
            token_label not in self.vocab_subset
        ):
            return False

        return True
