import argparse
import os
os.environ['HF_ENDPOINT'] = 'https://hf-mirror.com'
import platform
import shutil
from copy import deepcopy
from threading import Thread
from tqdm import tqdm
import json
import torch.distributed as dist
from transformers.trainer_utils import EvalPrediction
from peft import AutoPeftModelForCausalLM
from transformers.generation import GenerationConfig
import torch
from transformers import AutoModelForCausalLM, AutoTokenizer, TextIteratorStreamer
from transformers.trainer_utils import set_seed
from torch.utils.data import DataLoader

from finetune import LazySupervisedMultiroleDataset, preprocess_multirole_eval, load_json_data
from evaluate.rouge import compute_rouge

# DEFAULT_CKPT_PATH = "Qwen/Qwen1.5-1.8B-Chat"
# DEFAULT_EVAL_RESULT_PATH = "output/wulin/multirole-Qwen1.5-1.8B-Chat_fp16_eval_result"
# DEFAULT_EVAL_DATA_PATH = "data/chat/wulin/wulinwaizhuan_diags_L512_dev.json"
# ROLE_SYSTEM_PROFILE_PATH="data/chat/wulin/role_system_profile.json"

DEFAULT_CKPT_PATH = "/mnt/workspace/zenghang/LLM/Qwen/output/wulin/multirole_qwen7b_L512_instruction/checkpoint-550"
DEFAULT_EVAL_RESULT_PATH = "output/wulin/tongbai-multirole-L512-insturction-tuned-fp16_eval_result"
DEFAULT_EVAL_DATA_PATH = "data/chat/wulin/wulinwaizhuan_diags_tong_bai_L512_dev.json"
ROLE_SYSTEM_PROFILE_PATH = "data/chat/wulin/role_system_profile.json"
# TARGET_ROLE = "佟湘玉"
TARGET_ROLE = None

def _load_model_tokenizer(args):
    if args.cpu_only:
        device_map = "cpu"
    else:
        device_map = "auto"

    if args.peft:
        adapter_config = json.load(open(os.path.join(args.checkpoint_path, "adapter_config.json")))
        base_model = adapter_config['base_model_name_or_path']
        tokenizer = AutoTokenizer.from_pretrained(
            base_model, trust_remote_code=True, resume_download=True,
        )
        model = AutoPeftModelForCausalLM.from_pretrained(
            args.checkpoint_path,
            device_map=device_map,
            trust_remote_code=True,
            resume_download=True,
        ).eval()
    else:
        tokenizer = AutoTokenizer.from_pretrained(
            args.checkpoint_path, trust_remote_code=True, resume_download=True,
        )
        model = AutoModelForCausalLM.from_pretrained(
            args.checkpoint_path,
            device_map=device_map,
            trust_remote_code=True,
            resume_download=True,
        ).eval()

    config = GenerationConfig.from_pretrained(
        args.checkpoint_path, trust_remote_code=True, resume_download=True,
    )

    return model, tokenizer, config

def setup_for_distributed(is_master):
    """
    This function disables printing when not in master process
    """
    import builtins as __builtin__
    builtin_print = __builtin__.print

    def print(*args, **kwargs):
        force = kwargs.pop('force', False)
        if is_master or force:
            builtin_print(*args, **kwargs)

    __builtin__.print = print

def init_distributed_mode(args):
    if 'RANK' in os.environ and 'WORLD_SIZE' in os.environ:
        args.rank = int(os.environ["RANK"])
        args.world_size = int(os.environ['WORLD_SIZE'])
        args.gpu = int(os.environ['LOCAL_RANK'])
    elif 'SLURM_PROCID' in os.environ:
        args.rank = int(os.environ['SLURM_PROCID'])
        args.gpu = args.rank % torch.cuda.device_count()
    else:
        print('Not using distributed mode')
        args.distributed = False
        return

    args.distributed = True

    torch.cuda.set_device(args.gpu)
    args.dist_backend = 'nccl'
    print('| distributed init (rank {}): {}'.format(args.rank, args.dist_url), flush=True)
    torch.distributed.init_process_group(backend=args.dist_backend, init_method=args.dist_url,
                                         world_size=args.world_size, rank=args.rank)
    torch.distributed.barrier()
    setup_for_distributed(args.rank == 0)

