import torch
from transformers import TrainerCallback
import wandb

from .eval_funcs.choice import *
from .eval_funcs.mmlu import *

__all__ = ["EvalCallback"]


class EvalCallback(TrainerCallback):
    def __init__(self, model, tokenizer, eval_data_dict, chat_format_dict, fast_mode=False, cache_dir="./"):
        super().__init__()
        self.model = model
        self.tokenizer = tokenizer
        self.eval_data_dict = eval_data_dict
        self.total_results = {}
        self.epoch_counter = 0
        self.chat_format_dict = chat_format_dict
        self.fast_mode = fast_mode
        self.cache_dir = cache_dir

        # Define a list of evaluation tasks with optional additional arguments
        self.evaluation_tasks = [
            ("science_qa", run_science_qa_eval, {}),
            ("winogrande", run_winogrande_eval, {}),
            ("piqa", run_piqa_eval, {}),
            ("hhh_eval", run_static_hhh_eval, {}),
            ("arc_easy", run_arc_eval, {"subset_key": "ARC-Easy"}),
            ("arc_challenge", run_arc_eval, {"subset_key": "ARC-Challenge"}),
            ("commonsense_qa", run_commonsense_qa_eval, {}),
            ("social_i_qa", run_social_i_qa_eval, {}),
            ("truthful_qa", run_truthfulqa_eval, {}),
        ]

    @torch.no_grad()
    # def on_step_end(self, args, state, control, **kwargs):
    def on_train_end(self, args, state, control, **kwargs):
        # if (
        #     (state.global_step == 1)
        #     # or (state.global_step % 250 == 0)
        #     or (state.global_step == state.max_steps)
        # ):

        # if self.fast_mode:
        #     if state.global_step < state.max_steps:
        #         print("\n>>>> Using Fast Mode... skipping evaluation")
        #         return

        # Process each evaluation task
        for dataset_key, eval_func, kwargs in self.evaluation_tasks:
            results = eval_func(
                self.model,
                self.tokenizer,
                self.eval_data_dict[dataset_key],
                chat_format_dict=self.chat_format_dict,
                **kwargs
            )
            wandb.log(results, step=state.global_step)
            
        # MMLU Evaluation
        mmlu_results = run_mmlu_eval(
            self.model,
            self.tokenizer,
            self.chat_format_dict,
            k=5,
            batch_size=4,
            max_length=2048,
            cache_dir=self.cache_dir,
        )
        wandb.log(mmlu_results, step=state.global_step)
        
