import copy
import inspect
import warnings
from typing import Callable, List, Optional, Union

import numpy as np
import ray
import torch
import torch.distributed as dist  # If using a synchronous GPU
import torch.nn.functional as F
from utils.logger import setup_custom_logger
from tqdm import tqdm
from transformers.generation import (
    GenerationConfig,
    LogitsProcessorList,
    StoppingCriteriaList,
    validate_stopping_criteria,
)
from transformers.generation.utils import (
    GenerateOutput,
    GenerationMode,
    GreedySearchOutput,
)

logger = setup_custom_logger("TSP")


def generate_ensemnble_response(
    model_actors_list,
    model_name_list,
    tokenizers,
    vocab_union,
    mapping_matrices,
    index_to_vocab,
    special_prefix_tokens_dict,
    byte_mappings_list,
    primary_index,
    threshold,
    until,
    **kwargs,
):
    # Initiate asynchronous preparation for text generation across multiple model actors.
    # This includes setting up variables like stopping_criteria, etc.
    refs = []
    for model_actor in model_actors_list:
        ref = model_actor.generate_prepare.remote(**kwargs)
        refs.append(ref)
    ray.get(refs)

    cached_output_ids = [[] for _ in ray.get(model_actors_list[0].get_input_ids.remote())]
    while True:
        # Request each model in the list to asynchronously predict the probability distribution of the next token.
        tmp_outputs_refs = [
            model_actor.get_one_token.remote() for model_actor in model_actors_list
        ]

        tmp = ray.get(tmp_outputs_refs)
        tmp_outputs = [t[0] for t in tmp]
        tmp_outputs_times = [t[1] for t in tmp]

        # This function extracts and logs the token with the highest probability from each model's output.
        process_and_log_model_outputs(tokenizers, model_name_list, tmp_outputs)

        # Merge probability distributions from different models to identify a unified token,
        # then map this token to corresponding IDs across models using tokenizer and vocabulary mappings.
        merged_token_ids = merge_and_convert_tokens(
            tmp_outputs,
            tokenizers,
            mapping_matrices,
            vocab_union,
            index_to_vocab,
            special_prefix_tokens_dict,
            byte_mappings_list,
            primary_index,
            threshold,
            tmp_outputs_times,
        )

        # check whether should early stopping
        cached_output_ids, merged_token_ids = check_until(until,cached_output_ids,tokenizers,merged_token_ids)

        # Update the state required for text generation in each model, such as attention masks,
        # input IDs, and past key-value pairs. This prepares each model for the next step of generation.
        refs = []
        for i, model_actor in enumerate(model_actors_list):
            ref = model_actor.update_input_ids_and_model_kwargs.remote(
                next_tokens_list=merged_token_ids[i]
            )
            refs.append(ref)
        ray.get(refs)

        # Retrieve the list of unfinished sequences from each model to determine if any sentence has finished.
        unfinished_sequences_list = [
            ray.get(model_actor.get_unfinished_sequences.remote())
            for model_actor in model_actors_list
        ]

        # Synchronize the status of unfinished sequences across all models, ensuring consistency in tracking which sentences are still being generated.
        synced_unfinished_sequences = synchronize_unfinished_sequences(
            unfinished_sequences_list
        )


        # Update each model with the synchronized status of unfinished sequences.
        update_refs = [
            model_actor.update_unfinished_sequences.remote(synced_unfinished_sequences)
            for model_actor in model_actors_list
        ]
        ray.get(update_refs)

        # Check across all models to determine if the text generation should stop, i.e., if any model has finished generating its sentence.
        finish_refs = [
            model_actor.check_if_stop.remote() for model_actor in model_actors_list
        ]
        finish = any(
            ray.get(finish_refs)
        )  # Determine if any model signals to stop generation.

        # If any model has completed its sentence, break out of the loop to stop the generation process.
        if finish:
            break

    return ray.get(model_actors_list[0].get_input_ids.remote())


def process_and_log_model_outputs(tokenizers, model_name_list, model_outputs):
    """
    Processes the outputs from multiple models and logs the most confident token predicted by each.

    Args:
        tokenizers (list): A list of tokenizer objects corresponding to each model.
        model_name_list (list): A list of model name.
        model_outputs (list): A list of tensors representing the output distributions from each model.
    """
    for output, tokenizer, model_name in zip(
        model_outputs, tokenizers, model_name_list
    ):
        # Extract the highest scoring token and its score for each model's output
        max_scores, max_indices = torch.max(output, dim=-1)
        decoded_tokens = [
            tokenizer.decode([idx], skip_special_tokens=False)
            for idx in max_indices.tolist()
        ]
        max_scores_list = [round(score.item(), 4) for score in max_scores]

        # Log the decoded token, its ID, and confidence score
        logger.info(
            f"Token from Model {model_name}: {decoded_tokens} (token id {max_indices.tolist()}) with Conf {max_scores_list}"
        )


def synchronize_unfinished_sequences(unfinished_sequences_list):
    """
    This function synchronously updates the unfinished_sequences tensor of all states in a list, so that if a position in any tensor is 0, all corresponding positions are also set to 0, provided that all tensors have the same shape.
    """

    device = unfinished_sequences_list[0].device

    # Check that the shapes of unfinished_sequences are consistent for all states
    first_shape = unfinished_sequences_list[0].shape
    for unfinished_sequences in unfinished_sequences_list:
        if unfinished_sequences.shape != first_shape:
            raise ValueError(
                "All 'unfinished_sequences' tensors must have the same shape."
            )

    # Initialize a tensor with all 1's and the same size as unfinished_sequences
    sync_tensor = torch.ones_like(unfinished_sequences_list[0]).to(device)

    # Iterate through all unfinished_sequences to find out which positions need to be set to 1
    for unfinished_sequences in unfinished_sequences_list:
        sync_tensor = torch.logical_and(sync_tensor, unfinished_sequences.to(device))

    # True/False values in sync_tensor are converted to 1/0
    sync_tensor = sync_tensor.long()  # Use .long() to convert True/False to 1/0

    return sync_tensor


