from transformers import AutoModelForCausalLM, AutoTokenizer, TrainingArguments
from trl import ORPOConfig
from custom_trainer import *
import torch
import wandb
import argparse
import os
from datasets import load_dataset
from train_tools import *


def get_model_setups(args, trainer_name):
    """
    Prepare the model and tokenizer for the policy model training.
    """
    # Retrieve the default quantization settings for the model.
    quantization_config = get_default_quantization_config()

    # Load the tokenizer & model.
    tokenizer = AutoTokenizer.from_pretrained(
        MODEL_DICT[args.model_name],
        cache_dir=args.params.cache_dir,
    )
    policy_model = AutoModelForCausalLM.from_pretrained(
        MODEL_DICT[args.model_name],
        torch_dtype=torch.bfloat16,
        quantization_config=quantization_config,
        **args.params,
    )

    # Apply PEFT (Q-LORA) to the model.
    policy_model = prepare_model_for_qlora(policy_model, args.peft_setups)

    # Set a SFT ref model for specific trainers.
    ref_model = None

    if trainer_name in ["dpo", "ipo", "rso", "dpop", "bapo"]:
        ref_model = AutoModelForCausalLM.from_pretrained(
            MODEL_DICT[args.model_name],
            torch_dtype=torch.bfloat16,
            quantization_config=quantization_config,
            **args.params,
        ).eval()

    return {"model": policy_model, "ref_model": ref_model, "tokenizer": tokenizer}


def get_data_setups(args, trainer_name=None, model_name="gemma-2b-it"):
    """
    Load and preprocess the preference data for the policy model training.
    """

    # Load preference data
    data_path = os.path.join(
        "./data/train", args.dataset_name, f"{args.preference_type}.json"
    )
    train_base_dict, eval_base_dict = None, None
    if trainer_name == "bapo":
        if args.dataset_name == "psoups":
            data_path = data_path.replace(".json", f"_with_base_{model_name}.json")
        elif args.dataset_name == "psoups_identical":
            base_dict_path = os.path.join(
                "./data/train", args.dataset_name, f"psoups_identical_base_{model_name}.json"
            )
            train_base_dict = load_json_file(base_dict_path)
            eval_base_dict= train_base_dict
        elif args.dataset_name == "dsp":
            train_base_dict_path = os.path.join(
                "./data/train", args.dataset_name, f"dsp_train_base_{model_name}.json"
            )
            train_base_dict = load_json_file(train_base_dict_path)
            eval_base_dict_path = train_base_dict_path.replace("dsp_train", "dsp_test")
            eval_base_dict = load_json_file(eval_base_dict_path)
        

    dataset = load_dataset("json", data_files=data_path, split="train").shuffle(
        seed=2024
    )

    # Split train/eval datasets.
    if args.dataset_name in ["psoups", "psoups_identical"]:
        train_size = 45000
        train_dataset = dataset.select(range(train_size))
        eval_dataset = dataset.select(range(train_size, len(dataset)))
    elif args.dataset_name == "dsp":
        train_dataset = dataset
        eval_data_path = data_path.replace(".json", "_test.json")
        eval_dataset = load_dataset("json", data_files=eval_data_path, split="train")
    else:
        raise ValueError(f"Invalid dataset name: {args.dataset_name}")

    # Preprocess the preference data based on chat format.
    train_dataset = preprocess_dataset(train_dataset, CHAT_FORMAT[model_name], base_dict=train_base_dict)
    eval_dataset = preprocess_dataset(eval_dataset, CHAT_FORMAT[model_name], base_dict=eval_base_dict)

    return {"train_dataset": train_dataset, "eval_dataset": eval_dataset}


