import os
import torch
from torch import nn
from torch.nn import functional as F

from typing import Dict, List, Tuple, Union, Literal
from typing import Any, Callable, Dict, List, NewType, Optional, Tuple, Union
from datasets import Dataset, DatasetDict, load_dataset, load_from_disk
from peft import LoraConfig, TaskType, get_peft_model, prepare_model_for_kbit_training, PeftModel

from transformers import BitsAndBytesConfig, TrainingArguments, default_data_collator, DataCollatorForSeq2Seq, AutoTokenizer
from transformers.tokenization_utils_base import PreTrainedTokenizerBase
from transformers.utils import PaddingStrategy

from trl import DPOTrainer

from dataclasses import dataclass
from accelerate import Accelerator

TEMPLATE= None
tokenizer = None
cutoff_len = None

device_map={"": Accelerator().local_process_index}
@dataclass
class DataCollatorForSeq2SeqForNeg:
    tokenizer: PreTrainedTokenizerBase
    model: Optional[Any] = None
    padding: Union[bool, str, PaddingStrategy] = True
    max_length: Optional[int] = None
    pad_to_multiple_of: Optional[int] = None
    label_pad_token_id: int = -100
    return_tensors: str = "pt"

    def __call__(self, features, return_tensors=None):
        if return_tensors is None:
            return_tensors = self.return_tensors

        _features = {}
        for key in ['high_sent0_', 'high_sent1_', 'low_sent0_', 'low_sent1_', 'highneg_sent0_', 'lowneg_sent0_']:
            output = self.tokenizer.pad(
                {'input_ids': [feature[key+'input_ids'] for feature in features]},
                padding=self.padding,
                max_length=self.max_length,
                pad_to_multiple_of=self.pad_to_multiple_of,
                return_tensors=return_tensors,
            )
            for k, v in output.items():
                _features[key+k] = v
        features = _features
        return features