def update_input_ids_and_model_kwargs(model, state):
    """
    Updates input_ids and model_kwargs for the next generation step in a language model,
    handling padding, attention mask adjustments, and tracking unfinished sequences.

    Args:
    model: The language generation model being used.
    state (dict): A dictionary containing various states needed for generation, including:
        - outputs: The output from the previous generation step.
        - input_ids: The input IDs used in the previous generation step.
        - next_tokens_list: The list of next tokens to be added to input_ids.
        - model_kwargs: Additional model keyword arguments.
        - unfinished_sequences: A boolean list indicating which sequences are not finished.
        - pad_token_id: The ID used for padding.
        - eos_token_id_tensor: The ID of the end-of-sequence token.

    Returns:
    tuple: A tuple containing:
        - padded_input_ids_tensor: The updated input_ids tensor after padding and adding next tokens.
        - model_kwargs: The updated model keyword arguments.
        - unfinished_sequences: The updated list indicating which sequences are still unfinished.

    The function pads input_ids and next_tokens to the same length, updates attention masks,
    handles sequences that are finished by replacing tokens with pad_token_id, and adjusts 
    model_kwargs for the next generation step. It also trims unnecessary padding from input_ids
    and attention_mask if any sequence has more than one token to add. Finally, it updates 
    unfinished_sequences based on the presence of the eos_token_id.
    """
    outputs = state["outputs"]
    input_ids = state["input_ids"]
    next_tokens = state["next_tokens_list"]
    model_kwargs = state["model_kwargs"]
    unfinished_sequences = state["unfinished_sequences"]
    pad_token_id = state["pad_token_id"]
    eos_token_id_tensor = state["eos_token_id_tensor"]

    # next_tokens = [[314], [50256,100]]

    # Check if pad_token_id is provided
    if pad_token_id is None:
        raise ValueError("pad_token_id must be defined.")

    # Replace next_tokens with pad_token_id where sequences are finished
    next_tokens = [
        tokens if unfinished else [pad_token_id] * len(tokens)
        for tokens, unfinished in zip(next_tokens, unfinished_sequences)
    ]

    # Determine the device of input_ids
    device = input_ids.device

    # Calculate the maximum length after adding next_tokens
    max_length = max([input_ids.shape[1] + len(tokens) for tokens in next_tokens])

    # Pad input_ids and next_tokens to the same length
    padded_input_ids = []
    attention_masks = []  # To store the updated attention masks
    for i, tokens in enumerate(next_tokens):
        # Calculate padding size for input_ids
        input_padding_size = max_length - input_ids.shape[1] - len(tokens)

        # Pad input_ids
        padded_input = torch.cat(
            [
                torch.full(
                    (1, input_padding_size),
                    pad_token_id,
                    dtype=torch.long,
                    device=device,
                ),
                input_ids[i].unsqueeze(0),
                torch.tensor(tokens, dtype=torch.long, device=device).unsqueeze(0),
            ],
            dim=1,
        )
        padded_input_ids.append(padded_input)

        # Update the attention mask
        if "attention_mask" in model_kwargs:
            original_attention_mask = model_kwargs["attention_mask"][i]
            updated_attention_mask = torch.cat(
                [
                    torch.zeros(input_padding_size, dtype=torch.long, device=device),
                    original_attention_mask,
                    torch.ones(len(tokens), dtype=torch.long, device=device),
                ]
            )
            attention_masks.append(updated_attention_mask)

    # Convert the list of padded input_ids to a tensor
    padded_input_ids_tensor = torch.cat(padded_input_ids, dim=0)
    model_kwargs = model._update_model_kwargs_for_generation(
        outputs, model_kwargs, is_encoder_decoder=model.config.is_encoder_decoder
    )

    # Update the attention masks in model_kwargs
    if attention_masks:
        model_kwargs["attention_mask"] = torch.stack(attention_masks)

    # Update model_kwargs, set past_key_values to None if any sequence has more than one token to add
    if any(len(tokens) > 1 for tokens in next_tokens):
        model_kwargs["past_key_values"] = None

        # Find the index of the first non-pad token for each sequence
        first_non_pad_indices = [
            input_id.ne(pad_token_id).nonzero(as_tuple=True)[0][0].item()
            if pad_token_id in input_id
            else 0
            for input_id in padded_input_ids_tensor
        ]

        # Calculate the maximum number of leading pads that can be removed (minimum index of the first non-pad token)
        max_pads_to_remove = min(first_non_pad_indices)

        # Remove the unnecessary leading pads
        if max_pads_to_remove > 0:

            padded_input_ids_tensor = padded_input_ids_tensor[:, max_pads_to_remove:]
            if "attention_mask" in model_kwargs:
                model_kwargs["attention_mask"] = model_kwargs["attention_mask"][
                    :, max_pads_to_remove:
                ]

    # Update unfinished_sequences based on eos_token_id
    if eos_token_id_tensor is not None:
        for i, tokens in enumerate(next_tokens):
            for token in tokens:
                unfinished_sequences[i] = unfinished_sequences[i] & (
                    token != eos_token_id_tensor
                )

    return padded_input_ids_tensor, model_kwargs, unfinished_sequences



def check_byte_mappings(tokenizer):
    """
    Args:
    - tokenizer: An object representing a tokenizer. This tokenizer object must have a method
                 `get_vocab()` that returns a dictionary mapping tokens to their respective
                 token IDs within the tokenizer's vocabulary.

    Returns:
    - If the tokenizer is identified as BBPE based on prefix counts, returns a dictionary for byte values from '<0x00>' to '<0x7F>'.
    - Otherwise, returns a byte_mapping (dict): A dictionary where each key is a string representing a byte value in
                           standard hex format (e.g., '<0x00>', '<0x01>', ..., '<0xFF>'), and each
                           value is the corresponding token ID for that byte representation
                           within the tokenizer's vocabulary.
    """
    vocab = tokenizer.get_vocab()
    g_prefix_count = sum(token.startswith("Ġ") for token in vocab)
    u_prefix_count = sum(token.startswith("▁") for token in vocab)

    byte_mapping = {}

    # For BBPE, handle bytes from 0x00 to 0x7F
    if g_prefix_count > u_prefix_count:
        for byte_val in range(128):  # Limit to 0x00 to 0x7F
            byte_char = chr(byte_val)
            token_id = tokenizer.convert_tokens_to_ids(tokenizer.tokenize(byte_char))[0]
            hex_token = f"<0x{byte_val:02X}>"
            byte_mapping[hex_token] = token_id
    else:
        # For non-BBPE, attempt to find a direct mapping in vocab
        for byte_val in range(256):
            hex_token = f"<0x{byte_val:02X}>"
            # For cases like "\t" being replaced in vocab
            if hex_token == "<0x09>" and hex_token not in vocab:
                continue
            if hex_token not in vocab:
                raise ValueError(f"Token {hex_token} not found in tokenizer's vocabulary.")
            byte_mapping[hex_token] = vocab[hex_token]

    return byte_mapping


