# This code is based on the revised code from fastchat based on tatsu-lab/stanford_alpaca.

from dataclasses import dataclass, field
import json
import math
import random
import logging
import os
# os.environ['CUDA_VISIBLE_DEVICES'] = '0'
from typing import Dict, Optional, List
import torch
# from pip._internal.index import sources
from torch.utils.data import Dataset
from deepspeed import zero
from deepspeed.utils.zero_to_fp32 import get_fp32_state_dict_from_zero_checkpoint
from deepspeed.runtime.zero.partition_parameters import ZeroParamStatus
import transformers
from transformers import BitsAndBytesConfig, deepspeed
from trainer_qwen import QwenTrainer

from transformers.trainer_pt_utils import LabelSmoother
from peft import LoraConfig, get_peft_model, prepare_model_for_kbit_training
from accelerate.utils import DistributedType
from evaluate import *

import warnings
# 压根没用
warnings.filterwarnings("ignore", category=FutureWarning, module="deepspeed", lineno=23)
warnings.filterwarnings("ignore", category=UserWarning, module="bitsandbytes")

IGNORE_TOKEN_ID = LabelSmoother.ignore_index


SYSTEM = None
TARGET_ROLE = None


@dataclass
class ModelArguments:
    model_name_or_path: Optional[str] = field(default="Qwen/Qwen-7B")


@dataclass
class DataArguments:
    # data_path: str = field(
    #     default=None, metadata={"help": "Path to the training data."}
    # )
    data_path: List[str] = field(
        default_factory=lambda: []
    )
    eval_data_path: str = field(
        default=None, metadata={"help": "Path to the evaluation data."}
    )
    lazy_preprocess: bool = False
    replay_path: str = field(
        default=None, metadata={"help": "Path to the replay data."}
    )
    is_pretrain: bool = False
    multirole_conv: bool = False
    role_system_profile: str = field(
        default=None, metadata={"help": "Path to the role profile data."}
    )
    target_role: str = field(
        default=TARGET_ROLE, metadata={"help": "target role of response, None for all roles"}
    )
    classfication: bool = False


@dataclass
class TrainingArguments(transformers.TrainingArguments):
    cache_dir: Optional[str] = field(default=None)
    optim: str = field(default="adamw_torch")
    model_max_length: int = field(
        default=8192,
        metadata={
            "help": "Maximum sequence length. Sequences will be right padded (and possibly truncated)."
        },
    )
    use_lora: bool = False
    use_lwf: bool = False
    eval_output_dir: Optional[str] = field(default=None)
    is_chat_version: bool = False


@dataclass
class LoraArguments:
    lora_r: int = 18
    lora_alpha: int = 16
    lora_dropout: float = 0.05
    # transformer.h.[0-31].attn.c_attn, transformer.h.[0-31].attn.c_proj, transformer.h.[0-31].mlp.w1, transformer.h.[0-31].mlp.w2
    lora_target_modules: List[str] = field(
        default_factory=lambda: ["q_proj", "k_proj", "v_proj", "o_proj", "up_proj", "gate_proj", "down_proj",]
    )
    lora_weight_path: str = ""
    lora_bias: str = "none"
    q_lora: bool = False


def maybe_zero_3(param):
    if hasattr(param, "ds_id"):
        assert param.ds_status == ZeroParamStatus.NOT_AVAILABLE
        with zero.GatheredParameters([param]):
            param = param.data.detach().cpu().clone()
    else:
        param = param.detach().cpu().clone()
    return param