def get_trainer_setups(model_dict, data_dict, output_dir, cache_dir, args):
    # Setup Evaluation Callback
    eval_data_dict = {}

    for dataset_name, dataset_params in EVAL_DATASETS.items():
        dataset_params["cache_dir"] = cache_dir
        dataset_params["trust_remote_code"] = True
        eval_data_dict[dataset_name] = load_dataset(**dataset_params)

    chat_format_dict = CHAT_FORMAT[args.model_setups.model_name]
    eval_callback = EvalCallback(
        model_dict["model"],
        model_dict["tokenizer"],
        eval_data_dict=eval_data_dict,
        chat_format_dict=chat_format_dict,
        fast_mode=args.trainer.fast_mode,
        cache_dir=cache_dir,
    )

    # Set up the Trainer for training
    base_trainer_args = {
        "output_dir": output_dir,
        "gradient_checkpointing_kwargs": {"use_reentrant": False},
        "run_name": args.wandb_setups.name,
        **args.train_setups,
    }

    if args.trainer.name == "orpo":
        train_args = ORPOConfig(
            **args.trainer.orpo_args,
            **base_trainer_args,
        )
        trainer = TRAINER_DICT[args.trainer.name](
            model=model_dict["model"],
            tokenizer=model_dict["tokenizer"],
            train_dataset=data_dict["train_dataset"],
            eval_dataset=data_dict["eval_dataset"],
            args=train_args,
            callbacks=[eval_callback],
            **args.trainer.params,
        )

    else:
        train_args = TrainingArguments(
            **base_trainer_args,
        )

        trainer = TRAINER_DICT[args.trainer.name](
            model=model_dict["model"],
            ref_model=model_dict["ref_model"],
            tokenizer=model_dict["tokenizer"],
            train_dataset=data_dict["train_dataset"],
            eval_dataset=data_dict["eval_dataset"],
            args=train_args,
            callbacks=[eval_callback],
            **args.trainer.params,
        )

    return trainer


def main(args):
    # Initialize W&B
    wandb.init(config=args, **args.wandb_setups)

    # Fix randomness for reproducibility
    random_seeder(args.train_setups.seed)

    # Load the model and data setups
    model_dict = get_model_setups(args.model_setups, args.trainer.name)
    data_dict = get_data_setups(args.data_setups, args.trainer.name, args.model_setups.model_name)

    # Define the output directory path
    output_dir = os.path.join(args.trainer.output_dir, args.data_setups.dataset_name, "policy", args.wandb_setups.name)
    os.makedirs(output_dir, exist_ok=True)

    # Define the trainer setups & Run the training
    cache_dir = args.model_setups.params.cache_dir
    trainer = get_trainer_setups(model_dict, data_dict, output_dir, cache_dir, args)
    trainer.train()

    # Save the final model
    trainer.save_model(output_dir)


# Parser arguments for terminal execution
parser = argparse.ArgumentParser(description="Config file processing")
parser.add_argument("--config_path", default="./config/train/policy_dpo.json", type=str)
parser.add_argument("--model_name", type=str)
parser.add_argument("--dataset_name", type=str)
parser.add_argument("--preference_type", type=str)
parser.add_argument("--trainer_name", type=str)
parser.add_argument("--fast_mode", action="store_true")
parser.add_argument("--output_dir", type=str)
parser.add_argument("--bapo_lambda1", type=float)
parser.add_argument("--bapo_lambda2", type=float)
parser.add_argument("--num_train_epochs", type=int)
parser.add_argument("--learning_rate", type=float)
parser.add_argument("--lr_scheduler_type", type=str)
parser.add_argument("--lora_r", type=int)
parser.add_argument("--lora_alpha", type=float)
parser.add_argument("--cache_dir", type=str)
parser.add_argument("--group", type=str)
parser.add_argument("--project", type=str)
parser.add_argument("--exp_name", type=str)
args = parser.parse_args()

if __name__ == "__main__":
    # Set torch base print precision
    torch.set_printoptions(6)

    # Load configuration options
    config_loader = ConfLoader(args.config_path)
    opt = config_loader.opt

    # Apply updates from command line arguments to the configuration
    opt = overwrite_config(opt, args, ARGS_UPDATE_PATH)

    # Print configuration dictionary pretty
    pprint_config(opt)

    # Run experiment
    main(opt)