def get_vocab_union_and_mapping(tokenizers):
    """
    Modified function that creates a union of tokens from the vocabularies of given tokenizers and
    provides a mapping for each tokenizer from its token IDs to the tokens in the unified vocabulary.
    It handles tokens starting with 'Ġ' or '▁' differently to merge similar tokens.

    Args:
    tokenizers (list): A list of tokenizer objects, each with a 'get_vocab()' method that
                       returns a dictionary of tokens and their corresponding IDs in the tokenizer's
                       vocabulary.

    Returns:
    tuple: A tuple containing three elements:
        - vocab_union (set): A set containing the union of all tokens in the vocabularies of the
                             provided tokenizers.
        - tokenizers_mapping (list): A list of dictionaries, where each dictionary corresponds to
                                     a tokenizer from the input list and maps token IDs from the
                                     tokenizer to tokens in the vocab_union.
        - index_to_vocab (dict): A dictionary mapping from unique index to tokens in the vocab_union.
        - byte_mappings_list (list): A list of dictionaries, where each dictionary corresponds to a
                                tokenizer from the input list and provides a mapping of byte value
                                tokens from '<0x00>' to '<0xFF>' to their original token IDs in the
                                tokenizer's vocabulary. This mapping is used to ensure consistency
                                and to facilitate the identification and replacement of these tokens
                                in the unified vocabulary.
    """
    # Initialize a set to store all tokens
    vocab_union = set()
    # Initialize a list to store the mappings for each tokenizer
    tokenizers_mapping = []
    byte_mappings_list = []

    # 首先统一添加'<0x00>'到'<0xFF>'
    for byte_val in range(256):
        vocab_union.add(f"<0x{byte_val:02X}>")

    # Process each tokenizer separately
    for tokenizer in tokenizers:
        vocab = tokenizer.get_vocab()
        token_set = set()
        mapping = {}

        # Check and record the mapping of each tokenizer for '<0x00>' to '<0xFF>'
        byte_mapping = check_byte_mappings(tokenizer)
        byte_mappings_list.append(byte_mapping)

        if len(byte_mapping) == 128:
            logger.info(
                "BBPE detected."
            )

        # Remove existing '<0x00>' to '<0xFF>' mapping
        for hex_token, token_id in byte_mapping.items():
            # Remove token in vacob where tokenid appears in byte_mapping
            actual_tokens = [token for token, id in vocab.items() if id == token_id]

            if len(actual_tokens) != 1:
                # Raise an error if more than one matching token is found
                raise ValueError(
                    f"Multiple tokens/ Zero token found for token ID {token_id} in tokenizer's vocabulary."
                )
            del vocab[actual_tokens[0]]  # Delete the token from the original vocab

        # Detect usage of 'Ġ' and '▁'
        g_prefix_count = sum(token.startswith("Ġ") for token in vocab)
        u_prefix_count = sum(token.startswith("▁") for token in vocab)

        # Process tokens based on prefix type
        if g_prefix_count > u_prefix_count:
            # Handle tokens starting with 'Ġ'
            for token, token_id in vocab.items():
                processed_token = token.replace("Ġ", " ").replace("Ċ", "\n")
                token_set.add(processed_token)
                mapping[token_id] = processed_token
        else:
            # Handle tokens starting with '▁'
            for token, token_id in vocab.items():
                if token.startswith("▁"):
                    processed_token = token.replace("▁", " ")
                else:
                    # For tokens without '▁', use the decode method
                    processed_token = token  # tokenizer.decode([token_id])
                token_set.add(processed_token)
                mapping[token_id] = processed_token

        # Merge into the total vocab_union
        vocab_union = vocab_union.union(token_set)
        # Append the mapping for this tokenizer to the list
        tokenizers_mapping.append(mapping)

    # Generate a mapping for each token in the union to a unique index
    vocab_to_index = {token: i for i, token in enumerate(vocab_union)}

    # Convert vocab_to_index to index_to_vocab
    index_to_vocab = {index: token for token, index in vocab_to_index.items()}

    for tokenizer, byte_mapping, mapping in zip(
        tokenizers, byte_mappings_list, tokenizers_mapping
    ):
        # Update the mappings for each tokenizer to map to the index in the unified vocab
        for token_id, token in mapping.items():
            mapping[token_id] = vocab_to_index[token]

        # 定义扩展映射字典
        bbpe_mapping = {
            **{f"<0x{hex(i)[2:].upper()}>": chr(i) for i in range(0x30, 0x3A)},  # 映射 '0' 到 '9'
            **{f"<0x{hex(i)[2:].upper()}>": chr(i) for i in range(0x41, 0x5B)},  # 映射 'A' 到 'Z'
            **{f"<0x{hex(i)[2:].upper()}>": chr(i) for i in range(0x61, 0x7B)}   # 映射 'a' 到 'z'
        }

        # Patch '<0x00>' to '<0xFF>' mapping for each tokenizer
        for hex_token, original_token_id in byte_mapping.items():
            # First check the original complex conditions
            if not all(len(bm) == 128 for bm in byte_mappings_list) and len(byte_mapping) == 128:
                # Special handling of specified characters
                if hex_token in bbpe_mapping:
                    logger.warning(f"We force-mapped the BBPE {hex_token} to {bbpe_mapping[hex_token]} in union vocab")
                    mapping[original_token_id] = vocab_to_index[bbpe_mapping[hex_token]]
                    continue
            # Typically the mapping
            mapping[original_token_id] = vocab_to_index[hex_token]     

    return vocab_union, tokenizers_mapping, index_to_vocab, byte_mappings_list


def merge_and_convert_tokens_cpu(
    outputs,
    tokenizers,
    tokenizers_mapping,
    vocab_union,
    index_to_vocab,
    special_prefix_token,
    byte_mappings_list,
):
    """
    Merges the probability vectors from multiple models' outputs and converts the 
    highest probability tokens into corresponding token IDs for each tokenizer. The 
    function also handles special token replacements to ensure correct formatting and
    uses a special prefix token for tokenization processes.

    Args:
    outputs (list): A list of model output tensors, each containing probability vectors.
    tokenizers (list): A list of tokenizer objects used by the corresponding models.
    tokenizers_mapping (list): A list of dictionaries, each mapping token IDs from a tokenizer
                               to token id in the unified vocabulary.
    vocab_union (set): A set containing the union of all tokens from the tokenizers' vocabularies.
    index_to_vocab (dict): A dictionary mapping from unique index to tokens in the vocab_union.
    special_prefix_token (dict): A dictionary mapping each tokenizer to its special prefix token, 
                                 used as a reference point for comparison in tokenization.

    Returns:
    list: A nested list of token IDs, where each inner list corresponds to the token IDs 
          for each tokenizer, based on the highest probability token from the merged output.
    """
    # Initialize merged probability vector
    merged_probs = np.zeros((len(outputs[0]), len(vocab_union)))
    eos_token_list = [tokenizer.eos_token for tokenizer in tokenizers]

    # Merge probability vectors
    for output, mapping in zip(outputs, tokenizers_mapping):
        output_np = output.detach().cpu().numpy()  # Assuming PyTorch tensors
        for token_id, unified_token_index in mapping.items():
            # Adjusting for index-based access to probabilities
            merged_probs[:, unified_token_index] += output_np[:, token_id]
        # print(np.argmax(output_np, axis=1))
    # Get the highest probability token index in unified vocabulary
    max_token_indices = np.argmax(merged_probs, axis=1)
    # Directly use index_to_vocab for looking up tokens
    max_tokens = [index_to_vocab[index] for index in max_token_indices]

    logger.info(f"Token from ALL Model: {str(max_tokens)}\n")

    # print(max_token_indices,max_tokens)

    # Convert to token IDs for each tokenizer
    batch_token_ids = [
        [] for _ in range(len(tokenizers))
    ]  # Initialize list for each model
    for i, tokenizer in enumerate(tokenizers):
        for token in max_tokens:
            if token in eos_token_list:
                token_id = [tokenizer.eos_token_id]
            else:
                # Convert token to corresponding tokenizer's token IDs using special_prefix_token
                token_id = get_token_ids(
                    tokenizer,
                    token,
                    special_prefix_token[tokenizer],
                    byte_mappings_list[i],
                )

            batch_token_ids[i].append(token_id)  # Append token IDs for each batch

    return batch_token_ids