if __name__ == '__main__':
    parser = argparse.ArgumentParser(description='QWen1.5-Chat command-line interactive chat demo.')
    parser.add_argument("-c", "--checkpoint-path", type=str, default=DEFAULT_CKPT_PATH, help="Checkpoint name or path, default to %(default)r")
    parser.add_argument("-s", "--seed", type=int, default=1234, help="Random seed")
    parser.add_argument("--max_len", type=int, default=592, help="")
    parser.add_argument("--batch_size", type=int, default=8, help="")
    parser.add_argument("--cpu-only", action="store_true", help="Run demo with CPU only")
    parser.add_argument("--multi-role", action="store_true", help="Multi-role mode") #TODO: 现在默认是 multirole
    parser.add_argument("--peft", action="store_true", help="is peft model or not")
    parser.add_argument("--eval_data_path", type=str, default=DEFAULT_EVAL_DATA_PATH, help="")
    parser.add_argument("--role_system_profile_path", type=str, default=ROLE_SYSTEM_PROFILE_PATH, help="")
    parser.add_argument("--eval_result_path", type=str, default=DEFAULT_EVAL_RESULT_PATH, help="")
    parser.add_argument('--dist_url', default='env://', help='url used to set up distributed training')
    parser.add_argument("--distributed", action="store_true", help="")
    parser.add_argument("--fp16", action="store_true", help="")
    args = parser.parse_args()

    if args.distributed:
        init_distributed_mode(args)

    set_seed(args.seed)

    model, tokenizer, config = _load_model_tokenizer(args)
    tokenizer.pad_token_id = tokenizer.eos_token_id
    model.generation_config.pad_token_id = tokenizer.pad_token_id
    if args.fp16:
        model.half()
    if args.distributed:
        model = torch.nn.parallel.DistributedDataParallel(model, device_ids=[args.gpu])
        model_without_ddp = model.module
    else:
        model_without_ddp = model
    model.eval()

    # build eval dataset
    dataset_cls = LazySupervisedMultiroleDataset
    eval_json = load_json_data(args.eval_data_path)
    role_system_profile = json.load(open(args.role_system_profile_path, 'r')) if os.path.exists(args.role_system_profile_path) else None  # try except in one line
    eval_dataset = dataset_cls(eval_json, tokenizer=tokenizer, max_len=args.max_len, target_role=TARGET_ROLE, is_train=False, role_system_profile=role_system_profile)
    eval_dataloader = DataLoader(eval_dataset, batch_size=args.batch_size, shuffle=False, num_workers=4)

    all_inputs, all_preds, all_labels = [], [], []
    for inputs in tqdm(eval_dataloader):
        inputs['input_ids'] = inputs['input_ids'].to(model_without_ddp.device).to(torch.int64)
        inputs['attention_mask'] = inputs['attention_mask'].to(model_without_ddp.device).to(torch.int64)
        model_pred = model_without_ddp.generate(inputs['input_ids'], attention_mask=inputs['attention_mask'], max_new_tokens=128)
        inputs_str = tokenizer.batch_decode(inputs['input_ids'], skip_special_tokens=True)
        model_pred_str = tokenizer.batch_decode(model_pred, skip_special_tokens=True)
        for pred, input_str in zip(model_pred_str, inputs_str):
            if not pred.startswith(input_str):
                print(f"[ZH ERROR] The model prediction must start with input_str\npred: {pred}\ninput_str: {input_str}")
                input_str = pred[-10:]
            all_inputs.append(input_str)
            all_preds.append(pred.replace(input_str, "").replace("<|endoftext|>", "").strip())
        labels = torch.where(inputs['labels'] < 0, tokenizer.pad_token_id, inputs['labels'])
        all_labels += tokenizer.batch_decode(labels, skip_special_tokens=True)
        assert len(all_preds) == len(all_labels), f"[ZH ERROR] len(all_preds) != len(all_labels): {len(all_preds)} / {len(all_labels)}"

    os.makedirs(args.eval_result_path, exist_ok=True)
    eval_result_file = os.path.join(args.eval_result_path, f"pred_result.json")
    metrics = compute_rouge(EvalPrediction(predictions=all_preds, label_ids=all_labels), all_inputs, eval_result_file)
    print("{} metrics: {}".format(args.checkpoint_path, json.dumps(metrics)))