class DPORankerTrainer(DPOTrainer):
    add_negative = False
    add_more_negtive = False
    add_hard_negative = False
    def __init__(self, *args, **kwargs):
        super(DPORankerTrainer, self).__init__(*args, **kwargs)

    @staticmethod
    def get_cosing_embeddings(query_embs, product_embs):
        cos = nn.CosineSimilarity(dim=-1)
        return cos(query_embs, product_embs) + 1 # [-1, 1] -> [0, 2] for log
        #cosine_score = torch.sum(query_embs * product_embs, axis=1)
        #return torch.clamp(cosine_score, min=0.0)

    def tokenize_row(self, features, model=None):
        # skip
        return features

    def get_batch_loss_metrics(
        self,
        model,
        batch: Dict[str, Union[List, torch.LongTensor]],
        train_eval: Literal["train", "eval"] = "train",
    ):
        """Compute the DPO loss and other metrics for the given batch of inputs for train or test."""
        metrics = {}

        (
            policy_chosen_logps,
            policy_rejected_logps,
            policy_chosen_logits,
            policy_rejected_logits,
            embs,
        ) = self.concatenated_forward(model, batch)

        # if reference_chosen_logps and reference_rejected_logps in batch use them, otherwise use the reference model
        if "reference_chosen_logps" in batch and "reference_rejected_logps" in batch:
            reference_chosen_logps = batch["reference_chosen_logps"]
            reference_rejected_logps = batch["reference_rejected_logps"]
        else:
            with torch.no_grad():
                if self.ref_model is None:
                    with self.null_ref_context():
                        (
                            reference_chosen_logps,
                            reference_rejected_logps,
                            _,
                            _,
                            ref_embs,
                        ) = self.concatenated_forward(self.model, batch)
                else:
                    (
                        reference_chosen_logps,
                        reference_rejected_logps,
                        _,
                        _,
                        ref_embs,
                    ) = self.concatenated_forward(self.ref_model, batch)

        losses, chosen_rewards, rejected_rewards = self.dpo_loss(
            policy_chosen_logps,
            policy_rejected_logps,
            reference_chosen_logps,
            reference_rejected_logps,
        )
        reward_accuracies = (chosen_rewards > rejected_rewards).float()

        prefix = "eval_" if train_eval == "eval" else ""
        metrics[f"{prefix}rewards/chosen"] = chosen_rewards.mean().cpu()
        metrics[f"{prefix}rewards/rejected"] = rejected_rewards.mean().cpu()
        metrics[f"{prefix}rewards/accuracies"] = reward_accuracies.mean().cpu()
        metrics[f"{prefix}rewards/margins"] = (chosen_rewards - rejected_rewards).mean().cpu()
        metrics[f"{prefix}logps/rejected"] = policy_rejected_logps.detach().mean().cpu()
        metrics[f"{prefix}logps/chosen"] = policy_chosen_logps.detach().mean().cpu()
        metrics[f"{prefix}logits/rejected"] = policy_rejected_logits.detach().mean().cpu()
        metrics[f"{prefix}logits/chosen"] = policy_chosen_logits.detach().mean().cpu()

        # add kl loss between ref_embs and embs
        if self.add_emb_klloss:
            loss_kl = F.kl_div(embs.softmax(dim=-1).log(), ref_embs.softmax(dim=-1), reduction='mean')
            return losses.mean() + loss_kl, metrics
        else:
            return losses.mean(), metrics


    def concatenated_forward(
            self, model: nn.Module, batch: Dict[str, Union[List, torch.LongTensor]]
    ) -> Tuple[torch.FloatTensor, torch.FloatTensor, torch.FloatTensor, torch.FloatTensor, torch.FloatTensor]:
        #prompt_embs = model(**{k.replace("query_", ""): v for k, v in batch.items() if "query" in k})
        #prompt_embs_2 = model(**{k.replace("query_", ""): v for k, v in batch.items() if "query" in k})
        #chosen_embs = model(**{k.replace("chosen_", ""): v for k, v in batch.items() if "chosen" in k})
        #rejected_embs = model(**{k.replace("rejected_", ""): v for k, v in batch.items() if "rejected" in k})
        if False:
            high_embs0 = model(**{k.replace("high_sent0_", ""): v for k, v in batch.items() if "high_sent0" in k})
            high_embs1 = model(**{k.replace("high_sent1_", ""): v for k, v in batch.items() if "high_sent1" in k})
            low_embs0 = model(**{k.replace("low_sent0_", ""): v for k, v in batch.items() if "low_sent0" in k})
            low_embs1 = model(**{k.replace("low_sent1_", ""): v for k, v in batch.items() if "low_sent1" in k})
        else:
            high_sent0_input_ids = batch['high_sent0_input_ids']
            high_sent1_input_ids = batch['high_sent1_input_ids']
            low_sent0_input_ids = batch['low_sent0_input_ids']
            low_sent1_input_ids = batch['low_sent1_input_ids']
            max_len = max(high_sent0_input_ids.shape[1], high_sent1_input_ids.shape[1], low_sent0_input_ids.shape[1], low_sent1_input_ids.shape[1])
            if self.add_hard_negative:
                highneg_sent0_input_ids = batch['highneg_sent0_input_ids']
                lowneg_sent0_input_ids = batch['lowneg_sent0_input_ids']
                max_len = max(max_len, highneg_sent0_input_ids.shape[1], lowneg_sent0_input_ids.shape[1])
            # pad to max
            high_sent0_input_ids = F.pad(high_sent0_input_ids, (max_len - high_sent0_input_ids.shape[1], 0), value=tokenizer.pad_token_id)
            high_sent1_input_ids = F.pad(high_sent1_input_ids, (max_len - high_sent1_input_ids.shape[1], 0), value=tokenizer.pad_token_id)
            low_sent0_input_ids = F.pad(low_sent0_input_ids, (max_len - low_sent0_input_ids.shape[1], 0), value=tokenizer.pad_token_id)
            low_sent1_input_ids = F.pad(low_sent1_input_ids, (max_len - low_sent1_input_ids.shape[1], 0), value=tokenizer.pad_token_id)
            input_ids = torch.cat([high_sent0_input_ids, high_sent1_input_ids, low_sent0_input_ids, low_sent1_input_ids], dim=0)
            if self.add_hard_negative:
                highneg_sent0_input_ids = F.pad(highneg_sent0_input_ids, (max_len - highneg_sent0_input_ids.shape[1], 0), value=tokenizer.pad_token_id)
                lowneg_sent0_input_ids = F.pad(lowneg_sent0_input_ids, (max_len - lowneg_sent0_input_ids.shape[1], 0), value=tokenizer.pad_token_id)
                input_ids = torch.cat([input_ids, highneg_sent0_input_ids, lowneg_sent0_input_ids], dim=0)

            attention_mask = (input_ids != tokenizer.pad_token_id).long()
            embs = model(input_ids=input_ids, attention_mask=attention_mask)
            bsz = high_sent0_input_ids.shape[0]
            if self.add_hard_negative:
                high_embs0, high_embs1, low_embs0, low_embs1, highneg_embs, lowneg_embs = torch.split(embs, [bsz, bsz, bsz, bsz, bsz, bsz], dim=0)
            else:
                high_embs0, high_embs1, low_embs0, low_embs1 = torch.split(embs, [bsz, bsz, bsz, bsz], dim=0)

        if self.gather_embs and torch.distributed.is_initialized():
            import torch.distributed as dist
            high_embs0_list = [torch.zeros_like(high_embs0) for _ in range(dist.get_world_size())]
            high_embs1_list = [torch.zeros_like(high_embs1) for _ in range(dist.get_world_size())]
            low_embs0_list = [torch.zeros_like(low_embs0) for _ in range(dist.get_world_size())]
            low_embs1_list = [torch.zeros_like(low_embs1) for _ in range(dist.get_world_size())]
            dist.all_gather(high_embs0_list, high_embs0.contiguous())
            dist.all_gather(high_embs1_list, high_embs1.contiguous())
            dist.all_gather(low_embs0_list, low_embs0.contiguous())
            dist.all_gather(low_embs1_list, low_embs1.contiguous())
            high_embs0_list[dist.get_rank()] = high_embs0
            high_embs1_list[dist.get_rank()] = high_embs1
            low_embs0_list[dist.get_rank()] = low_embs0
            low_embs1_list[dist.get_rank()] = low_embs1
            high_embs0 = torch.cat(high_embs0_list, dim=0)
            high_embs1 = torch.cat(high_embs1_list, dim=0)
            low_embs0 = torch.cat(low_embs0_list, dim=0)
            low_embs1 = torch.cat(low_embs1_list, dim=0)
            bsz = high_embs0.shape[0]

        chosen_logits = self.get_cosing_embeddings(high_embs0, high_embs1)
        chosen_logps = chosen_logits.log()

        rejected_logits = self.get_cosing_embeddings(low_embs0, low_embs1)
        rejected_logps = rejected_logits.log()
        if self.add_negative:
            if self.add_more_negtive or self.add_more_more_negative:
                if self.add_more_more_negative:
                    similars = [self.get_cosing_embeddings(low_embs0, high_embs1).unsqueeze(1),
                                self.get_cosing_embeddings(high_embs0, low_embs1).unsqueeze(1),]
                    for i in range(1, bsz):
                        similars.append(self.get_cosing_embeddings(low_embs0, high_embs1.roll(i, 0)).unsqueeze(1))
                        similars.append(self.get_cosing_embeddings(high_embs0, low_embs1.roll(i, 0)).unsqueeze(1))
                    negative_logits = torch.cat(similars, dim=1).max(dim=1)[0]
                else:
                    negative_logits = torch.cat([self.get_cosing_embeddings(low_embs0, high_embs1).unsqueeze(1),
                                                self.get_cosing_embeddings(low_embs0.flip(0), high_embs1).unsqueeze(1),
                                                self.get_cosing_embeddings(high_embs0, low_embs1).unsqueeze(1),
                                                self.get_cosing_embeddings(high_embs0.flip(0), low_embs1).unsqueeze(1)], dim=1).max(dim=1)[0]
                if self.add_negative_gate:
                    chosen_logits = torch.cat([
                        chosen_logits,
                        torch.cat([rejected_logits.unsqueeze(1), negative_logits.unsqueeze(1)], dim=1).min(dim=1)[0],
                        ], dim=0)
                else:
                    chosen_logits = torch.cat([
                        chosen_logits,
                        rejected_logits,
                        ], dim=0)
            elif self.add_hard_negative:
                negative_logits = torch.cat([self.get_cosing_embeddings(highneg_embs, high_embs0).unsqueeze(1),
                                             self.get_cosing_embeddings(lowneg_embs, low_embs0).unsqueeze(1),
                                             self.get_cosing_embeddings(low_embs0, high_embs1).unsqueeze(1),
                                             self.get_cosing_embeddings(high_embs0, low_embs1).unsqueeze(1)], dim=1).max(dim=1)[0]
                chosen_logits = torch.cat([
                    chosen_logits,
                    chosen_logits,
                    ], dim=0)
            else:
                negative_logits = torch.cat([self.get_cosing_embeddings(low_embs0, high_embs1).unsqueeze(1),
                                             self.get_cosing_embeddings(high_embs0, low_embs1).unsqueeze(1)], dim=1).max(dim=1)[0]
                chosen_logits = torch.cat([
                    chosen_logits,
                    rejected_logits,
                    ], dim=0)

            rejected_logits = torch.cat([
                rejected_logits,
                negative_logits,
                ], dim=0)

            chosen_logps = chosen_logits.log()
            rejected_logps = rejected_logits.log()

        return (chosen_logps, rejected_logps, chosen_logits, rejected_logits, embs)