def create_mapping_matrix(mapping, union_vocab_size, model_vocab_size):
    """
    Creates a sparse tensor mapping matrix for vocabulary translation.
    
    Args:
    - mapping (dict): Maps model token IDs to unified vocabulary indexes.
    - union_vocab_size (int): Size of the unified vocabulary.
    - model_vocab_size (int): Size of the model's vocabulary.
    
    Returns:
    - torch.sparse_coo_tensor: Sparse tensor in COO format with shape [model_vocab_size, union_vocab_size].
                               Each non-zero element (i, j) indicates a mapping from the i-th token in the
                               model's vocabulary to the j-th token in the unified vocabulary.
    """

    if model_vocab_size == 151646:
        logger.warning(
            "The qwen1.5 series has been detected, where the length of tokenizer.get_vocab() and the vocab_size in the model config are inconsistent. We have forcefully set it to the latter. https://github.com/QwenLM/Qwen1.5/issues/29"
        )
        model_vocab_size = 151936

    indices = []
    values = []

    for model_token_id, unified_token_index in mapping.items():
        indices.append([model_token_id, unified_token_index])
        values.append(1.0)

    # Convert Tensor to fit COO format
    indices = torch.tensor(indices, dtype=torch.long).t()
    values = torch.tensor(values, dtype=torch.float)

    # Creating a sparse tensor
    size = torch.Size([model_vocab_size, union_vocab_size])
    mapping_matrix = torch.sparse_coo_tensor(indices, values, size, device="cuda")

    return mapping_matrix

def check_until(
    until,
    cached_batch_output_ids,
    tokenizers,
    merged_token_ids,
):
    """ 
    Args:
    until (list of str): List of text for early stopping.
    cached_batch_output_ids (str): Cached output ids for until early stopping (batch,)
    """
    if len(cached_batch_output_ids) != len(merged_token_ids[0]):
        raise ValueError(f"len(cached_batch_output_ids):{len(cached_batch_output_ids)} != len(merged_token_ids[0]): {len(merged_token_ids[0])}")
    for i,_ in enumerate(cached_batch_output_ids):
        cached_batch_output_ids[i] = cached_batch_output_ids[i] + merged_token_ids[0][i]
        tmp_text = tokenizers[0].decode(cached_batch_output_ids[i])
    
        if until:
            for stop_txt in until:
                if stop_txt in tmp_text:
                    for j, tokenizer in enumerate(tokenizers):
                        merged_token_ids[j][i]= merged_token_ids[j][i] + [tokenizer.eos_token_id]
                    break
    return cached_batch_output_ids, merged_token_ids


def merge_and_convert_tokens(
    outputs,
    tokenizers,
    mapping_matrices,
    vocab_union,
    index_to_vocab,
    special_prefix_token,
    byte_mappings_list,
    primary_index,
    threshold,
    tmp_outputs_times,
):
    """
    Merges the probability vectors from multiple models' outputs and converts the 
    highest probability tokens into corresponding token IDs for each tokenizer. The 
    function also handles special token replacements to ensure correct formatting and
    uses a special prefix token for tokenization processes.

    Args:
    outputs (list): A list of model output tensors, each containing probability vectors.
    tokenizers (list): A list of tokenizer objects used by the corresponding models.
    mapping_matrices (List[torch.sparse_coo_tensor]): A list of sparse COO tensors, each representing
    a mapping matrix from a model's tokenizer token IDs to the token IDs in the unified vocabulary.
    Each matrix corresponds to a tokenizer and maps its original token IDs to new token IDs in the
    unified vocabulary. The shape of each matrix is [model_vocab_size, len(vocab_union)], where
    model_vocab_size is the size of the tokenizer's vocabulary.
    vocab_union (set): A set containing the union of all tokens from the tokenizers' vocabularies.
    index_to_vocab (dict): A dictionary mapping from unique index to tokens in the vocab_union.
    special_prefix_token (dict): A dictionary mapping each tokenizer to its special prefix token, 
                                 used as a reference point for comparison in tokenization.
    primary_index(int): -1 or n, -1 will ensemble every token
    threshold(float): tokens with conf lower than threshold will be ensembled.
    recorder (EnsembleRecorder): An object used to log the use of ensembling during token generation,
                                recording details such as the count and indices of ensembled tokens.
    tmp_outputs_times (list of float): Consumed time for each sample in the batch.
                            
    Returns:
    list: A nested list of token IDs, where each inner list corresponds to the token IDs 
          for each tokenizer, based on the highest probability token from the merged output.
    """
    eos_token_list = [tokenizer.eos_token for tokenizer in tokenizers]
    eos_token_list.extend(["<|end_of_text|>","<|endoftext|>","<|im_end|>","<|end|>"])

    # Initialize the merged probability vector, stored on the GPU
    merged_probs = torch.zeros((outputs[0].size(0), len(vocab_union)), device="cuda")

    # Tokens with conf lower than threshold will be ensembled
    if primary_index != -1:
        max_probs, _ = torch.max(outputs[primary_index], dim=1) # Calculate the maximum value of each sample
        mask = max_probs > threshold # 創建一個布爾mask，其中超過0.x的標記為True
        logger.info(mask)
        for i, (output, mapping_matrix) in enumerate(zip(outputs, mapping_matrices)):
            if i == primary_index:
                transformed_probs = torch.sparse.mm(output, mapping_matrix)
                merged_probs += transformed_probs
            else:
                # For other outputs, first set the value to 0 for samples where mask is True
                output_masked = output.clone()
                output_masked[mask] = 0

                # Then convert normally and accumulate to merged_probs
                transformed_probs = torch.sparse.mm(output_masked, mapping_matrix)
                merged_probs += transformed_probs

    # Ensembling every token
    else:
        mask = torch.zeros(outputs[0].size(0), dtype=torch.bool)  # Create a one-dimensional mask that is all True, meaning that each token will be ensemble
        for output, mapping_matrix in zip(outputs, mapping_matrices):
            transformed_probs = torch.sparse.mm(output, mapping_matrix)
            merged_probs += transformed_probs

    max_token_indices = torch.argmax(merged_probs, dim=1)
    max_tokens = [index_to_vocab[index.item()] for index in max_token_indices]
    # max_tokens = ["a"] # For speed testing
    logger.info(f"Token from ALL Model: {str(max_tokens)}\n")

    # Convert to token IDs for each tokenizer
    batch_token_ids = [
        [] for _ in range(len(tokenizers))
    ]  # Initialize list for each model
    for i, tokenizer in enumerate(tokenizers):
        for token in max_tokens:
            if token in eos_token_list:
                token_id = [tokenizer.eos_token_id]
            else:
                # Convert token to corresponding tokenizer's token IDs using special_prefix_token
                token_id = get_token_ids(
                    tokenizer,
                    token,
                    special_prefix_token[tokenizer],
                    byte_mappings_list[i],
                )

            batch_token_ids[i].append(token_id)  # Append token IDs for each batch

    # record ensemble info
    # recorder.update(mask, [[len(sample) for sample in batch] for batch in batch_token_ids])
    # recorder._estimate(tmp_outputs_times)

    return batch_token_ids