# Borrowed from peft.utils.get_peft_model_state_dict
def get_peft_state_maybe_zero_3(named_params, bias):
    if bias == "none":
        to_return = {k: t for k, t in named_params if "lora_" in k}
    elif bias == "all":
        to_return = {k: t for k, t in named_params if "lora_" in k or "bias" in k}
    elif bias == "lora_only":
        to_return = {}
        maybe_lora_bias = {}
        lora_bias_names = set()
        for k, t in named_params:
            if "lora_" in k:
                to_return[k] = t
                bias_name = k.split("lora_")[0] + "bias"
                lora_bias_names.add(bias_name)
            elif "bias" in k:
                maybe_lora_bias[k] = t
        for k, t in maybe_lora_bias:
            if bias_name in lora_bias_names:
                to_return[bias_name] = t
    else:
        raise NotImplementedError
    to_return = {k: maybe_zero_3(v) for k, v in to_return.items()}
    return to_return


local_rank = None

def rank0_print(*args):
    if local_rank == 0:
        print(*args)


def safe_save_model_for_hf_trainer(trainer: transformers.Trainer, output_dir: str, bias="none"):
    """Collects the state dict and dump to disk."""
    # check if zero3 mode enabled
    if deepspeed.is_deepspeed_zero3_enabled():
        state_dict = trainer.model_wrapped._zero3_consolidated_16bit_state_dict()
    else:
        if trainer.args.use_lora:
            state_dict = get_peft_state_maybe_zero_3(
                trainer.model.named_parameters(), bias
            )
        else:
            state_dict = trainer.model.state_dict()
    if trainer.args.should_save and trainer.args.local_rank == 0:
        trainer._save(output_dir, state_dict=state_dict)

def preprocess_pretrain(
        sources: list[str],
        tokenizer: transformers.PreTrainedTokenizer,
        max_len: int,) -> Dict:

    im_start = tokenizer("<|im_start|>").input_ids
    im_end = tokenizer("<|im_end|>").input_ids
    nl_tokens = tokenizer('\n').input_ids

    # Apply prompt templates
    input_ids, targets = [], []
    for i, source in enumerate(sources):
        input_id, target = [], []
        input_id += im_start + tokenizer(source).input_ids + im_end + nl_tokens
        target += im_start + tokenizer(source).input_ids + im_end + nl_tokens
        # padding
        input_id += [tokenizer.pad_token_id] * (max_len - len(input_id))
        target += [IGNORE_TOKEN_ID] * (max_len - len(target))
        input_ids.append(input_id[:max_len])
        targets.append(target[:max_len])
    input_ids = torch.tensor(input_ids, dtype=torch.int)
    targets = torch.tensor(targets, dtype=torch.int)

    return dict(
        input_ids=input_ids,
        labels=targets,
        attention_mask=input_ids.ne(tokenizer.pad_token_id),
    )

def preprocess_multirole_eval(sources, target_role, tokenizer, max_len, system_message: str = SYSTEM):
    im_start = tokenizer("<|im_start|>").input_ids
    im_end = tokenizer("<|im_end|>").input_ids
    nl_tokens = tokenizer('\n').input_ids
    _system = tokenizer('system').input_ids + nl_tokens
    _target_role = im_start + tokenizer(target_role).input_ids + nl_tokens

    # Apply prompt templates
    input_ids, targets = [], []

    for i, source in enumerate(sources):
        assert source[-1]['from'] == target_role, "[ZH ERROR] last sentence is not from {target}".format(target=target_role)
        input_id = []
        target = ""
        if system_message is not None and len(system_message) > 0:
            system = im_start + _system + tokenizer(system_message).input_ids + im_end + nl_tokens
            input_id += system
        for j, sentence in enumerate(source):
            role = sentence["from"]
            if role == target_role and j == len(source) - 1:
                contain_target_role = True
                target = sentence['value']
                _input_id = im_start + tokenizer(role).input_ids + nl_tokens
            else:
                _input_id = im_start + tokenizer(role).input_ids + nl_tokens + tokenizer(sentence["value"]).input_ids + im_end + nl_tokens
            input_id += _input_id
        assert contain_target_role, "[ZH ERROR] There are no target role or the last conversation is not from target role"
        # TODO: 现在是在输入数据集中保证不会长度溢出，之后需要实现自动从前截断
        input_id = [tokenizer.pad_token_id] * (max_len - len(input_id)) + input_id
        input_ids.append(input_id[:max_len])
        # TODO: 去掉 先tokenize后decode 直接传str，可能需要改data_collate_fn(train.py)
        target_id = tokenizer(target).input_ids
        targets.append(target_id + [IGNORE_TOKEN_ID] * (max_len - len(target_id)))
    input_ids = torch.tensor(input_ids, dtype=torch.int)
    targets = torch.tensor(targets, dtype=torch.int)

    return dict(
        input_ids=input_ids,
        labels=targets,
        attention_mask=input_ids.ne(tokenizer.pad_token_id), # attention mask is no use
    )


