# Copyright (c) Microsoft Corporation.
# SPDX-License-Identifier: Apache-2.0

# DeepSpeed Team
import pdb
import argparse
import logging
import torch
import sys
import os
import jsonlines

sys.path.append(
            os.path.abspath(os.path.join(os.path.dirname(__file__), os.path.pardir)))
from transformers import (
    AutoConfig,
    AutoModelForCausalLM,
    AutoTokenizer,
)
from utils.module.lora import convert_linear_layer_to_lora, convert_lora_to_linear_layer, only_optimize_lora_parameters

sys.path.append(
    os.path.abspath(os.path.join(os.path.dirname(__file__), os.path.pardir)))
from utils.model.model_utils import create_hf_model
from utils.ds_utils import get_eval_ds_config


logger = logging.getLogger(__name__)


def parse_args():
    parser = argparse.ArgumentParser(description="Eval the finetued SFT model")
    parser.add_argument(
        "--model_name_or_path_baseline",
        type=str,
        help="Path to baseline model",
        required=True,
    )
    parser.add_argument(
        "--model_name_or_path_finetune",
        type=str,
        help="Path to pretrained model",
        required=True,
    )
    parser.add_argument(
        "--num_beams",
        type=int,
        default=1,
        help='Specify num of beams',
    )
    parser.add_argument(
        "--num_beam_groups",
        type=int,
        default=1,
        help='Specify num of beams',
    )
    parser.add_argument(
        "--top_k",
        type=int,
        default=4,
        help='Specify num of beams',
    )
    parser.add_argument(
        "--penalty_alpha",
        type=float,
        default=0.6,
        help='Specify num of beams',
    )
    parser.add_argument(
        "--num_return_sequences",
        type=int,
        default=1,
        help='Specify num of return sequences',
    )
    parser.add_argument(
        "--max_new_tokens",
        type=int,
        default=1024,
        help='Specify num of return sequences',
    )
    parser.add_argument("--language",
                        type=str,
                        default="English",
                        choices=["English", "Chinese", "Japanese", "German"])
    parser.add_argument("--input",
                        type=str,
                        default="")
    # deepspeed features
    parser.add_argument('--offload',
                        action='store_true',
                        help='Enable ZeRO Offload techniques.')
    parser.add_argument(
        '--zero_stage',
        type=int,
        default=0,
        help='ZeRO optimization stage for Actor model (and clones).')
    args = parser.parse_args()

    return args


def generate(model,
             tokenizer,
             inputs,
             num_beams=1,
             num_beam_groups=1,
             do_sample=False,
             num_return_sequences=1,
             max_new_tokens=100):

    max_new_tokens = min(max_new_tokens, inputs.input_ids.shape[1]*3)
    generate_ids = model.generate(inputs.input_ids,
                                  num_beams=num_beams,
                                  num_beam_groups=num_beam_groups,
                                  repetition_penalty=1.3,
                                  do_sample=do_sample,
                                  num_return_sequences=num_return_sequences,
                                  max_new_tokens=max_new_tokens,
                                  use_cache=True,
                                  eos_token_id=tokenizer.eos_token_id,
                                  early_stopping=True)
    #src = generate_ids[:,:inputs.input_ids.shape[1]]
    #tgt = generate_ids[:,inputs.input_ids.shape[1]:]
    result = tokenizer.batch_decode(generate_ids,
                                    skip_special_tokens=True,
                                    clean_up_tokenization_spaces=False)
    return result


def generate_constrastive_search(model,
                                 tokenizer,
                                 inputs,
                                 top_k=4,
                                 penalty_alpha=0.6,
                                 num_return_sequences=1,
                                 max_new_tokens=100):

    generate_ids = model.generate(inputs.input_ids,
                                  top_k=top_k,
                                  penalty_alpha=penalty_alpha,
                                  num_return_sequences=num_return_sequences,
                                  max_new_tokens=max_new_tokens)

    result = tokenizer.batch_decode(generate_ids,
                                    skip_special_tokens=True,
                                    clean_up_tokenization_spaces=False)
    return result


def print_utils(gen_output, idx=""):
    for i in range(len(gen_output)):
        print()
        gen_out = gen_output[i]
        print("T-"+ idx + "\t" + gen_out)
        print()