def get_token_ids(tokenizer, token, special_prefix_token, byte_mapping):
    """
    Tokenizes a given token and a special prefix token from the tokenizer's vocabulary, 
    then finds the token IDs for the portion of the given token that does not overlap 
    with the special prefix token. It is particularly useful for identifying unique sub-tokens 
    in tokenization processes. If initial tokenization does not meet expectations,
    it tries using ';' as an alternate special prefix token.

    Args:
    tokenizer: An instance of a tokenizer class with an 'encode' method that converts
               text to a list of token IDs.
    token (str): The token to be tokenized and analyzed.
    special_prefix_token (str): A special prefix token from the tokenizer's vocabulary, used as a 
                                reference point for comparison. It is the shortest token starting with 
                                a specific prefix ('▁' in most cases), which is neither part of any 
                                other token nor contains any other token.
    byte_mapping (dict): A dictionary mapping standard byte representations ('<0x00>' to '<0xFF>')
                         to their token IDs in the tokenizer's vocabulary.

    Returns:
    list: A list of token IDs representing the non-overlapping part of the 'token'
          when tokenized, compared to the tokenization of 'special_prefix_token'.

    The function tries using the provided special_prefix_token, and if tokenization doesn't match as expected,
    it attempts using ';' as an alternate special_prefix_token. If it still doesn't match, it returns
    the token IDs for 'token'.
    """

    # Check if the token is a standard byte representation and return its token ID if found
    if token in byte_mapping:
        return [byte_mapping[token]]

    if byte_mapping != 128:
        prefix_tokens = [special_prefix_token, ";"]

        for prefix_token in prefix_tokens:
            # Tokenize individually
            token_id_list1 = tokenizer.encode(prefix_token, add_special_tokens=False)

            # Tokenize doubled token
            token_id_list2 = tokenizer.encode(
                prefix_token + token, add_special_tokens=False
            )

            # Check if the start of token_id_list2 matches token_id_list1
            if token_id_list2[: len(token_id_list1)] == token_id_list1:
                result = token_id_list2[len(token_id_list1) :]
                if result:
                    return result

        # If tokenization doesn't match as expected with any prefix token, return the token IDs for 'token'
        logger.warning(f"Warning: Token '{token}' may not be tokenized as expected.")
    return tokenizer.encode(token, add_special_tokens=False)


def find_special_underscore_token(tokenizer):
    """
    Identifies the shortest special token in the tokenizer's vocabulary that starts with '▁',
    which is neither part of any other token nor contains any other token (except '▁' itself).
    '▁' itself and tokens resulting in only whitespace after '▁' is removed are also excluded 
    from the result.
    
    Args:
        tokenizer: An instance of a tokenizer class with a 'get_vocab()' method, returning 
                   a dictionary of tokens and their IDs.

    Returns:
        str: The shortest special token meeting the criteria, with '▁' removed, sorted
             lexicographically to ensure consistency. Raises an error if no such token is found.

    The function first checks the prevalence of tokens starting with 'Ġ' and '▁'. If tokens
    starting with 'Ġ' are more prevalent, it returns an empty string. Otherwise, it proceeds
    to find the shortest token starting with '▁', which is not part of any other token and
    does not contain any other tokens (except for the initial '▁'), and is not just whitespace
    after '▁' is removed. It then removes '▁' from the token before returning it. If no such
    token is found, an error is raised.
    """

    # Get the tokenizer's vocabulary
    vocab = tokenizer.get_vocab()

    # Counting tokens starting with 'Ġ' and '▁'
    count_prefix_G = sum(1 for token in vocab if token.startswith("Ġ"))
    count_prefix_underscore = sum(1 for token in vocab if token.startswith("▁"))

    if count_prefix_G > count_prefix_underscore:
        return ""

    # Filter tokens starting with '▁'
    underscore_tokens = [
        token for token in vocab if token.startswith("▁") and token != "▁"
    ]

    special_tokens = []
    for token in tqdm(underscore_tokens, desc="Searching tokens"):
        cleaned_token = token[1:]  # Remove '▁'
        if (
            not any(
                token in other_token
                for other_token in underscore_tokens
                if other_token != token
            )
            and token.count("▁") == 1
            and cleaned_token.strip() != ""
        ):
            special_tokens.append(cleaned_token)

    if not special_tokens:
        raise ValueError("No special underscore token found that meets the criteria.")

    # Returns the token with the smallest dictionary order for consistency
    return min(special_tokens, key=lambda x: (len(x), x))


def get_special_prefix_tokens_for_all(tokenizers):
    """
    This function takes a list of tokenizers and returns a dictionary where each tokenizer is 
    associated with its special prefix token. It utilizes a hypothetical function find_special_underscore_token
    which is assumed to return the special prefix token that each individual tokenizer can handle.
    
    Args:
    tokenizers (list): A list of tokenizer objects. Each tokenizer is assumed to have a 
                       method or functionality that allows the extraction of its special prefix token.
    
    Returns:
    dict: A dictionary where each key is a tokenizer from the input list, and the corresponding 
          value is the special prefix token that the tokenizer can handle, as determined by calling 
          the find_special_underscore_token function.
          
    Example:
    tokenizers = [tokenizer1, tokenizer2, ...]
    special_prefix_tokens = get_special_prefix_tokens_for_all(tokenizers)
    print(special_prefix_tokens)  # Output: {tokenizer1: special_prefix_token1, tokenizer2: special_prefix_token2, ...}
    """

    # Initialize an empty dictionary to store the results
    special_prefix_tokens = {}

    # Iterate through the list of tokenizers
    for tokenizer in tokenizers:
        if tokenizer.vocab_size == 256000:
            logger.info("gemma-it detected, use '¢' as special_prefix_token")
            special_prefix_tokens[tokenizer] = "¢"
            continue
        # Get the special prefix token for each tokenizer
        token = find_special_underscore_token(tokenizer)
        # Store the tokenizer and its special prefix token in the dictionary
        special_prefix_tokens[tokenizer] = token
    return special_prefix_tokens