def preprocess_multirole(sources, target_role, tokenizer, max_len, system_message: str = SYSTEM):

    im_start = tokenizer("<|im_start|>").input_ids
    im_end = tokenizer("<|im_end|>").input_ids
    nl_tokens = tokenizer('\n').input_ids
    _system = tokenizer('system').input_ids + nl_tokens
    _target_role = im_start + tokenizer(target_role).input_ids + nl_tokens

    # Apply prompt templates
    input_ids, targets = [], []

    for i, source in enumerate(sources):
        if source[0]["from"] == target_role:
            source = source[1:]
        input_id, target = [], []
        if system_message is not None and len(system_message) > 0:
            system = im_start + _system + tokenizer(system_message).input_ids + im_end + nl_tokens
            input_id += system
            target += im_start + [IGNORE_TOKEN_ID] * (len(system) - 3) + im_end + nl_tokens
        assert len(input_id) == len(target), "len input_id != target"
        contain_target_role = False
        for j, sentence in enumerate(source):
            role = sentence["from"]
            _input_id = im_start + tokenizer(role).input_ids + nl_tokens + tokenizer(sentence["value"]).input_ids + im_end + nl_tokens
            input_id += _input_id
            if role != target_role or j != len(source) - 1:
                _target = im_start + [IGNORE_TOKEN_ID] * (len(_input_id) - 3) + im_end + nl_tokens
            else:
                contain_target_role = True
                _target = im_start + [IGNORE_TOKEN_ID] * (len(tokenizer(role).input_ids) + 1) + \
                          _input_id[len(tokenizer(role).input_ids) + 2:-2] + im_end + nl_tokens
            target += _target
        assert contain_target_role, "[ZH ERROR] There are no target role or the last conversation is not from target role"
        assert len(input_id) == len(target), "Length of input_id and target are different"
        input_id += [tokenizer.pad_token_id] * (max_len - len(input_id))
        target += [IGNORE_TOKEN_ID] * (max_len - len(target))
        if len(input_id) > max_len: 
            rank0_print(f"[ZH WARNING] input_id is longer than {max_len}")
        input_ids.append(input_id[:max_len])
        targets.append(target[:max_len])
    input_ids = torch.tensor(input_ids, dtype=torch.int)
    targets = torch.tensor(targets, dtype=torch.int)

    return dict(
        input_ids=input_ids,
        labels=targets,
        attention_mask=input_ids.ne(tokenizer.pad_token_id),
    )