def prompt_eval(args, model_baseline, model_fintuned, tokenizer, device,
                prompts):
    i = 0
    for prompt in prompts:
        inputs = tokenizer(prompt['input_prompt'], return_tensors="pt").to(device)
        i += 1
        # if inputs['input_ids'].shape[1] >=3000:
        #    print("longer than 3000, skip ", i)
        #    continue
        # print("==========Baseline: Greedy=========")
        # r_base = generate(model_baseline,
        #                   tokenizer,
        #                   inputs,
        #                   num_beams=1,
        #                   num_return_sequences=args.num_return_sequences,
        #                   max_new_tokens=args.max_new_tokens)
        # print_utils(r_base)
        # print("==========finetune: Greedy=========")
        # r_finetune_g = generate(model_fintuned,
        #                         tokenizer,
        #                         inputs,
        #                         num_beams=1,
        #                         num_return_sequences=args.num_return_sequences,
        #                         max_new_tokens=args.max_new_tokens)
        # print_utils( r_finetune_g)
        # Note: we use the above simplest greedy search as the baseline. Users can also use other baseline methods,
        # such as beam search, multinomial sampling, and beam-search multinomial sampling.
        # We provide examples as below for users to try.

        # print("==========finetune: Multinomial sampling=========")
        # r_finetune_m = generate(model_fintuned, tokenizer, inputs,
        #                         num_beams=1,
        #                         do_sample=True,
        #                         num_return_sequences=args.num_return_sequences,
        #                         max_new_tokens=args.max_new_tokens)
        # print_utils(r_finetune_m)
        print("==========finetune: Beam Search=========")
        r_finetune_b = generate(model_fintuned, tokenizer, inputs,
                                num_beams=4,
                                num_return_sequences=args.num_return_sequences,
                                max_new_tokens=1024)
        print_utils(r_finetune_b, str(prompt['docid']))
        # print("==========finetune: Beam-search multinomial sampling=========")
        # r_finetune_s = (model_fintuned, tokenizer, inputs,
        #                         num_beams=args.num_beams,
        #                         do_sample=True,
        #                         num_return_sequences=args.num_return_sequences,
        #                         max_new_tokens=args.max_new_tokens)
        # print_utils(r_finetune_s)
        # print("==========finetune: Diverse Beam Search=========")
        # r_finetune_d = generate(model_fintuned, tokenizer, inputs,
        #                         num_beams=args.num_beams,
        #                         num_beam_groups=args.num_beam_groups,
        #                         num_return_sequences=args.num_return_sequences,
        #                         max_new_tokens=args.max_new_tokens)
        # print_utils(r_finetune_d)
        # print("==========finetune: Constrastive Search=========")
        # r_finetune_c = generate_constrastive_search(model_fintuned, tokenizer, inputs,
        #                                             top_k=args.top_k,
        #                                             penalty_alpha=args.penalty_alpha,
        #                                             num_return_sequences=args.num_return_sequences,
        #                                             max_new_tokens=args.max_new_tokens)
        # print_utils(r_finetune_c)
        print("====================prompt end=============================")
        print()
        print()


def main():
    args = parse_args()

    device = torch.device("cuda:0")
    config = AutoConfig.from_pretrained(args.model_name_or_path_baseline)
    tokenizer = AutoTokenizer.from_pretrained(args.model_name_or_path_baseline)
    #tokenizer = AutoTokenizer.from_pretrained(args.model_name_or_path_baseline,
    #                                          fast_tokenizer=True)

    #model_baseline = create_hf_model(AutoModelForCausalLM,
    #                                args.model_name_or_path_baseline,
    #                                tokenizer, None)
                       
    ds_config = get_eval_ds_config(offload=True,
                                    stage=args.zero_stage)
    print(ds_config)
    print(args.model_name_or_path_finetune)
    print(args.input)
    model_fintuned = create_hf_model(AutoModelForCausalLM,
                                     args.model_name_or_path_finetune,
                                     tokenizer, ds_config)
    
    model_fintuned.to(device)

    # One observation: if the prompt ends with a space " ", there is a high chance that
    # the original model (without finetuning) will stuck and produce no response.
    # Finetuned models have less such issue. Thus following prompts all end with ":"
    # to make it a more meaningful comparison.
    prompts=[]
    if args.language == "Chinese":
        # prompts = [
        #     "Human: Please tell me about Microsoft in a few sentence? Assistant:",
        #     "Human: Explain the moon landing to a 6 year old in a few sentences. Assistant:",
        #     "Human: Write a short poem about a wise frog. Assistant:",
        #     "Human: Who was president of the United States in 1955? Assistant:",
        #     "Human: How does a telescope work? Assistant:",
        #     "Human: Why do birds migrate south for the winter? Assistant:"
        # ]
        # instruct = "Human: Translate the following Chinese text into English\n\n"
        instruct = "Human: Translate this document from Chinese to English\n"
    elif args.language == "German":
        instruct = "Human: Translate this document from German to English\n"
    elif args.language == "English":
        instruct = "Human: Translate this document from English to German\n"
    for line in jsonlines.open(args.input):
        src = line['prompt']
        tmp = dict()
        #tmp['input_prompt'] = instruct+src
        tmp['input_prompt'] = instruct+src
        tmp['docid'] = line['docid']
        prompts.append(tmp)
            

    prompt_eval(args, None, model_fintuned, tokenizer, device,
                prompts)


if __name__ == "__main__":
    main()
