# Copyright (c) Microsoft Corporation.
# SPDX-License-Identifier: Apache-2.0

# DeepSpeed Team
import pdb
import os
import sys
import copy
import torch
import logging
import argparse
import jsonlines
import deepspeed

sys.path.append(
            os.path.abspath(os.path.join(os.path.dirname(__file__), os.path.pardir)))
from transformers import (
    AutoConfig,
    AutoModelForCausalLM,
    AutoModelForSeq2SeqLM,
    AutoTokenizer,
)
from tqdm import tqdm
from typing import List

from utils.module.lora import convert_linear_layer_to_lora, convert_lora_to_linear_layer, only_optimize_lora_parameters
from utils.data.data_utils import batchify, shard_data
from utils.dist_utils import wait_for_everyone, is_main_process
from utils.utils import to_device
from utils.model.model_utils import create_hf_model
from utils.ds_utils import get_eval_ds_config
IGNORE_INDEX=-100

logger = logging.getLogger(__name__)


def parse_args():
    parser = argparse.ArgumentParser(description="Eval the finetued SFT model")
    parser.add_argument(
        "--model_name_or_path_baseline",
        type=str,
        help="Path to baseline model",
        required=True,
    )
    parser.add_argument(
        "--model_name_or_path_finetune",
        type=str,
        help="Path to pretrained model",
        required=True,
    )
    parser.add_argument(
        "--per_device_batch_size",
        type=int,
        default=4,
        help="Batch size (per device) for the evaluation dataloader.",
    )
    parser.add_argument('--bf16',
        action='store_true',
        help='Enable bfloat16 techniques.')
    parser.add_argument(
        "--num_beams",
        type=int,
        default=4,
        help='Specify num of beams',
    )
    parser.add_argument(
        "--num_beam_groups",
        type=int,
        default=1,
        help='Specify num of beams',
    )
    parser.add_argument(
        "--top_k",
        type=int,
        default=4,
        help='topk',
    )
    parser.add_argument(
        "--do_sample",
        action="store_true",
        help="Whether to use sampling; use greedy decoding if this option is disabled.",
    )
    parser.add_argument(
        "--temperature",
        type=float,
        default=1.0,
        help="The value used to module the next token probabilities.",
    )
    parser.add_argument(
        "--penalty_alpha",
        type=float,
        default=0.6,
        help='Penalty of length',
    )
    parser.add_argument(
        "--repetition_penalty",
        type=float,
        default=1.3,
        help='Penalty of repetition',
    )
    parser.add_argument(
        "--num_return_sequences",
        type=int,
        default=1,
        help='Specify num of return sequences',
    )
    parser.add_argument(
        "--max_new_tokens",
        type=int,
        default=1024,
        help='Specify num of return sequences',
    )
    # deepspeed-inference config
    parser.add_argument(
        "--mp_size",
        type=int,
        default=1,
        help="The model parallel size."
    )
    parser.add_argument(
        "--max_seq_len",
        type=int,
        default=2048,
        help="The maximum sequence length.",
    )
    parser.add_argument(
        "--dtype",
        type=str,
        default="fp16",
        help="The data type of the model."
    )
    # data sharding

    parser.add_argument(
        "--num_shards",
        type=int,
        default=1,
        help="Number of shards to split the data into."
    )

    parser.add_argument(
        "--shard_id",
        type=int,
        default=0,
        help="The shard id of the current process."
    )
    parser.add_argument("--language",
                        type=str,
                        default="English",
                        choices=["English", "Chinese", "Japanese", "German"]
    )
    parser.add_argument("--input",
                        type=str,
                        default=""
    )
    parser.add_argument("--output",
                        type=str,
                        default=""
    )
    # deepspeed features
    parser.add_argument('--offload',
                        action='store_true',
                        help='Enable ZeRO Offload techniques.'
    )
    parser.add_argument(
        '--zero_stage',
        type=int,
        default=0,
        help='ZeRO optimization stage for Actor model (and clones).'
    )
    parser = deepspeed.add_config_arguments(parser)
    args = parser.parse_args()

    return args

def pad_sequence(inputs: List[torch.LongTensor], padding_value: int = 0, padding_side="left", batch_first=True):
    max_len = max(seq.size(0) for seq in inputs)

    padding_inputs = []

    for seq in inputs:
        if seq.size(0) < max_len:
            if padding_side == "left":
                seq = torch.cat([torch.ones((max_len - seq.size(0),)).to(seq) * padding_value, seq])
            else:
                seq = torch.cat([seq, torch.ones((max_len - seq.size(0),)).to(seq) * padding_value])
        padding_inputs.append(seq)

    return_tensor = torch.stack(padding_inputs)

    if not batch_first:
        return_tensor = return_tensor.transpose(0, 1)

    return return_tensor