def preprocess(
    sources,
    tokenizer: transformers.PreTrainedTokenizer,
    max_len: int,
    # system_message: str = "You are a helpful assistant."
    system_message: str = SYSTEM,
) -> Dict:
    roles = {"user": "<|im_start|>user", "assistant": "<|im_start|>assistant"}

    im_start = tokenizer("<|im_start|>").input_ids
    im_end = tokenizer("<|im_end|>").input_ids
    nl_tokens = tokenizer('\n').input_ids
    _system = tokenizer('system').input_ids + nl_tokens
    _user = tokenizer('user').input_ids + nl_tokens
    _assistant = tokenizer('assistant').input_ids + nl_tokens

    # Apply prompt templates
    input_ids, targets = [], []
    for i, source in enumerate(sources):
        if roles[source[0]["from"]] != roles["user"]:
            source = source[1:]

        input_id, target = [], []
        system = im_start + _system + tokenizer(system_message).input_ids + im_end + nl_tokens
        input_id += system
        target += im_start + [IGNORE_TOKEN_ID] * (len(system)-3) + im_end + nl_tokens
        assert len(input_id) == len(target)
        for j, sentence in enumerate(source):
            role = roles[sentence["from"]]
            _input_id = tokenizer(role).input_ids + nl_tokens + \
                tokenizer(sentence["value"]).input_ids + im_end + nl_tokens
            input_id += _input_id
            if role == '<|im_start|>user':
                _target = im_start + [IGNORE_TOKEN_ID] * (len(_input_id)-3) + im_end + nl_tokens
            elif role == '<|im_start|>assistant':
                _target = im_start + [IGNORE_TOKEN_ID] * len(tokenizer(role).input_ids) + \
                          _input_id[len(tokenizer(role).input_ids)+1:-2] + im_end + nl_tokens
            else:
                raise NotImplementedError
            target += _target
        assert len(input_id) == len(target)
        input_id += [tokenizer.pad_token_id] * (max_len - len(input_id))
        target += [IGNORE_TOKEN_ID] * (max_len - len(target))
        input_ids.append(input_id[:max_len])
        targets.append(target[:max_len])
    input_ids = torch.tensor(input_ids, dtype=torch.int)
    targets = torch.tensor(targets, dtype=torch.int)

    return dict(
        input_ids=input_ids,
        labels=targets,
        attention_mask=input_ids.ne(tokenizer.pad_token_id),
    )

def preprocess_classification(
    sources,
    sources_target,
    tokenizer: transformers.PreTrainedTokenizer,
    max_len: int,
    system_message: str = SYSTEM,
) -> Dict:
    roles = {"user": "<|im_start|>user", "assistant": "<|im_start|>assistant"}

    im_start = tokenizer("<|im_start|>").input_ids
    im_end = tokenizer("<|im_end|>").input_ids
    nl_tokens = tokenizer('\n').input_ids

    # Apply prompt templates
    input_ids, targets = [], []
    for source, source_target in zip(sources, sources_target):
        input_id, target = [], []
        assert len(input_id) == len(target)
        _input_id = im_start + tokenizer(source).input_ids + im_end + nl_tokens
        input_id += _input_id
        _target = im_start + [IGNORE_TOKEN_ID] * (len(_input_id)-3) + im_end + nl_tokens
        target += _target

        _input_id = im_start + tokenizer("target:").input_ids + tokenizer(source_target).input_ids + im_end + nl_tokens
        _target = im_start + [IGNORE_TOKEN_ID] * len(tokenizer("target:").input_ids) + _input_id[len(tokenizer("target:").input_ids) + 1:-2] + im_end + nl_tokens
        input_id += _input_id
        target += _target

        assert len(input_id) == len(target)
        input_id += [tokenizer.pad_token_id] * (max_len - len(input_id))
        target += [IGNORE_TOKEN_ID] * (max_len - len(target))
        input_ids.append(input_id[:max_len])
        targets.append(target[:max_len])
    input_ids = torch.tensor(input_ids, dtype=torch.int)
    targets = torch.tensor(targets, dtype=torch.int)

    return dict(
        input_ids=input_ids,
        labels=targets,
        attention_mask=input_ids.ne(tokenizer.pad_token_id),
    )