def greedy_search(
    model,
    input_ids: torch.LongTensor,
    logits_processor: Optional[LogitsProcessorList] = None,
    stopping_criteria: Optional[StoppingCriteriaList] = None,
    max_length: Optional[int] = None,
    pad_token_id: Optional[int] = None,
    eos_token_id: Optional[Union[int, List[int]]] = None,
    output_attentions: Optional[bool] = None,
    output_hidden_states: Optional[bool] = None,
    output_scores: Optional[bool] = None,
    return_dict_in_generate: Optional[bool] = None,
    synced_gpus: bool = False,
    streamer: Optional["BaseStreamer"] = None,
    **model_kwargs,
) -> Union[GreedySearchOutput, torch.LongTensor]:
    r"""
    Generates sequences of token ids for models with a language modeling head using **greedy decoding** and can be
    used for text-decoder, text-to-text, speech-to-text, and vision-to-text models.

    <Tip warning={true}>

    In most cases, you do not need to call [`~generation.GenerationMixin.greedy_search`] directly. Use generate()
    instead. For an overview of generation strategies and code examples, check the [following
    guide](../generation_strategies).

    </Tip>


    Parameters:
        input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`):
            The sequence used as a prompt for the generation.
        logits_processor (`LogitsProcessorList`, *optional*):
            An instance of [`LogitsProcessorList`]. List of instances of class derived from [`LogitsProcessor`]
            used to modify the prediction scores of the language modeling head applied at each generation step.
        stopping_criteria (`StoppingCriteriaList`, *optional*):
            An instance of [`StoppingCriteriaList`]. List of instances of class derived from [`StoppingCriteria`]
            used to tell if the generation loop should stop.

        max_length (`int`, *optional*, defaults to 20):
            **DEPRECATED**. Use `logits_processor` or `stopping_criteria` directly to cap the number of generated
            tokens. The maximum length of the sequence to be generated.
        pad_token_id (`int`, *optional*):
            The id of the *padding* token.
        eos_token_id (`Union[int, List[int]]`, *optional*):
            The id of the *end-of-sequence* token. Optionally, use a list to set multiple *end-of-sequence* tokens.
        output_attentions (`bool`, *optional*, defaults to `False`):
            Whether or not to return the attentions tensors of all attention layers. See `attentions` under
            returned tensors for more details.
        output_hidden_states (`bool`, *optional*, defaults to `False`):
            Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors
            for more details.
        output_scores (`bool`, *optional*, defaults to `False`):
            Whether or not to return the prediction scores. See `scores` under returned tensors for more details.
        return_dict_in_generate (`bool`, *optional*, defaults to `False`):
            Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple.
        synced_gpus (`bool`, *optional*, defaults to `False`):
            Whether to continue running the while loop until max_length (needed for ZeRO stage 3)
        streamer (`BaseStreamer`, *optional*):
            Streamer object that will be used to stream the generated sequences. Generated tokens are passed
            through `streamer.put(token_ids)` and the streamer is responsible for any further processing.
        model_kwargs:
            Additional model specific keyword arguments will be forwarded to the `forward` function of the model.
            If model is an encoder-decoder model the kwargs should include `encoder_outputs`.

    Return:
        [`~generation.GreedySearchDecoderOnlyOutput`], [`~generation.GreedySearchEncoderDecoderOutput`] or
        `torch.LongTensor`: A `torch.LongTensor` containing the generated tokens (default behaviour) or a
        [`~generation.GreedySearchDecoderOnlyOutput`] if `model.config.is_encoder_decoder=False` and
        `return_dict_in_generate=True` or a [`~generation.GreedySearchEncoderDecoderOutput`] if
        `model.config.is_encoder_decoder=True`.

    Examples:

    ```python
    >>> from transformers import (
    ...     AutoTokenizer,
    ...     AutoModelForCausalLM,
    ...     LogitsProcessorList,
    ...     MinLengthLogitsProcessor,
    ...     StoppingCriteriaList,
    ...     MaxLengthCriteria,
    ... )

    >>> tokenizer = AutoTokenizer.from_pretrained("gpt2")
    >>> model = AutoModelForCausalLM.from_pretrained("gpt2")

    >>> # set pad_token_id to eos_token_id because GPT2 does not have a PAD token
    >>> model.generation_config.pad_token_id = model.generation_config.eos_token_id

    >>> input_prompt = "It might be possible to"
    >>> input_ids = tokenizer(input_prompt, return_tensors="pt").input_ids

    >>> # instantiate logits processors
    >>> logits_processor = LogitsProcessorList(
    ...     [
    ...         MinLengthLogitsProcessor(10, eos_token_id=model.generation_config.eos_token_id),
    ...     ]
    ... )
    >>> stopping_criteria = StoppingCriteriaList([MaxLengthCriteria(max_length=20)])

    >>> outputs = model.greedy_search(
    ...     input_ids, logits_processor=logits_processor, stopping_criteria=stopping_criteria
    ... )

    >>> tokenizer.batch_decode(outputs, skip_special_tokens=True)
    ["It might be possible to get a better understanding of the nature of the problem, but it's not"]
    ```"""

    # init values
    logits_processor = (
        logits_processor if logits_processor is not None else LogitsProcessorList()
    )
    stopping_criteria = (
        stopping_criteria if stopping_criteria is not None else StoppingCriteriaList()
    )
    if max_length is not None:
        warnings.warn(
            "`max_length` is deprecated in this function, use"
            " `stopping_criteria=StoppingCriteriaList([MaxLengthCriteria(max_length=max_length)])` instead.",
            UserWarning,
        )
        stopping_criteria = validate_stopping_criteria(stopping_criteria, max_length)
    pad_token_id = (
        pad_token_id
        if pad_token_id is not None
        else model.generation_config.pad_token_id
    )
    eos_token_id = (
        eos_token_id
        if eos_token_id is not None
        else model.generation_config.eos_token_id
    )
    if isinstance(eos_token_id, int):
        eos_token_id = [eos_token_id]
    eos_token_id_tensor = (
        torch.tensor(eos_token_id).to(input_ids.device)
        if eos_token_id is not None
        else None
    )
    output_scores = (
        output_scores
        if output_scores is not None
        else model.generation_config.output_scores
    )
    output_attentions = (
        output_attentions
        if output_attentions is not None
        else model.generation_config.output_attentions
    )
    output_hidden_states = (
        output_hidden_states
        if output_hidden_states is not None
        else model.generation_config.output_hidden_states
    )
    return_dict_in_generate = (
        return_dict_in_generate
        if return_dict_in_generate is not None
        else model.generation_config.return_dict_in_generate
    )

    # init attention / hidden states / scores tuples
    scores = () if (return_dict_in_generate and output_scores) else None
    # decoder_attentions = () if (return_dict_in_generate and output_attentions) else None
    # cross_attentions = () if (return_dict_in_generate and output_attentions) else None
    # decoder_hidden_states = (
    #     () if (return_dict_in_generate and output_hidden_states) else None
    # )

    # if model is an encoder-decoder, retrieve encoder attention weights and hidden states
    # if return_dict_in_generate and model.config.is_encoder_decoder:
    #     encoder_attentions = (
    #         model_kwargs["encoder_outputs"].get("attentions")
    #         if output_attentions
    #         else None
    #     )
    #     encoder_hidden_states = (
    #         model_kwargs["encoder_outputs"].get("hidden_states")
    #         if output_hidden_states
    #         else None
    #     )

    if model.config.is_encoder_decoder:
        raise Exception("We only support decorder arch!")

    # keep track of which sequences are already finished
    unfinished_sequences = torch.ones(
        input_ids.shape[0], dtype=torch.long, device=input_ids.device
    )

    this_peer_finished = False  # used by synced_gpus only

    return {
        "input_ids": input_ids,
        "model_kwargs": model_kwargs,
        "output_attentions": output_attentions,
        "output_hidden_states": output_hidden_states,
        "stopping_criteria": stopping_criteria,
        "logits_processor": logits_processor,
        "scores": scores,
        "pad_token_id": pad_token_id,
        "eos_token_id_tensor": eos_token_id_tensor,
        "unfinished_sequences": unfinished_sequences,
        "this_peer_finished": this_peer_finished,
    }


