from trl import DPOTrainer, ORPOTrainer
from custom_trainer import DPOPTrainer, BAPOTrainer


__all__ = ["MODEL_DICT", "CHAT_FORMAT", "TRAINER_DICT", "EVAL_DATASETS", "ARGS_UPDATE_PATH"]

MODEL_DICT = {
    "gemma-2b-it": "google/gemma-2b-it",
    "phi3-mini": "microsoft/Phi-3-mini-128k-instruct",
    "tinyllama": "TinyLlama/TinyLlama-1.1B-Chat-v1.0",
}

CHAT_FORMAT = {
    "gemma-2b-it": {
        "user_start": "<start_of_turn>user\n",
        "model_start": "<start_of_turn>model\n",
        "turn_end": "<end_of_turn>",
    },
    "phi3-mini": {
        "user_start": "<|user|>\n",
        "model_start": "<|assistant|>\n",
        "turn_end": "<|end|>",
    },
    "tinyllama": {
        "user_start": "<|user|>\n",
        "model_start": "<|assistant|>\n",
        "turn_end": "</s>",
    },
}

TRAINER_DICT = {
    "dpo": DPOTrainer,
    "ipo": DPOTrainer,
    "rso": DPOTrainer,
    "dpop": DPOPTrainer,
    "orpo": ORPOTrainer,
    "bapo": BAPOTrainer,
}

EVAL_DATASETS = {
    "science_qa": {"path": "tasksource/ScienceQA_text_only", "split": "validation"},
    "winogrande": {
        "path": "winogrande",
        "name": "winogrande_s",
        "split": "validation",
    },
    "piqa": {"path": "piqa", "split": "validation"},
    "arc_challenge": {
        "path": "allenai/ai2_arc",
        "name": "ARC-Challenge",
        "split": "validation",
    },
    "arc_easy": {
        "path": "allenai/ai2_arc",
        "name": "ARC-Easy",
        "split": "validation",
    },
    "commonsense_qa": {"path": "commonsense_qa", "split": "validation"},
    "social_i_qa": {"path": "social_i_qa", "split": "validation"},
    "truthful_qa": {
        "path": "truthful_qa",
        "name": "multiple_choice",
    },
    "hhh_eval": {
        "path": "json",
        "data_files": "./data/eval/hhh/hhh_eval.json",
        "split": "train",
    },
}

ARGS_UPDATE_PATH = {
        "model_name": ("model_setups", "model_name"),
        "dataset_name": ("data_setups", "dataset_name"),
        "preference_type": ("data_setups", "preference_type"),
        "trainer_name": ("trainer", "name"),
        "fast_mode": ("trainer", "fast_mode"),
        "output_dir": ("trainer", "output_dir"),
        "bapo_lambda1": ("trainer", "bapo_lambda1"),
        "bapo_lambda2": ("trainer", "bapo_lambda2"),
        "num_train_epochs": ("train_setups", "num_train_epochs"),
        "learning_rate": ("train_setups", "learning_rate"),
        "lr_scheduler_type": ("train_setups", "lr_scheduler_type"),
        "lora_r": ("model_setups", "peft_setups", "r"),
        "lora_alpha": ("model_setups", "peft_setups", "lora_alpha"),
        "cache_dir": ("model_setups", "params", "cache_dir"),
        "project": ("wandb_setups", "project"),
        "group": ("wandb_setups", "group"),
        "project": ("wandb_setups", "project"),
        "exp_name": ("wandb_setups", "name"),
}