def preprocess_classification_eval(sources, sources_target, tokenizer, max_len, system_message: str = SYSTEM):
    im_start = tokenizer("<|im_start|>").input_ids
    im_end = tokenizer("<|im_end|>").input_ids
    nl_tokens = tokenizer('\n').input_ids

    # Apply prompt templates
    input_ids, targets = [], []
    for source, source_target in zip(sources, sources_target):
        input_id, target = [], []
        assert len(input_id) == len(target)
        _input_id = im_start + tokenizer(source).input_ids + im_end + nl_tokens
        input_id += _input_id

        _input_id = im_start + tokenizer("target:").input_ids
        input_id += _input_id

        target = source_target
        input_id = [tokenizer.pad_token_id] * (max_len - len(input_id)) + input_id
        input_ids.append(input_id[:max_len])

        target_id = tokenizer(target).input_ids
        targets.append(target_id + [IGNORE_TOKEN_ID] * (max_len - len(target_id)))

    input_ids = torch.tensor(input_ids, dtype=torch.int)
    targets = torch.tensor(targets, dtype=torch.int)

    # print("[EVAL INPUT IDS]", input_ids)
    # print("[EVAL targets]", targets)

    return dict(
        input_ids=input_ids,
        labels=targets,
        attention_mask=input_ids.ne(tokenizer.pad_token_id), # attention mask is no use
    )


class SupervisedDataset(Dataset):
    """Dataset for supervised fine-tuning."""

    def __init__(self, raw_data, tokenizer: transformers.PreTrainedTokenizer, max_len: int):
        super(SupervisedDataset, self).__init__()

        rank0_print("Formatting inputs...")
        sources = [example["conversations"] for example in raw_data]
        data_dict = preprocess(sources, tokenizer, max_len)

        self.input_ids = data_dict["input_ids"]
        self.labels = data_dict["labels"]
        self.attention_mask = data_dict["attention_mask"]

    def __len__(self):
        return len(self.input_ids)

    def __getitem__(self, i) -> Dict[str, torch.Tensor]:
        return dict(
            input_ids=self.input_ids[i],
            labels=self.labels[i],
            attention_mask=self.attention_mask[i],
        )

class LazyPretrainDataset(Dataset):
    """Dataset for supervised fine-tuning."""

    def __init__(self, raw_data, tokenizer: transformers.PreTrainedTokenizer, max_len: int):
        super(LazyPretrainDataset, self).__init__()
        self.tokenizer = tokenizer
        self.max_len = max_len

        rank0_print("Formatting inputs...Skip in lazy mode... Using pretrain format")
        self.tokenizer = tokenizer
        self.raw_data = raw_data
        self.cached_data_dict = {}

    def __len__(self):
        return len(self.raw_data)

    def __getitem__(self, i) -> Dict[str, torch.Tensor]:
        if i in self.cached_data_dict:
            return self.cached_data_dict[i]

        ret = preprocess_pretrain([self.raw_data[i]["content"]], self.tokenizer, self.max_len)
        ret = dict(
            input_ids=ret["input_ids"][0],
            labels=ret["labels"][0],
            attention_mask=ret["attention_mask"][0],
        )
        self.cached_data_dict[i] = ret

        return ret


class LazySupervisedDataset(Dataset):
    """Dataset for supervised fine-tuning."""

    def __init__(self, raw_data, tokenizer: transformers.PreTrainedTokenizer, max_len: int):
        super(LazySupervisedDataset, self).__init__()
        self.tokenizer = tokenizer
        self.max_len = max_len

        rank0_print("Formatting inputs...Skip in lazy mode")
        self.tokenizer = tokenizer
        self.raw_data = raw_data
        self.cached_data_dict = {}

    def __len__(self):
        return len(self.raw_data)

    def __getitem__(self, i) -> Dict[str, torch.Tensor]:
        if i in self.cached_data_dict:
            return self.cached_data_dict[i]

        ret = preprocess_classification([self.raw_data[i]["input"]], [self.raw_data[i]['target']], self.tokenizer, self.max_len)
        ret = dict(
            input_ids=ret["input_ids"][0],
            labels=ret["labels"][0],
            attention_mask=ret["attention_mask"][0],
        )
        self.cached_data_dict[i] = ret

        return ret