def get_one_token(model, state):
    """
    Generates the scores for the next token in the sequence using the provided model
    and updates the state with the results.

    Args:
    model: The language generation model being used.
    state (dict): A dictionary containing the state required for generation, including:
        - input_ids: The input IDs for the current generation step.
        - model_kwargs: Additional keyword arguments for the model.
        - output_attentions: Boolean, whether to return attentions weights.
        - output_hidden_states: Boolean, whether to return hidden states.
        - logits_processor: Function to process logits (e.g., applying temperature).

    Returns:
    tuple: A tuple containing:
        - next_tokens_scores(batch_size, vocabulary_size): The softmax-normalized scores for 
        the next token in the sequence.
        - outputs: The model's outputs, including logits, attentions, and hidden states.

    The function prepares model inputs, performs a forward pass to get the logits for the next token,
    processes these logits using the provided logits_processor, and then applies softmax to get
    the normalized scores for the next token. It returns these scores along with the model's outputs.
    """
    input_ids = state["input_ids"]
    model_kwargs = state["model_kwargs"]
    output_attentions = state["output_attentions"]
    output_hidden_states = state["output_hidden_states"]
    logits_processor = state["logits_processor"]

    # prepare model inputs
    model_inputs = model.prepare_inputs_for_generation(input_ids, **model_kwargs)

    # disable kv cache for speed testing
    # model_inputs['use_cache'] = False
    # model_inputs['past_key_values'] = None

    with torch.no_grad():
        # forward pass to get next token
        outputs = model(
            **model_inputs,
            return_dict=True,
            output_attentions=output_attentions,
            output_hidden_states=output_hidden_states,
        )

    next_token_logits = outputs.logits[:, -1, :]

    # pre-process distribution
    next_tokens_scores = logits_processor(input_ids, next_token_logits)

    # Apply softmax to the scores
    next_tokens_scores = F.softmax(next_tokens_scores, dim=-1)

    return next_tokens_scores, outputs


