# from train_tools.data_utils.preprocess.mmlu import *
import torch
from tqdm import tqdm
from datasets import load_dataset


@torch.no_grad()
def run_mmlu_eval(
    model, tokenizer, chat_format_dict, k=0, batch_size=8, max_length=2048, cache_dir="./"
):
    categories = get_subcategories()
    super_categories = get_categories()

    label_space = ["A", "B", "C", "D"]
    label_indices = [tokenizer.encode(label)[-1] for label in label_space]

    results = {
        "humanities": {"correct": 0, "total": 0},
        "STEM": {"correct": 0, "total": 0},
        "social sciences": {"correct": 0, "total": 0},
        "other (business, health, misc.)": {"correct": 0, "total": 0},
        "all": {"correct": 0, "total": 0},
    }
    
    for task_name, task_category in tqdm(
        categories.items(), desc="MMLU Tasks", position=0, leave=False
    ):
        category_name = task_category[0]
        super_category = get_super_category(category_name, super_categories)
        
        task_correct, task_total = 0, 0

        print(f"\nEvaluating {task_name}...")

        trainset = (
            load_dataset(
                "cais/mmlu",
                task_name,
                cache_dir=cache_dir,
                split="dev",
                trust_remote_code=True,
            )
            if k > 0
            else None
        )

        testset = load_dataset(
            "cais/mmlu",
            task_name,
            cache_dir=cache_dir,
            split="test",
            trust_remote_code=True,
        )

        testset = testset.map(
            lambda sample: preprocess_sample(sample, task_name, chat_format_dict, trainset, k),
            batched=False,
            remove_columns=testset.column_names,
        )

        for i in range(0, len(testset), batch_size):
            inputs = testset["prompts"][i : i + batch_size]
            labels = testset["golden_labels"][i : i + batch_size]
            
            answer_indices = [label_space.index(label) for label in labels]

            inputs = tokenizer(
                inputs,
                return_tensors="pt",
                padding=True,
                truncation=True,
                max_length=max_length,
            )
            output = model.generate(
                input_ids=inputs.input_ids.to(model.device),
                attention_mask=inputs.attention_mask.to(model.device),
                do_sample=False,
                pad_token_id=tokenizer.pad_token_id,
                max_new_tokens=1,
                return_dict_in_generate=True,
                output_scores=True,
            )
            
            del inputs
            pred = torch.argmax(output.scores[0][:, label_indices], dim=1)
            task_correct += torch.sum(pred == torch.tensor(answer_indices).to(model.device)).item()
            task_total += len(labels)

        results[super_category]["correct"] += task_correct
        results[super_category]["total"] += task_total
        
        results["all"]["correct"] += task_correct
        results["all"]["total"] += task_total
        
    mmlu_results = {}
    for key, value in results.items():
        mmlu_results[key] = value["correct"] / value["total"]

    return mmlu_results


def preprocess_sample(sample, subject, chat_format_dict, trainset=None, k=0):
    prompt = _generate_prompt(sample, subject, chat_format_dict, trainset, k)
    choice_names = ["A", "B", "C", "D"]
    answer = choice_names[int(sample["answer"])]

    return {"prompts": prompt, "golden_labels": answer}


def _generate_prompt(sample, subject, chat_format_dict, trainset=None, k=0):
    subject = _format_subject(subject)

    user_start, model_start, turn_end = (
        chat_format_dict["user_start"],
        chat_format_dict["model_start"],
        chat_format_dict["turn_end"],
    )
    
    mmlu_prompt = f"{user_start}"
    mmlu_prompt += f"The following are multiple choice questions (with answers) about {subject}.\n\n"

    if k > 0:
        examples = _generate_examples(trainset, k)
        mmlu_prompt += examples

    instance = _format_instance(sample, include_answer=False)
    mmlu_prompt += f"{instance}{turn_end}"
    mmlu_prompt += f"{model_start}The correct answer is: "

    return mmlu_prompt


def _generate_examples(trainset, k=0):
    example_shots = ""
    for i in range(k):
        example_shots += _format_instance(trainset[i], include_answer=True)

    return example_shots


def _format_instance(sample, include_answer=False):
    prompt = sample["question"]
    choice_names = ["A", "B", "C", "D"]

    for i in range(4):
        choice = sample["choices"][i]
        choice_name = choice_names[i]
        prompt += f"\n{choice_name}. {choice}"

    prompt += "\nAnswer:"

    if include_answer:
        answer = choice_names[int(sample["answer"])]
        prompt += f" {answer}\n\n"

    return prompt


def _format_subject(subject):
    subwords = subject.split("_")
    subject = " ".join(subwords)

    return subject

def get_super_category(category_name, category_dict):
    for super_category, sub_categories in category_dict.items():
        if category_name in sub_categories:
            return super_category
    return None

def get_subcategories():
    return {
        "abstract_algebra": ["math"],
        "anatomy": ["health"],
        "astronomy": ["physics"],
        "business_ethics": ["business"],
        "clinical_knowledge": ["health"],
        "college_biology": ["biology"],
        "college_chemistry": ["chemistry"],
        "college_computer_science": ["computer science"],
        "college_mathematics": ["math"],
        "college_medicine": ["health"],
        "college_physics": ["physics"],
        "computer_security": ["computer science"],
        "conceptual_physics": ["physics"],
        "econometrics": ["economics"],
        "electrical_engineering": ["engineering"],
        "elementary_mathematics": ["math"],
        "formal_logic": ["philosophy"],
        "global_facts": ["other"],
        "high_school_biology": ["biology"],
        "high_school_chemistry": ["chemistry"],
        "high_school_computer_science": ["computer science"],
        "high_school_european_history": ["history"],
        "high_school_geography": ["geography"],
        "high_school_government_and_politics": ["politics"],
        "high_school_macroeconomics": ["economics"],
        "high_school_mathematics": ["math"],
        "high_school_microeconomics": ["economics"],
        "high_school_physics": ["physics"],
        "high_school_psychology": ["psychology"],
        "high_school_statistics": ["math"],
        "high_school_us_history": ["history"],
        "high_school_world_history": ["history"],
        "human_aging": ["health"],
        "human_sexuality": ["culture"],
        "international_law": ["law"],
        "jurisprudence": ["law"],
        "logical_fallacies": ["philosophy"],
        "machine_learning": ["computer science"],
        "management": ["business"],
        "marketing": ["business"],
        "medical_genetics": ["health"],
        "miscellaneous": ["other"],
        "moral_disputes": ["philosophy"],
        "moral_scenarios": ["philosophy"],
        "nutrition": ["health"],
        "philosophy": ["philosophy"],
        "prehistory": ["history"],
        "professional_accounting": ["other"],
        "professional_law": ["law"],
        "professional_medicine": ["health"],
        "professional_psychology": ["psychology"],
        "public_relations": ["politics"],
        "security_studies": ["politics"],
        "sociology": ["culture"],
        "us_foreign_policy": ["politics"],
        "virology": ["health"],
        "world_religions": ["philosophy"],
    }


def get_categories():
    return {
        "STEM": [
            "physics",
            "chemistry",
            "biology",
            "computer science",
            "math",
            "engineering",
        ],
        "humanities": ["history", "philosophy", "law"],
        "social sciences": [
            "politics",
            "culture",
            "economics",
            "geography",
            "psychology",
        ],
        "other (business, health, misc.)": ["other", "business", "health"],
    }
