import json

__all__ = [
    "load_json_file",
    "preprocess_dataset",
    "preprocess_reward_datset",
]

def load_json_file(data_path):
    with open(data_path, "r") as f:
        data = json.load(f)
    return data

def preprocess_dataset(dataset, chat_format_dict, base_dict=None):
    is_base_exist = "base" in dataset.column_names
    dataset = dataset.map(
        lambda x: (
            _format_as_chat(x, chat_format_dict, is_base_exist=is_base_exist, base_dict=base_dict)
        ),
        batched=True,
        num_proc=8,
    )
    
    return dataset

def _format_as_chat(examples, chat_format_dict, is_base_exist=False, base_dict=None):
    # Initialize dictionary to collect new examples.
    new_examples = {
        "prompt": [],
        "chosen": [],
        "rejected": [],
    }
    if is_base_exist or base_dict is not None:
        new_examples["base"] = []

    # Iterate through examples and format prompts.
    zipped_items = zip(
        examples["prompt"], examples["chosen"], examples["rejected"],
        *(examples["base"],) if is_base_exist else ()
    )
    for item in zipped_items:
        formatted_prompt = "{user_start}{prompt}{turn_end}{model_start}".format(
            prompt=item[0], **chat_format_dict
        )
        new_examples["prompt"].append(formatted_prompt)
        new_examples["chosen"].append(item[1])
        new_examples["rejected"].append(item[2])
        if is_base_exist:
            new_examples["base"].append(item[3])
        elif base_dict is not None:
            new_examples["base"].append(base_dict[item[0]])

    return new_examples


def preprocess_reward_datset(tokenizer, examples, chat_format_dict):
    # Tokenize the dataset for use in a reward model, capturing both 'chosen' and 'rejected' responses.
    new_examples = {
        "input_ids_chosen": [],
        "attention_mask_chosen": [],
        "input_ids_rejected": [],
        "attention_mask_rejected": [],
    }
    for prompt, chosen, rejected in zip(
        examples["prompt"], examples["chosen"], examples["rejected"]
    ):
        formatted_chosen = """{user_start}{prompt}{turn_end}{model_start}{chosen}"""
        formatted_rejected = (
            """{user_start}{prompt}{turn_end}{model_start}{rejected}"""
        )

        formatted_chosen = formatted_chosen.format(
            prompt=prompt, chosen=chosen, **chat_format_dict
        )
        formatted_rejected = formatted_rejected.format(
            prompt=prompt, rejected=rejected, **chat_format_dict
        )

        tokenized_chosen = tokenizer(
            formatted_chosen, truncation=True, max_length=512, padding="max_length"
        )
        tokenized_rejected = tokenizer(
            formatted_rejected, truncation=True, max_length=512, padding="max_length"
        )

        new_examples["input_ids_chosen"].append(tokenized_chosen["input_ids"])
        new_examples["attention_mask_chosen"].append(tokenized_chosen["attention_mask"])
        new_examples["input_ids_rejected"].append(tokenized_rejected["input_ids"])
        new_examples["attention_mask_rejected"].append(
            tokenized_rejected["attention_mask"]
        )

    return new_examples
