from typing import Any, Dict, List, Set, Tuple

import ray
import torch
from utils.ray_actor import get_remote_model_generator_class
from utils.TSP_gen_utils import *

def setup_model_actors_and_data(config: List[Dict], norm_type: str, threshold: float) -> Tuple[List[Any], List[Any], Set[str], List[Dict[int, int]], Dict[int, str], Dict[Any, str], List[Dict[str, int]], int]:
    """
    Sets up model actors based on configurations and preprocesses necessary data for text generation.

    Args:
        config (List[Dict]): Configuration list where each element is a dictionary specifying
                             model path and memory specifications.
        norm_type (str): The type of normalization to apply ('ece_norm', 'average', or 'score').
        threshold (float)

    Returns:
        Tuple containing:
        - model_actors_list (List[ActorHandle]): List of Ray actor handles for the model generators.
        - tokenizers (List[Tokenizer]): List of tokenizer instances fetched from each model actor.
        - vocab_union (Set[str]): Unified set of all tokens across the tokenizers' vocabularies.
        - mapping_matrices (List[torch.sparse_coo_tensor]): A list of sparse COO tensors, each representing
          a mapping matrix from a model's tokenizer token IDs to the token IDs in the unified vocabulary.
          Each matrix corresponds to a tokenizer and maps its original token IDs to new token IDs in the
          unified vocabulary. The shape of each matrix is [model_vocab_size, len(vocab_union)], where
          model_vocab_size is the size of the tokenizer's vocabulary.
        - index_to_vocab (Dict[int, str]): Mapping from unique indices to tokens in the unified vocabulary.
        - special_prefix_tokens_dict (Dict[Tokenizer, str]): Mapping of each tokenizer to its special prefix token.
        - byte_mappings_list (List[Dict[str, int]]): List of byte value mappings for '<0x00>' to '<0xFF>'
          for each tokenizer.
        - min_max_position_embeddings (int): The minimum of the maximum position embeddings across all model actors.
        - model_name_list (List[str]): list of model name in model_actors_list
        - primary_index (int)
        - threshold (float)
    """
    update_scores(config, norm_type)
    config = normalize_scores(config)
    logger.info(f"Model ensemble weight{[(c['name'], round(c['score'],4)) for c in config]}")

    # find primary model
    primary_index = check_priorities(config)
    if primary_index != -1:
        real_threshold = threshold*config[primary_index]["score"]
        logger.info(f"Primary model is {config[primary_index]['name']} with threshold {real_threshold}!")
    else:
        real_threshold = threshold = 1
        logger.info(f"Every token will be ensembled, which means threshold is {real_threshold}!")
    
    # Initialize model actors based on configuration and GPU requirements
    model_actors_list = [
        get_remote_model_generator_class(model_config["num_gpus"]).remote(
            model_path=model_config["weight"], max_memory=model_config["max_memory"],model_name=model_config["name"],model_ensemble_weight=model_config["score"]
        )
        for model_config in config
    ]

    # Fetch tokenizer for each model
    tokenizers = [
        ray.get(model_actor.get_tokenizer.remote()) for model_actor in model_actors_list
    ]

    model_name_list = [
        ray.get(model_actor.get_model_name.remote()) for model_actor in model_actors_list
    ]

    # Determine special prefix tokens for all tokenizers
    special_prefix_tokens_dict = get_special_prefix_tokens_for_all(tokenizers)

    # Create a unified vocabulary and mappings for tokenizers
    vocab_union, tokenizers_mapping, index_to_vocab, byte_mappings_list = get_vocab_union_and_mapping(
        tokenizers
    )

    model_vocab_size_list = [
        ray.get(model_actor.get_vocab_size.remote()) for model_actor in model_actors_list
    ]

    mapping_matrices = [
        create_mapping_matrix(mapping, len(vocab_union), vocab_size)
        for mapping, tokenizer, vocab_size in zip(tokenizers_mapping, tokenizers, model_vocab_size_list)
    ]

    # Find the minimum max position embeddings across all models
    min_max_position_embeddings = min(
        ray.get(model_actor.get_max_position_embeddings.remote())
        for model_actor in model_actors_list
    )

    return (
        model_actors_list,
        tokenizers,
        vocab_union,
        mapping_matrices,
        index_to_vocab,
        special_prefix_tokens_dict,
        byte_mappings_list,
        min_max_position_embeddings,
        model_name_list,
        primary_index,
        real_threshold,
    )

def calculate_non_pad_lengths(models_inputs, tokenizers, recorder):
    """
    Calculate the lengths of each input for multiple models, excluding padding tokens.
    
    Args:
    models_inputs (list of list of torch.Tensor): List of inputs for each model, where each input is a tensor.
    tokenizers (list): List of tokenizers corresponding to each model input, used to obtain the pad token id.
    recorder: An object used to record the input lengths after removing pad tokens.
    
    This function does not return any value. It updates the recorder with the computed lengths.
    """
    input_lengths = []  # To store the computed lengths

    for i, inputs in enumerate(models_inputs):
        pad_token_id = tokenizers[i].pad_token_id  # Get the pad token id for the current model
        model_input_lengths = []  # To store the input lengths for the current model

        for input_tensor in inputs:
            # Create a mask where positions of pad token are False, and others are True
            mask = input_tensor != pad_token_id
            
            # Compute the number of True elements in the mask, i.e., the number of non-pad tokens
            length = torch.sum(mask).item()
            model_input_lengths.append(length)
        
        input_lengths.append(model_input_lengths)
    
    recorder.update_input_length(input_lengths)  # Update the recorder with the input lengths