def preprocess_function(examples, small_score_gap=False, e5=False):
    result = {}

    queries = examples["sent0"]
    products = examples["sent1"]
    hard_neg = examples["hard_neg"]
    scores = examples['score']
    sort_sents = sorted(zip(scores, queries, products, hard_neg), reverse=True, key=lambda x: x[0])
    if small_score_gap:
        high_sents = sort_sents[0::2]
        low_sents = sort_sents[1::2]
    else:
        high_sents = sort_sents[:len(sort_sents)//2]
        low_sents = sort_sents[len(sort_sents)//2:]

    gap_score = []
    for i in range(len(sort_sents)//2):
        gap_score.append(high_sents[i][0] - low_sents[i][0])

    a1, b1, a2, b2 = [], [], [], []
    an1, bn1 = [], []
    for i in range(len(sort_sents)//2):
        a1.append(high_sents[i][1])
        b1.append(high_sents[i][2])
        a2.append(low_sents[i][1])
        b2.append(low_sents[i][2])
        an1.append(high_sents[i][3])
        bn1.append(low_sents[i][3])

    result_a1 = []
    result_b1 = []
    result_a2 = []
    result_b2 = []
    result_an1 = []
    result_bn1 = []
    template = TEMPLATE.replace("_", " ")
    for i,j,k,l,m,n in zip(a1, b1, a2, b2, an1, bn1):
        i = tokenizer.decode(tokenizer.encode(i, add_special_tokens=False)[:cutoff_len])
        j = tokenizer.decode(tokenizer.encode(j, add_special_tokens=False)[:cutoff_len])
        k = tokenizer.decode(tokenizer.encode(k, add_special_tokens=False)[:cutoff_len])
        l = tokenizer.decode(tokenizer.encode(l, add_special_tokens=False)[:cutoff_len])
        m = tokenizer.decode(tokenizer.encode(m, add_special_tokens=False)[:cutoff_len])
        n = tokenizer.decode(tokenizer.encode(n, add_special_tokens=False)[:cutoff_len])
        if e5:
            def get_detailed_instruct(task_description: str, query: str) -> str:
                return f'Instruct: {task_description}\nQuery: {query}'
            task_description = 'Retrieve semantically similar text.'
            result_a1.append(get_detailed_instruct(task_description, i))
            result_b1.append(get_detailed_instruct(task_description, j))
            result_a2.append(get_detailed_instruct(task_description, k))
            result_b2.append(get_detailed_instruct(task_description, l))
            result_an1.append(get_detailed_instruct(task_description, m))
            result_bn1.append(get_detailed_instruct(task_description, n))
        else:
            result_a1.append(template.replace("*sent 0*", i))
            result_b1.append(template.replace("*sent 0*", j))
            result_a2.append(template.replace("*sent 0*", k))
            result_b2.append(template.replace("*sent 0*", l))
            result_an1.append(template.replace("*sent 0*", m))
            result_bn1.append(template.replace("*sent 0*", n))

    result['high_sent0'] = result_a1
    result['high_sent1'] = result_b1
    result['low_sent0'] = result_a2
    result['low_sent1'] = result_b2
    result['highneg_sent0'] = result_an1
    result['lowneg_sent0'] = result_bn1
    result['gap_score'] = gap_score
    return result

def tokenized_function(examples, e5=False):
    result = {}
    for prefix in ['high_sent0', 'high_sent1', 'low_sent0', 'low_sent1', 'highneg_sent0', 'lowneg_sent0']:
        prompt = examples[prefix]
        if e5:
            result_chosen = tokenizer(prompt, max_length=512 + 1, return_attention_mask=False, padding=False, truncation=True)
            result_chosen['input_ids'] = [input_ids + [tokenizer.eos_token_id] for input_ids in result_chosen['input_ids']]
            #result_chosen = tokenizer.pad(result_chosen, padding=True, return_attention_mask=True, return_tensors='pt')
            for k, v in result_chosen.items():
                result[f"{prefix}_{k}"] = v
        else:
            result_chosen = tokenizer(prompt, truncation=True, max_length=80)
            for k, v in result_chosen.items():
                result[f"{prefix}_{k}"] = v
    return result

from transformers import LlamaModel, LlamaPreTrainedModel
class Llama2ModelForSentenceEmbedding(LlamaPreTrainedModel):
    def __init__(self, config):
        super().__init__(config)
        self.num_labels = config.num_labels
        self.model = LlamaModel(config)
        self.score = nn.Linear(10, 1, bias=False)
        self.post_init()

    def forward(self, **kwargs):
        kwargs['output_hidden_states'] = True
        kwargs['return_dict'] = True

        embeddings = self.model(**kwargs).hidden_states[-1][:, -1, :]
        return embeddings

from transformers import OPTPreTrainedModel, OPTModel
class OPT2ModelForSentenceEmbedding(OPTPreTrainedModel):
    def __init__(self, config):
        super().__init__(config)
        self.num_labels = config.num_labels
        self.model = OPTModel(config)
        self.score = nn.Linear(10, 1, bias=False)
        self.post_init()

    def forward(self, **kwargs):
        kwargs['output_hidden_states'] = True
        kwargs['return_dict'] = True

        embeddings = self.model(**kwargs).hidden_states[-1][:, -1, :]
        return embeddings

from torch import Tensor
from transformers import MistralModel, MistralPreTrainedModel
def last_token_pool(last_hidden_states: Tensor,
                attention_mask: Tensor) -> Tensor:
    left_padding = (attention_mask[:, -1].sum() == attention_mask.shape[0])
    if left_padding:
        return last_hidden_states[:, -1]
    else:
        sequence_lengths = attention_mask.sum(dim=1) - 1
        batch_size = last_hidden_states.shape[0]
        return last_hidden_states[torch.arange(batch_size, device=last_hidden_states.device), sequence_lengths]

class MistralModelForSentenceEmbedding(MistralPreTrainedModel):
    def __init__(self, config):
        super().__init__(config)
        self.model = MistralModel(config)
        self.vocab_size = config.vocab_size
        self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False)
        self.score = nn.Linear(10, 1, bias=False)
        # Initialize weights and apply final processing
        self.post_init()

    def forward(self, **kwargs):
        kwargs['output_hidden_states'] = True
        kwargs['return_dict'] = True

        hidden_states = self.model(**kwargs).hidden_states
        embeddings = last_token_pool(hidden_states[-1], kwargs['attention_mask'])
        return embeddings

def main(
        output_dir: str = "peft_adapter_weight_path",
        model_name_or_path: str = "",
        per_device_train_batch_size: int = 2,
        gradient_accumulation_steps: int = 4,
        learning_rate: float = 1e-4,
        small_score_gap: bool = False,
        same_ref_model: bool = False,
        add_negative: bool = False,
        add_more_negtive: bool = False,
        add_more_more_negative: bool = False,
        gather_embs: bool = False,
        add_hard_negative: bool = False,
        add_negative_gate: bool = False,
        seed: int = 42,
        not_use_bf16: bool = False,
        add_emb_klloss: bool = False,
        logging_steps: int = 10,
):

    global TEMPLATE, tokenizer, cutoff_len

    TEMPLATE="This_sentence_:_\"*sent_0*\"_means_in_one_word:\""
    if 'opt' in model_name_or_path:
        tokenizer = AutoTokenizer.from_pretrained("facebook/opt-2.7b")
        MODEL_CLASS = OPT2ModelForSentenceEmbedding
        tokenizer.pad_token_id = 0
        tokenizer.padding_side = "left"  # Allow batched inference
    elif 'e5' in model_name_or_path:
        tokenizer = AutoTokenizer.from_pretrained("intfloat/e5-mistral-7b-instruct")
        MODEL_CLASS = MistralModelForSentenceEmbedding
        tokenizer.pad_token_id = 0
    else:
        tokenizer = AutoTokenizer.from_pretrained("huggyllama/llama-7b")
        MODEL_CLASS = Llama2ModelForSentenceEmbedding
        tokenizer.pad_token = '[PAD]'
        tokenizer.padding_side = "left"  # Allow batched inference
    cutoff_len = 64

    dataset = load_from_disk('./data/nil_score')

    from functools import partial
    data_cache_dir = f'./data/nil_score/processed_dataset_{model_name_or_path.replace("/", "_")}'
    if os.path.exists(data_cache_dir):
        tokenized_datasets = load_from_disk(data_cache_dir)
    else:
        processed_datasets = dataset.map(
            partial(preprocess_function, small_score_gap=small_score_gap, e5='e5-' in model_name_or_path),
            batched=True,
            batch_size=1000,
            remove_columns=dataset.column_names,
            num_proc=50
        )

        tokenized_datasets = processed_datasets.map(
                    partial(tokenized_function, e5='e5-' in model_name_or_path),
                    batched=True,
                    remove_columns=processed_datasets.column_names,
                    num_proc=50
        )
        tokenized_datasets.save_to_disk(data_cache_dir)

    torch_dtype = torch.bfloat16
    if not_use_bf16:
        torch_dtype = torch.float16
    bnb_config = BitsAndBytesConfig(
        load_in_4bit=True,
        bnb_4bit_use_double_quant=True,
        bnb_4bit_quant_type="nf4",
        bnb_4bit_compute_dtype=torch_dtype,
    )

    model = MODEL_CLASS.from_pretrained(model_name_or_path,
                                        device_map=device_map,
                                        torch_dtype=torch_dtype,
                                        quantization_config=bnb_config)

    if same_ref_model:
        model_ref = model
    else:
        model_ref = MODEL_CLASS.from_pretrained(model_name_or_path,
                                                device_map=device_map,
                                                torch_dtype=torch_dtype,
                                                quantization_config=bnb_config)

    all_linear = 'q_proj,k_proj,v_proj,o_proj,gate_proj,down_proj,up_proj'
    if 'opt' in model_name_or_path:
        all_linear = 'q_proj,k_proj,v_proj,out_proj,fc1,fc2'

    peft_config = LoraConfig(
        r=64,
        lora_alpha=16,
        lora_dropout=0.05,
        bias="none",
        task_type=TaskType.FEATURE_EXTRACTION,
        target_modules=all_linear.split(','),
    )

    training_args = TrainingArguments(
        output_dir=output_dir,
        per_device_train_batch_size=per_device_train_batch_size,
        gradient_accumulation_steps=gradient_accumulation_steps,
        per_device_eval_batch_size=4,
        learning_rate=learning_rate,
        num_train_epochs=1,
        gradient_checkpointing_kwargs={"use_reentrant": False},
        weight_decay=0.05,
        save_strategy="steps",
        save_steps=50,
        logging_strategy="steps",
        logging_steps=logging_steps,
        lr_scheduler_type="linear",
        warmup_steps=100,
        #optim="paged_adamw_8bit",
        bf16=False if not_use_bf16 else True,
        fp16=True if not_use_bf16 else False,
        remove_unused_columns=False,
        label_names=['labels'],
        save_total_limit=1000,
        seed=seed,
    )

    DC_FUN = DataCollatorForSeq2SeqForNeg

    dpo_trainer = DPORankerTrainer(
        model,
        model_ref,
        args=training_args,
        beta=0.1,
        data_collator=DC_FUN(tokenizer, pad_to_multiple_of=8, return_tensors="pt", padding=True),
        train_dataset=tokenized_datasets,
        #eval_dataset=tokenized_datasets['validation'],
        tokenizer=tokenizer,
        peft_config=peft_config
    )
    dpo_trainer.add_negative = add_negative
    dpo_trainer.add_more_negtive = add_more_negtive
    dpo_trainer.add_more_more_negative = add_more_more_negative
    dpo_trainer.gather_embs = gather_embs
    dpo_trainer.add_hard_negative = add_hard_negative
    dpo_trainer.add_negative_gate = add_negative_gate
    dpo_trainer.add_emb_klloss = add_emb_klloss
    # call train
    dpo_trainer.train()

    # save peft adapter model
    dpo_trainer.save_model("peft_adapter_weight_path")


if __name__ == "__main__":
    from fire import Fire
    Fire(main)
