import torch 

class LogitsWarper:
    """Abstract base class for all logit warpers that can be applied during generation with multinomial sampling."""

    def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor) -> torch.FloatTensor:
        """Torch method for warping logits."""
        raise NotImplementedError(
            f"{self.__class__} is an abstract class. Only classes inheriting this class can be called."
        )
    
class DecimalWarper(LogitsWarper):
    def __init__(self, tokenizer, eval_model, filter_value: float = -float("Inf")):
        self.filter_value = filter_value
        decimals = ["0","1","2","3","4","5","6","7","8","9"]
        self.decimal_ids = []
        for d in decimals:
            if eval_model == "llama2":
                self.decimal_ids.append(tokenizer.encode(d)[2])
            else:
                self.decimal_ids.extend(tokenizer.encode(d))


    def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor) -> torch.FloatTensor:
        indices_to_keep = torch.zeros_like(scores,dtype=torch.bool)
        indices_to_keep[0][self.decimal_ids] = True
        indices_to_remove = torch.logical_not(indices_to_keep)
        scores = scores.masked_fill(indices_to_remove, self.filter_value)
        return scores