class LazySupervisedMultiroleDataset(Dataset):
    """Dataset for supervised fine-tuning."""

    def __init__(self, raw_data, tokenizer: transformers.PreTrainedTokenizer, max_len: int, target_role: str = None, is_train: bool = True, role_system_profile: dict = None):
        super(LazySupervisedMultiroleDataset, self).__init__()
        self.tokenizer = tokenizer
        self.max_len = max_len
        self.target_role = target_role
        self.is_train = is_train
        self.role_system_profile = role_system_profile if role_system_profile is not None  else {}

        rank0_print("Formatting inputs...Skip in lazy mode...in multi-role mode")
        self.tokenizer = tokenizer
        self.raw_data = self.segment(raw_data, max_len, target_role)
        rank0_print("Data loaded len: {}".format(len(self.raw_data)))
        self.cached_data_dict = {}

    def __len__(self):
        return len(self.raw_data)

    def __getitem__(self, i) -> Dict[str, torch.Tensor]:
        if i in self.cached_data_dict:
            return self.cached_data_dict[i]

        target_role = self.raw_data[i]["conversations"][-1]["from"]
        system_message = self.role_system_profile[target_role] if target_role in self.role_system_profile else None

        if self.is_train:
            ret = preprocess_multirole([self.raw_data[i]["conversations"]], target_role, self.tokenizer, self.max_len, system_message)
            ret = dict(
                input_ids=ret["input_ids"][0],
                labels=ret["labels"][0],
                attention_mask=ret["attention_mask"][0],
            )
        else:
            ret = preprocess_multirole_eval([self.raw_data[i]["conversations"]], target_role, self.tokenizer, self.max_len, system_message)
            ret = dict(
                input_ids=ret["input_ids"][0],
                labels=ret["labels"][0],
                attention_mask=ret["attention_mask"][0],
            )
        self.cached_data_dict[i] = ret

        return ret

    def segment(self, raw_data, max_len, target_role):
        '''
        TODO: pre-process raw_data with different max len to avoid overflow
        Now: filter conversations do not contain target role
        '''
        if target_role is None:
            return raw_data
        new_data = []
        for item in raw_data:
            if target_role in [r['from'] for r in item['conversations']]:
                new_data.append(item)

        return new_data

class LazySupervisedClassificationDataset(Dataset):
    """Dataset for supervised fine-tuning."""

    def __init__(self, raw_data, tokenizer: transformers.PreTrainedTokenizer, max_len: int,  is_train: bool = True):
        super(LazySupervisedClassificationDataset, self).__init__()
        self.tokenizer = tokenizer
        self.max_len = max_len
        self.is_train = is_train

        rank0_print("Formatting inputs...Skip in lazy mode...in classification mode")
        self.tokenizer = tokenizer
        self.raw_data = raw_data
        rank0_print("Data loaded len: {}".format(len(self.raw_data)))
        self.cached_data_dict = {}

    def __len__(self):
        return len(self.raw_data)

    def __getitem__(self, i) -> Dict[str, torch.Tensor]:
        if i in self.cached_data_dict:
            return self.cached_data_dict[i]

        if self.is_train:
            ret = preprocess_classification([self.raw_data[i]["input"]], [self.raw_data[i]['target']], self.tokenizer, self.max_len)
            ret = dict(
                input_ids=ret["input_ids"][0],
                labels=ret["labels"][0],
                attention_mask=ret["attention_mask"][0],
            )
        else:
            ret = preprocess_classification_eval([self.raw_data[i]["input"]], [self.raw_data[i]['target']], self.tokenizer, self.max_len)
            ret = dict(
                input_ids=ret["input_ids"][0],
                labels=ret["labels"][0],
                attention_mask=ret["attention_mask"][0],
            )
        self.cached_data_dict[i] = ret

        return ret
