"""
New Dataloading method for MDPO
"""
import json
from llamafactory.data.template import get_template_and_fix_tokenizer
from llamafactory.data.parser import get_dataset_list
from typing import Any, Dict, List, Generator, Tuple
from datasets import Dataset
from datasets import load_from_disk
from datasets import Dataset, IterableDataset
from transformers import Seq2SeqTrainingArguments
from transformers.tokenization_utils import PreTrainedTokenizer
from llamafactory.hparams import DataArguments
from llamafactory.extras.logging import get_logger
import os

logger = get_logger(__name__)

def load_json(file_path):
    max_K = 0
    tmp = []
    with open(file_path, "r") as f:
        data = json.load(f)
        for i, d in enumerate(data):
            max_K = max(max_K, len(d["output"][0]))
            # 最后一个为rejected， 前面都为chosen
            response = [chosen for chosen in d["output"][0]]
            response.append(d["output"][1]) 
            tmp.append({
                'prompt': d["instruction"],
                'response': response,
            })
    dataset = Dataset.from_list(tmp)
    return dataset, max_K

def get_dataset(model_args, data_args):
    data_file = get_dataset_list(data_args)[0].dataset_name  # No concatenation
    # with training_args.main_process_first(desc="load_dataset: "):
    data, max_K = load_json(data_file)
    return data, max_K

def infer_max_len(source_len: int, target_len: int, data_args: "DataArguments") -> Tuple[int, int]:
    max_target_len = int(data_args.cutoff_len * (target_len / (source_len + target_len)))
    max_target_len = max(max_target_len, data_args.reserved_label_len)
    max_source_len = data_args.cutoff_len - max_target_len
    return max_source_len, max_target_len

def preprocess_dataset(dataset, tokenizer, data_args, training_args, multiple_K, stage="rm"):
    """
    dataset: [{'instruction': xxx, 'response': [chosen_1, chosen_2, ..., rejected], query:"", history: [] }, ...]
    """
    
    template = get_template_and_fix_tokenizer(tokenizer, data_args.template)

    def construct_example(examples: Dict[str, List[Any]]) -> Generator[Any, None, None]:
        for i in range(len(examples["prompt"])):
            query, response = examples["prompt"][i], examples["response"][i]
            query = query + "\n" + examples["query"][i] if "query" in examples and examples["query"][i] else query
            history = examples["history"][i] if "history" in examples else None
            system = examples["system"][i] if "system" in examples else None
            yield query, response, history, system

    def preprocess_multiple_positive_dataset(examples: Dict[str, List[Any]]) -> Dict[str, List[List[int]]]:
        # build input pairs with format `<bos> X`, `Y1 <eos>` and `Y2 <eos>`
        model_inputs = {"prompt_ids": [], "rejected_ids": []} 
        # [chosen_1, chosen_2, ..., rejected]
        for i in range(multiple_K):
            model_inputs["chosen_ids_{}".format(i)] = []
        
        for query, response, history, system in construct_example(examples):
            if not ((isinstance(query, str) and isinstance(response, list) and (query.strip() != "") and (len(response) > 1))):
                continue
            messages =[
                {'role': 'user', 'content': query},
                {'role': 'assistant', 'content': response[-1]}
            ]
            prompt_ids, rejected_ids = template.encode_oneturn(tokenizer, messages)

            if template.efficient_eos:
                rejected_ids += [tokenizer.eos_token_id]

            chosen_ids_list = []
            target_len = len(rejected_ids)
            for i in range(multiple_K):
                messages =[
                    {'role': 'user', 'content': query},
                    {'role': 'assistant', 'content': response[i]}
                ]
                _, chosen_ids = template.encode_oneturn(tokenizer, messages)
                if template.efficient_eos:
                    chosen_ids += [tokenizer.eos_token_id]
                chosen_ids_list.append(chosen_ids)
                target_len = max(target_len, len(chosen_ids))

            source_len = len(prompt_ids)
            max_source_len, max_target_len = infer_max_len(source_len, target_len, data_args)
            if source_len > max_source_len:
                prompt_ids = prompt_ids[:max_source_len]
            if target_len > max_target_len:
                for i in range(multiple_K):
                    chosen_ids_list[i] = chosen_ids_list[i][:max_target_len]
                rejected_ids = rejected_ids[:max_target_len]

            model_inputs["prompt_ids"].append(prompt_ids)
            model_inputs["rejected_ids"].append(rejected_ids)
            for i in range(multiple_K):
                model_inputs["chosen_ids_{}".format(i)].append(chosen_ids_list[i])

        return model_inputs

    def print_multiple_positive_dataset_example(example: Dict[str, List[int]]) -> None:
        print("prompt_ids:\n{}".format(example["prompt_ids"]))
        print("prompt:\n{}".format(tokenizer.decode(example["prompt_ids"], skip_special_tokens=False)))
        for i in range(multiple_K):
            print("chosen_ids_{}:\n{}".format(i, example["chosen_ids_{}".format(i)]))
            print("chosen_{}:\n{}".format(i, tokenizer.decode(example["chosen_ids_{}".format(i)], skip_special_tokens=False)))
        print("rejected_ids:\n{}".format(example["rejected_ids"]))
        print("rejected:\n{}".format(tokenizer.decode(example["rejected_ids"], skip_special_tokens=False)))


    # if data_args.cache_path is not None and os.path.exists(data_args.cache_path):
    #     logger.warning("Loading dataset from disk will ignore other data arguments.")
    #     return load_from_disk(data_args.cache_path)
    
    with training_args.main_process_first(desc="dataset map pre-processing"):
        column_names = list(next(iter(dataset)).keys())
        kwargs = {}
        if not data_args.streaming:
            kwargs = dict(
                num_proc=data_args.preprocessing_num_workers,
                load_from_cache_file=(not data_args.overwrite_cache),
                desc="Running tokenizer on dataset"
            )

        dataset = dataset.map(
            preprocess_multiple_positive_dataset,
            batched=True,
            remove_columns=column_names,
            **kwargs
        )
    # if data_args.cache_path is not None and not os.path.exists(data_args.cache_path):
    #         if training_args.should_save:
    #             dataset.save_to_disk(data_args.cache_path)
    #         raise SystemExit("Dataset saved, rerun this script with the same `--cache_path`.")

    return dataset
    