def generate_prepare(
    model,
    inputs: Optional[torch.Tensor] = None,
    generation_config: Optional[GenerationConfig] = None,
    logits_processor: Optional[LogitsProcessorList] = None,
    stopping_criteria: Optional[StoppingCriteriaList] = None,
    prefix_allowed_tokens_fn: Optional[Callable[[int, torch.Tensor], List[int]]] = None,
    synced_gpus: Optional[bool] = None,
    assistant_model: Optional["PreTrainedModel"] = None,
    streamer: Optional["BaseStreamer"] = None,
    negative_prompt_ids: Optional[torch.Tensor] = None,
    negative_prompt_attention_mask: Optional[torch.Tensor] = None,
    **kwargs,
) -> Union[GenerateOutput, torch.LongTensor]:

    # 1. Handle `generation_config` and kwargs that might update it, and validate the `.generate()` call
    model._validate_model_class()

    # priority: `generation_config` argument > `model.generation_config` (the default generation config)
    if generation_config is None:
        # legacy: users may modify the model configuration to control generation. To trigger this legacy behavior,
        # two conditions must be met
        # 1) the generation config must have been created from the model config (`_from_model_config` field);
        # 2) the generation config must have seen no modification since its creation (the hash is the same).
        if (
            model.generation_config._from_model_config
            and model.generation_config._original_object_hash
            == hash(model.generation_config)
        ):
            new_generation_config = GenerationConfig.from_model_config(model.config)
            if new_generation_config != model.generation_config:
                warnings.warn(
                    "You have modified the pretrained model configuration to control generation. This is a"
                    " deprecated strategy to control generation and will be removed soon, in a future version."
                    " Please use and modify the model generation configuration (see"
                    " https://huggingface.co/docs/transformers/generation_strategies#default-text-generation-configuration )"
                )
                model.generation_config = new_generation_config
        generation_config = model.generation_config

    generation_config = copy.deepcopy(generation_config)

    model_kwargs = generation_config.update(
        **kwargs
    )  # All unused kwargs must be model kwargs

    generation_config.validate()
    model._validate_model_kwargs(model_kwargs.copy())

    # 2. Set generation parameters if not already defined
    logits_processor = (
        logits_processor if logits_processor is not None else LogitsProcessorList()
    )
    stopping_criteria = (
        stopping_criteria if stopping_criteria is not None else StoppingCriteriaList()
    )

    if (
        generation_config.pad_token_id is None
        and generation_config.eos_token_id is not None
    ):
        if model_kwargs.get("attention_mask", None) is None:
            logger.warning(
                "The attention mask and the pad token id were not set. As a consequence, you may observe "
                "unexpected behavior. Please pass your input's `attention_mask` to obtain reliable results."
            )
        eos_token_id = generation_config.eos_token_id
        if isinstance(eos_token_id, list):
            eos_token_id = eos_token_id[0]
        logger.warning(
            f"Setting `pad_token_id` to `eos_token_id`:{eos_token_id} for open-end generation."
        )
        generation_config.pad_token_id = eos_token_id

    # 3. Define model inputs
    # inputs_tensor has to be defined
    # model_input_name is defined if model-specific keyword input is passed
    # otherwise model_input_name is None
    # all model-specific keyword inputs are removed from `model_kwargs`
    inputs_tensor, model_input_name, model_kwargs = model._prepare_model_inputs(
        inputs, generation_config.bos_token_id, model_kwargs
    )

    batch_size = inputs_tensor.shape[0]

    # 4. Define other model kwargs
    model_kwargs["output_attentions"] = generation_config.output_attentions
    model_kwargs["output_hidden_states"] = generation_config.output_hidden_states
    # decoder-only models with inputs_embeds forwarding must use caching (otherwise we can't detect whether we are
    # generating the first new token or not, and we only want to use the embeddings for the first new token)
    if not model.config.is_encoder_decoder and model_input_name == "inputs_embeds":
        model_kwargs["use_cache"] = True
    else:
        model_kwargs["use_cache"] = generation_config.use_cache

    accepts_attention_mask = "attention_mask" in set(
        inspect.signature(model.forward).parameters.keys()
    )
    requires_attention_mask = "encoder_outputs" not in model_kwargs

    if (
        model_kwargs.get("attention_mask", None) is None
        and requires_attention_mask
        and accepts_attention_mask
    ):
        model_kwargs["attention_mask"] = model._prepare_attention_mask_for_generation(
            inputs_tensor,
            generation_config.pad_token_id,
            generation_config.eos_token_id,
        )

    # decoder-only models should use left-padding for generation
    if not model.config.is_encoder_decoder:
        # If `input_ids` was given, check if the last id in any sequence is `pad_token_id`
        # Note: If using, `inputs_embeds` this check does not work, because we want to be more hands-off.
        if (
            generation_config.pad_token_id is not None
            and len(inputs_tensor.shape) == 2
            and torch.sum(inputs_tensor[:, -1] == generation_config.pad_token_id) > 0
        ):
            logger.warning(
                "A decoder-only architecture is being used, but right-padding was detected! For correct "
                "generation results, please set `padding_side='left'` when initializing the tokenizer."
            )

    if model.config.is_encoder_decoder and "encoder_outputs" not in model_kwargs:
        # if model is encoder decoder encoder_outputs are created
        # and added to `model_kwargs`
        model_kwargs = model._prepare_encoder_decoder_kwargs_for_generation(
            inputs_tensor, model_kwargs, model_input_name
        )

    # 5. Prepare `input_ids` which will be used for auto-regressive generation
    if model.config.is_encoder_decoder:
        input_ids, model_kwargs = model._prepare_decoder_input_ids_for_generation(
            batch_size=batch_size,
            model_input_name=model_input_name,
            model_kwargs=model_kwargs,
            decoder_start_token_id=generation_config.decoder_start_token_id,
            bos_token_id=generation_config.bos_token_id,
            device=inputs_tensor.device,
        )
    else:
        input_ids = (
            inputs_tensor
            if model_input_name == "input_ids"
            else model_kwargs.pop("input_ids")
        )

    if streamer is not None:
        streamer.put(input_ids.cpu())

    # 6. Prepare `max_length` depending on other stopping criteria.
    input_ids_length = input_ids.shape[-1]
    has_default_max_length = (
        kwargs.get("max_length") is None and generation_config.max_length is not None
    )
    if generation_config.max_new_tokens is not None:
        if not has_default_max_length and generation_config.max_length is not None:
            logger.warning(
                f"Both `max_new_tokens` (={generation_config.max_new_tokens}) and `max_length`(="
                f"{generation_config.max_length}) seem to have been set. `max_new_tokens` will take precedence. "
                "Please refer to the documentation for more information. "
                "(https://huggingface.co/docs/transformers/main/en/main_classes/text_generation)"
            )
        generation_config.max_length = (
            generation_config.max_new_tokens + input_ids_length
        )
    model._validate_generated_length(
        generation_config, input_ids_length, has_default_max_length
    )

    # 7. determine generation mode
    generation_mode = model._get_generation_mode(generation_config, assistant_model)

    if streamer is not None and (generation_config.num_beams > 1):
        raise ValueError(
            "`streamer` cannot be used with beam search (yet!). Make sure that `num_beams` is set to 1."
        )

    if model.device.type != input_ids.device.type:
        warnings.warn(
            "You are calling .generate() with the `input_ids` being on a device type different"
            f" than your model's device. `input_ids` is on {input_ids.device.type}, whereas the model"
            f" is on {model.device.type}. You may experience unexpected behaviors or slower generation."
            " Please make sure that you have put `input_ids` to the"
            f" correct device by calling for example input_ids = input_ids.to('{model.device.type}') before"
            " running `.generate()`.",
            UserWarning,
        )

    # 8. prepare distribution pre_processing samplers
    logits_processor = model._get_logits_processor(
        generation_config=generation_config,
        input_ids_seq_length=input_ids_length,
        encoder_input_ids=inputs_tensor,
        prefix_allowed_tokens_fn=prefix_allowed_tokens_fn,
        logits_processor=logits_processor,
        model_kwargs=model_kwargs,
        negative_prompt_ids=negative_prompt_ids,
        negative_prompt_attention_mask=negative_prompt_attention_mask,
    )

    # 9. prepare stopping criteria
    stopping_criteria = model._get_stopping_criteria(
        generation_config=generation_config, stopping_criteria=stopping_criteria
    )

    # 10. go into different generation modes
    if generation_mode == GenerationMode.ASSISTED_GENERATION:
        if generation_config.num_return_sequences > 1:
            raise ValueError(
                "num_return_sequences has to be 1 when doing assisted generate, "
                f"but is {generation_config.num_return_sequences}."
            )
        if batch_size > 1:
            raise ValueError("assisted generate is only supported for batch_size = 1")
        if not model_kwargs["use_cache"]:
            raise ValueError("assisted generate requires `use_cache=True`")

        assistant_accepts_encoder_outputs = "encoder_outputs" in set(
            inspect.signature(assistant_model.forward).parameters.keys()
        )

        # 11. If the assistant model is an encoder-decoder, prepare its encoder outputs
        if (
            assistant_model.config.is_encoder_decoder
            and "assistant_encoder_outputs" not in model_kwargs
        ):
            assistant_model_kwargs = copy.deepcopy(model_kwargs)
            (
                inputs_tensor,
                model_input_name,
                assistant_model_kwargs,
            ) = assistant_model._prepare_model_inputs(
                inputs_tensor,
                assistant_model.generation_config.bos_token_id,
                assistant_model_kwargs,
            )
            assistant_model_kwargs = assistant_model._prepare_encoder_decoder_kwargs_for_generation(
                inputs_tensor, assistant_model_kwargs, model_input_name
            )
            model_kwargs["assistant_encoder_outputs"] = assistant_model_kwargs[
                "encoder_outputs"
            ]

        if (
            not assistant_model.config.is_encoder_decoder
            and assistant_accepts_encoder_outputs
            and "encoder_outputs" in model_kwargs
        ):
            # some assistants might be assymetric (many more enc layers than dec layers)
            # encoder-decoder models that share the exact same encoder as the teacher
            # in this case the assistant only needs to load the light-weight decoder,
            # but still requires `encoder_outputs` to be passed
            model_kwargs["assistant_encoder_outputs"] = model_kwargs["encoder_outputs"]

    # 11. run greedy search
    return greedy_search(
        model,
        input_ids,
        logits_processor=logits_processor,
        stopping_criteria=stopping_criteria,
        pad_token_id=generation_config.pad_token_id,
        eos_token_id=generation_config.eos_token_id,
        output_scores=generation_config.output_scores,
        return_dict_in_generate=generation_config.return_dict_in_generate,
        synced_gpus=synced_gpus,
        streamer=streamer,
        **model_kwargs,
    )