def check_priorities(dict_list):
    """
    Check the list of dictionaries to ensure that there is exactly one "primary" priority and all priorities are valid.

    Args:
    dict_list (list of dict): A list where each item is a dictionary with a key "priority" whose value should be either "supportive" or "primary".

    Returns:
    int: Index of the first dictionary with "primary" as priority if there is exactly one, otherwise returns -1.
    """
    allowed_priorities = ["supportive", "primary"]
    primary_index = -1
    primary_count = 0

    for index, d in enumerate(dict_list):
        priority = d.get("priority")

        # Check if the priority is within the allowed values
        if priority not in allowed_priorities:
            raise ValueError(f"'priority' value '{priority}' at index {index} is not allowed!")

        # Check for primary priority and count them
        if priority == "primary":
            primary_count += 1
            if primary_count == 1:
                primary_index = index

    # Warn if there is more than one primary priority
    if primary_count > 1:
        raise ValueError("More than one 'primary' found!")

    return primary_index


def normalize_scores(config, n=1):
    """
    Normalizes the scores of each configuration in the list of dictionaries by multiplying each score by n,
    and then normalizing these scores to a 0 to 1 range such that their sum is 1.
    
    Parameters:
        config (list of dict): A list of dictionaries, each representing a configuration with a 'score' key.
        n (int, optional): The factor to multiply each score by before normalization. Defaults to 1.
    
    Returns:
        list of dict: The input list of dictionaries with normalized 'score' values.
    """
    
    # Extract scores and multiply by n
    scores = np.array([configuration['score'] for configuration in config]) ** n
    
    # Normalize scores to sum to 1
    normalized_scores = scores / np.sum(scores)
    
    # Update the scores in the original list of dictionaries
    for configuration, new_score in zip(config, normalized_scores):
        configuration['score'] = new_score
    
    return config

def extract_generated_texts(tokenizer, input_ids_0: torch.Tensor, output: torch.Tensor) -> List[str]:
    """
    Extract the generated text from the model output, excluding the input section and any padding to the left.

    :param tokenizer: The tokenizer used, which must contain the pad_token_id attribute.
    :param input_ids_0: The token IDs input to the model, in the shape (batch_size, sequence_length).
                        The input may contain padding on the left.
    :param output: The output token IDs of the model, with shape (batch_size, output_sequence_length).
                   The output sequence contains the input sequence and the generated response.
    :return: A list of strings, each of which is the text generated by the model in the corresponding batch.

    Function logic:
    - For each sample, find the non-pad part in input_ids_0.
    - Search for a sequence in the output that matches the non-pad part.
    - Cut from the end of the found matching sequence until the end of the output as the response content.
    - Use tokenizer to decode the token IDs of the response content into text.
    """
    pad_token_id = tokenizer.pad_token_id  # Get the ID of the pad token
    generated_texts = []

    for i in range(output.shape[0]):
        # Find the index of the first non-pad token in input_ids_0
        non_pad_indices = (input_ids_0[i] != pad_token_id).nonzero().squeeze()
        
        if non_pad_indices.dim() == 0:
            non_pad_indices = non_pad_indices.unsqueeze(0)

        first_non_pad_index = non_pad_indices[0].item() if non_pad_indices.numel() > 0 else -1

        if first_non_pad_index == -1:
            raise ValueError("No non-pad tokens found in the input for batch index {}".format(i))

        # Construct the input_ids tensor for the non-pad part of the current sample
        input_ids_non_pad = input_ids_0[i, first_non_pad_index:]

        # Find the sequence that matches input_ids_non_pad
        found_match = False
        for pos in range(output.shape[1]):
            if pos + input_ids_non_pad.shape[0] <= output.shape[1]:
                if torch.equal(output[i, pos:pos+input_ids_non_pad.shape[0]], input_ids_non_pad):
                    found_match = True
                    response_start_index = pos + input_ids_non_pad.shape[0]
                    break

        if not found_match:
            raise ValueError(f"No matching sequence found in the output for batch index {i}")

        # Token IDs for the intercepted response part
        response_ids = output[i, response_start_index:]
        # Decode to text using tokenizer
        decoded_text = tokenizer.decode(response_ids.tolist(), skip_special_tokens=True, clean_up_tokenization_spaces=False)
        generated_texts.append(decoded_text)

    return generated_texts

def update_scores(config, norm_type):
    """
    This function updates each dictionary in a list by different strategies based on the norm_type value.
    - 'ece_norm': Subtracts the value of "ece" from "score" and stores the result back in "score".
    - 'average': Sets all scores to 1.
    - 'score': Leaves the "score" values unchanged.
    
    If the norm_type is not one of the specified values, an error is raised.

    Parameters:
    - config (list of dict): A list of dictionaries, each containing the fields "score" and "ece".
    - norm_type (str): The type of normalization to apply ('ece_norm', 'average', or 'score').

    Returns:
    - The updated list of dictionaries according to the specified normalization type.
    
    Raises:
    - ValueError: If norm_type is not one of the specified values.
    """
    if norm_type == 'ece_norm':
        for item in config:
            if "score" in item and "ece" in item:
                item["score"] = item["score"] - item["ece"]
    elif norm_type == 'average':
        for item in config:
            item["score"] = 1
    elif norm_type == 'score':
        pass
    else:
        raise ValueError(f"Invalid norm_type: {norm_type}. Expected 'ece_norm', 'average', or 'score'.")

    return config

# init RAY
ray.init()