def build_input_and_label_ids(samples, tokenizer, eval_mode=True, compute_target_only=False):
    """
    从符合SFT格式的数据中构造 input_ids 和 label_ids
    """
    if "input" in samples[0]:
        INPUT_KEY = "input"
    elif "inputs" in samples[0]:
        INPUT_KEY = "inputs"
    else:
        raise ValueError("The input key should be either `input` or `inputs`")

    if "target" in samples[0]:
        TARGET_KEY = "target"
    elif "targets" in samples[0]:
        TARGET_KEY = "targets"
    else:
        TARGET_KEY = None  # no target side, all the data are in the `input`

    if "prompt" in samples[0]:
        WITH_PROMPT_TEMPLATE = True
    else:
        WITH_PROMPT_TEMPLATE = False

    def _build_source(item):
        inputs = item[INPUT_KEY]

        if isinstance(inputs, str):
            inputs = [inputs, ]

        if WITH_PROMPT_TEMPLATE:
            inputs = item['prompt'].format(*inputs)
        else:
            inputs = inputs[0]

        return inputs

    def _tokenize_fn(strings, tokenizer, add_eos=True):
        """Tokenize a list of strings."""

        tokenized_list = [
            tokenizer(
                text,
                padding="longest",
                truncation=False,
            )
            for text in strings
        ]

        input_ids = [tokenized.input_ids for tokenized in tokenized_list]
        if add_eos and getattr(tokenizer, 'add_eos_token', False) is False:
            input_ids = [sample + [tokenizer.eos_token_id, ] for sample in input_ids]

        input_ids = [torch.tensor(sample) for sample in input_ids]
        labels = copy.deepcopy(input_ids)

        input_ids_lens = labels_lens = [
            sample_pt.ne(tokenizer.pad_token_id).sum().item() for sample_pt in input_ids
        ]
        return dict(
            input_ids=input_ids,
            labels=labels,
            input_ids_lens=input_ids_lens,
            labels_lens=labels_lens,
        )

    sources = [_build_source(item) for item in samples]

    # 是否需要将target拼入最终的样本
    # eval时target是需要我们生成的部分，所以不需要拼入
    if not eval_mode and TARGET_KEY is not None:
        targets = [item[TARGET_KEY] for item in samples]
    else:
        targets = []

    if len(targets) > 0:
        examples = [s + t for s, t in zip(sources, targets)]
    else:
        examples = sources

    examples_tokenized = _tokenize_fn(examples, tokenizer, add_eos=True if not eval_mode else False)

    input_ids = examples_tokenized["input_ids"]
    labels = copy.deepcopy(input_ids)
    if len(targets) > 0 and compute_target_only:
        sources_tokenized = _tokenize_fn(sources, tokenizer, add_eos=False)
        for label, source_len in zip(labels, sources_tokenized["input_ids_lens"]):
            label[:source_len] = IGNORE_INDEX
    
    return input_ids, labels

def generate(model,
             tokenizer,
             inputs,
             num_beams=1,
             num_beam_groups=1,
             do_sample=False,
             num_return_sequences=1,
             max_new_tokens=100):

    max_new_tokens = min(max_new_tokens, inputs.input_ids.shape[1]*3)
    generate_ids = model.generate(inputs.input_ids,
                                  num_beams=num_beams,
                                  num_beam_groups=num_beam_groups,
                                  repetition_penalty=1.3,
                                  do_sample=do_sample,
                                  num_return_sequences=num_return_sequences,
                                  max_new_tokens=max_new_tokens,
                                  use_cache=True,
                                  eos_token_id=tokenizer.eos_token_id,
                                  early_stopping=True)
    #src = generate_ids[:,:inputs.input_ids.shape[1]]
    #tgt = generate_ids[:,inputs.input_ids.shape[1]:]
    result = tokenizer.batch_decode(generate_ids,
                                    skip_special_tokens=True,
                                    clean_up_tokenization_spaces=False)
    return result


def generate_constrastive_search(model,
                                 tokenizer,
                                 inputs,
                                 top_k=4,
                                 penalty_alpha=0.6,
                                 num_return_sequences=1,
                                 max_new_tokens=100):

    generate_ids = model.generate(inputs.input_ids,
                                  top_k=top_k,
                                  penalty_alpha=penalty_alpha,
                                  num_return_sequences=num_return_sequences,
                                  max_new_tokens=max_new_tokens)

    result = tokenizer.batch_decode(generate_ids,
                                    skip_special_tokens=True,
                                    clean_up_tokenization_spaces=False)
    return result