class LazySupervisedDatasetReplay(Dataset):
    """Dataset for supervised fine-tuning."""

    def __init__(self, raw_data, replay_data, tokenizer: transformers.PreTrainedTokenizer, max_len: int):
        super(LazySupervisedDatasetReplay, self).__init__()
        self.tokenizer = tokenizer
        self.max_len = max_len

        rank0_print(f"Formatting inputs...Skip in lazy mode... [ZH] Using Replay {len(replay_data)} data")
        self.tokenizer = tokenizer
        self.raw_data = raw_data
        self.replay_data = replay_data
        self.cached_data_dict = {}

    def __len__(self):
        return len(self.raw_data)

    def __getitem__(self, i) -> Dict[str, torch.Tensor]:
        # 使用李佳琦的数据
        if random.random() < 0.3:
            if i in self.cached_data_dict:
                return self.cached_data_dict[i]

            ret = preprocess([self.raw_data[i]["conversations"]], self.tokenizer, self.max_len)
            ret = dict(
                input_ids=ret["input_ids"][0],
                labels=ret["labels"][0],
                attention_mask=ret["attention_mask"][0],
            )
            self.cached_data_dict[i] = ret
        # 使用 replay 数据
        else:
            index = random.randint(0, len(self.replay_data)-1)
            ret = preprocess([self.replay_data[index]["conversations"]], self.tokenizer, self.max_len)
            ret = dict(
                input_ids=ret["input_ids"][0],
                labels=ret["labels"][0],
                attention_mask=ret["attention_mask"][0],
            )

        return ret

def load_json_data(data_path):
    try:
        if type(data_path) == list:
            json_data = []
            for dp in data_path:
                json_data += json.load(open(dp, "r"))
        else:
            json_data = json.load(open(data_path, "r"))
        return json_data
    except:
        return None


def make_supervised_data_module(
    tokenizer: transformers.PreTrainedTokenizer, data_args, max_len,
) -> Dict:
    """Make dataset and collator for supervised fine-tuning."""
    if data_args.lazy_preprocess:
        if data_args.is_pretrain:
            dataset_cls = LazyPretrainDataset
        elif data_args.replay_path:
            dataset_cls = LazySupervisedDatasetReplay
        elif data_args.multirole_conv:
            dataset_cls = LazySupervisedMultiroleDataset
        elif data_args.classfication:
            dataset_cls = LazySupervisedClassificationDataset
        else:
            dataset_cls = LazySupervisedDataset
    else:
        dataset_cls = SupervisedDataset
    rank0_print("Loading data...")

    # get training json
    rank0_print("[ZH DEBUG] training data path:", data_args.data_path)
    train_json = load_json_data(data_args.data_path)

    if data_args.replay_path:
        replay_json = json.load(open(data_args.replay_path, "r"))
        train_dataset = dataset_cls(train_json, replay_json, tokenizer=tokenizer, max_len=max_len)
    elif data_args.multirole_conv:
        role_system_profile = load_json_data(data_args.role_system_profile)
        train_dataset = dataset_cls(train_json, tokenizer=tokenizer, max_len=max_len, target_role=data_args.target_role, role_system_profile=role_system_profile)
    elif data_args.classfication:
        train_dataset = dataset_cls(train_json, tokenizer=tokenizer, max_len=max_len)
    else:
        train_dataset = dataset_cls(train_json, tokenizer=tokenizer, max_len=max_len)

    if data_args.eval_data_path:
        # get dev json
        rank0_print("[ZH DEBUG] evaluating data path:", data_args.eval_data_path)
        eval_json = load_json_data(data_args.eval_data_path)
        if data_args.multirole_conv:
            role_system_profile = load_json_data(data_args.role_system_profile)
            eval_dataset = dataset_cls(eval_json, tokenizer=tokenizer, max_len=max_len, target_role=data_args.target_role, is_train=False, role_system_profile=role_system_profile)
        elif data_args.classfication:
            eval_dataset = dataset_cls(eval_json, tokenizer=tokenizer, max_len=max_len, is_train=False)
        else:
            eval_dataset = dataset_cls(eval_json, tokenizer=tokenizer, max_len=max_len)
    else:
        eval_dataset = None

    return dict(train_dataset=train_dataset, eval_dataset=eval_dataset)