def main():
    args = parse_args()

    local_rank = int(os.environ.get("LOCAL_RANK", -1))

    if local_rank == -1:
        device = torch.device("cuda")
    else:
        torch.cuda.set_device(local_rank)
        device = torch.device("cuda", local_rank)
        deepspeed.init_distributed()

    args.global_rank = torch.distributed.get_rank()
    config = AutoConfig.from_pretrained(args.model_name_or_path_baseline)
    tokenizer = AutoTokenizer.from_pretrained(args.model_name_or_path_baseline)
    if tokenizer.pad_token is None:
        tokenizer.pad_token_id = tokenizer.unk_token_id
    ds_config = get_eval_ds_config(offload=True,
                                    stage=args.zero_stage)
    if args.bf16:
        ds_config['bf16'] = {"enabled": True,}
        ds_config['fp16']['enabled'] = False
    print(ds_config)
    print(args.model_name_or_path_finetune)
    print(args.input)
    if 'mt5' in args.model_name_or_path_baseline:
        model_fintuned = create_hf_model(AutoModelForSeq2SeqLM,
                                     args.model_name_or_path_finetune,
                                     tokenizer, ds_config)   
    else:                             
        model_fintuned = create_hf_model(AutoModelForCausalLM,
                                     args.model_name_or_path_finetune,
                                     tokenizer, ds_config)
    
    # One observation: if the prompt ends with a space " ", there is a high chance that
    # the original model (without finetuning) will stuck and produce no response.
    # Finetuned models have less such issue. Thus following prompts all end with ":"
    # to make it a more meaningful comparison.
    all_data=[]
    if args.language == "Chinese":
        instruct = "Human: Translate this document from Chinese to English\n"
    elif args.language == "German":
        instruct = "Human: Translate this document from German to English\n"
    elif args.language == "English":
        instruct = "Human: Translate this document from English to German\n"

    
    for line in jsonlines.open(args.input):
        src = line['prompt']
        if 'instruct' in line:
            instruct = line['instruct']
        tmp = dict()
        tmp['input'] = instruct+src
        if 'docid' in line.keys():
            tmp['docid'] = line['docid']
        all_data.append(tmp)

    if args.num_shards > 1:
        all_data = shard_data(all_data, num_shards=args.num_shards, shard_id=args.shard_id)

    batches = list(batchify(all_data, batch_size=args.per_device_batch_size))  
    engine = deepspeed.init_inference(model_fintuned, mp_size=args.mp_size,
                                      dtype=torch.bfloat16 if args.dtype == 'bf16' else torch.float16)
    engine.module.eval()
    model_fintuned = engine.module
    
    num_micro_batches_per_epoch = len(batches)

    progress_bar = tqdm(range(num_micro_batches_per_epoch), disable=local_rank != 0)

    generate_kwargs = dict(max_length=args.max_seq_len, do_sample=args.do_sample, temperature=args.temperature,
                           top_k=args.top_k, repetition_penalty=args.repetition_penalty,
                           pad_token_id=tokenizer.pad_token_id, num_beams=args.num_beams, eos_token_id=tokenizer.eos_token_id, early_stopping=True)

    print(generate_kwargs)
    if args.max_new_tokens > 0:
        generate_kwargs['max_new_tokens'] = args.max_new_tokens


    results = []

    for step, raw_batch in enumerate(batches):

        input_ids, _ = build_input_and_label_ids(samples=raw_batch, tokenizer=tokenizer, eval_mode=True)

        batch = {
            "input_ids": pad_sequence(input_ids, padding_value=tokenizer.pad_token_id, padding_side=tokenizer.padding_side)
        }

        
        if args.max_new_tokens > 0:
            generate_kwargs['max_new_tokens'] = min(args.max_new_tokens, batch['input_ids'].shape[1]*3)
        batch = to_device(batch, device)
        outputs = model_fintuned.generate(batch['input_ids'], **generate_kwargs)
        #pdb.set_trace()
        raw_outputs = tokenizer.batch_decode(outputs, skip_special_tokens=True)
        output_samples = []

        for raw_sample, raw_output in zip(raw_batch, raw_outputs):
            pure_output = raw_output.replace(raw_sample['input'], '').strip()
            output_sample = copy.deepcopy(raw_sample)
            output_sample['prediction'] = pure_output
            output_sample['docid'] = raw_sample['docid']
            output_samples.append(output_sample)

        progress_bar.update(1)
        results.extend(output_samples)

    if is_main_process():
        writer = jsonlines.open(args.output, flush=True, mode='w')
        writer.write_all(results)

    wait_for_everyone()

if __name__ == "__main__":
    main()