def train():
    global local_rank

    parser = transformers.HfArgumentParser(
        (ModelArguments, DataArguments, TrainingArguments, LoraArguments)
    )
    (
        model_args,
        data_args,
        training_args,
        lora_args,
    ) = parser.parse_args_into_dataclasses()

    # This serves for single-gpu qlora.
    if getattr(training_args, 'deepspeed', None) and int(os.environ.get("WORLD_SIZE", 1))==1:
        training_args.distributed_state.distributed_type = DistributedType.DEEPSPEED

    local_rank = training_args.local_rank

    device_map = "cuda"
    world_size = int(os.environ.get("WORLD_SIZE", 1))
    ddp = world_size != 1
    if lora_args.q_lora:
        device_map = {"": int(os.environ.get("LOCAL_RANK") or 0)} if ddp else "auto"
        if len(training_args.fsdp) > 0 or deepspeed.is_deepspeed_zero3_enabled():
            logging.warning(
                "FSDP or ZeRO3 are incompatible with QLoRA."
            )

    # Set RoPE scaling factor
    config = transformers.AutoConfig.from_pretrained(
        model_args.model_name_or_path,
        cache_dir=training_args.cache_dir,
        trust_remote_code=True,
    )
    config.use_cache = False
    if training_args.use_lwf:
        config.auto_map['AutoModelForCausalLM'] = "modeling_qwen.QWenLMHeadModelLWF"

    compute_dtype = (
        torch.float16
        if training_args.fp16
        else (torch.bfloat16 if training_args.bf16 else torch.float32)
    )

    # Load model and tokenizer
    config.pad_token_id = config.eos_token_id
    model = transformers.AutoModelForCausalLM.from_pretrained(
        model_args.model_name_or_path,
        config=config,
        cache_dir=training_args.cache_dir,
        # device_map=device_map,
        trust_remote_code=True,
        quantization_config=BitsAndBytesConfig(
            load_in_4bit=False,
            bnb_4bit_use_double_quant=True,
            bnb_4bit_quant_type="nf4",
            bnb_4bit_compute_dtype=compute_dtype,
        )
        if training_args.use_lora and lora_args.q_lora
        else None,
    )

    if training_args.fp16:
        model.half()
    if training_args.bf16:
        model.bfloat16()

    tokenizer = transformers.AutoTokenizer.from_pretrained(
        model_args.model_name_or_path,
        cache_dir=training_args.cache_dir,
        model_max_length=training_args.model_max_length,
        padding_side="right",
        use_fast=False,
        trust_remote_code=True,
    )
    tokenizer.pad_token_id = tokenizer.eos_token_id

    if training_args.use_lora:
        if lora_args.q_lora or 'chat' in model_args.model_name_or_path.lower() or training_args.is_chat_version:
            modules_to_save = None
        else:
            modules_to_save = ["embed_tokens", "lm_head"]
        lora_config = LoraConfig(
            r=lora_args.lora_r,
            lora_alpha=lora_args.lora_alpha,
            target_modules=lora_args.lora_target_modules,
            lora_dropout=lora_args.lora_dropout,
            bias=lora_args.lora_bias,
            task_type="CAUSAL_LM",
            modules_to_save=modules_to_save  # This argument serves for adding new tokens.
        )
        if lora_args.q_lora:
            model = prepare_model_for_kbit_training(
                model, use_gradient_checkpointing=training_args.gradient_checkpointing
            )

        model = get_peft_model(model, lora_config)

        # Print peft trainable params
        model.print_trainable_parameters()

        if training_args.gradient_checkpointing:
            model.enable_input_require_grads()

    # Load data
    data_module = make_supervised_data_module(
        tokenizer=tokenizer, data_args=data_args, max_len=training_args.model_max_length
    )

    # Start trainner
    trainer = QwenTrainer(
        model=model, tokenizer=tokenizer, args=training_args, **data_module, compute_metrics=compute_accuracy
    )

    trainer.train()
    trainer.save_state()

    safe_save_model_for_hf_trainer(trainer=trainer, output_dir=training_args.output_dir, bias=lora_args.lora_bias)


if __name__ == "__main__":
    train()