import json
import os
from tqdm import tqdm 

import argparse
import torch
from transformers import AutoModelForCausalLM, AutoTokenizer


def load_model_and_tokenizer(model_path, tokenizer_path=None):
    if tokenizer_path == None:
        tokenizer_path = model_path

    model = AutoModelForCausalLM.from_pretrained(
        model_path,
        trust_remote_code=True,
        torch_dtype=torch.float16,
        device_map="auto",
        bf16=False,
        fp16=True,
        fp32=False,
        use_flash_attn=False,
    )
    tokenizer = AutoTokenizer.from_pretrained(tokenizer_path, trust_remote_code=True)
    tokenizer.eos_token_id = tokenizer.eod_id
    tokenizer.pad_token_id = tokenizer.eod_id

    return model, tokenizer


def make_prompt_zh(query):
    prompt = "请你根据给定的query，改写出不同的查询\nQuery: {}\nOutput: "
    return prompt.format(query)


def make_prompt(query):
    prompt = "Instruction: output the rewrite of input query\nQuery: {}\nOutput: "
    return prompt.format(query)


def response(query, model, tokenizer, generation_config):
    input_text = make_prompt(query)
    src_input_ids = tokenizer(input_text).input_ids
    input_ids = torch.tensor([src_input_ids], dtype=torch.long, device="cuda")
    outputs = model.generate(input_ids, **generation_config)
    response = tokenizer.decode(outputs[0][len(src_input_ids) :], skip_special_tokens=True)
    return response

def clean_pred(text):
    text = text.split("\n")[0]
    return text

def response_single(query, model, tokenizer, generation_config, debug=False):
    responses = set()
    times = 0
    while len(responses) < 2 and times < 5:
        times += 1
        pred = response(query, model, tokenizer, generation_config)
        pred = clean_pred(pred)
        responses.add(pred)
        # if debug:
        #     import pdb;pdb.set_trace()
    responses = list(responses)
    return responses


if __name__ == "__main__":

    parser = argparse.ArgumentParser(description='')
    parser.add_argument('--input_file',
                        required=True,
                        type=str,)
    parser.add_argument('--output_file',
                        required=True,
                        type=str,)
    parser.add_argument('--ckpt_dir',
                        required=True,
                        type=str,)
    parser.add_argument('--debug',
                        action="store_true",)
    args = parser.parse_args()

    model, tokenizer = load_model_and_tokenizer(args.ckpt_dir)

    generation_config = {
        "max_new_tokens": 128,
        "do_sample": True,
        "penalty_alpha": 0.3,
        "temperature": 1.0,
        "top_p": 0.8,
        "repetition_penalty": 1.0,
        "eos_token_id": tokenizer.eos_token_id,
    }

    data = json.load(open(args.input_file))
    print(args.input_file)

    if isinstance(data, list):
        if args.debug:
            data = data[:5]
        for idx, mention in tqdm(enumerate(data), desc="data", total=len(data)):  

            pred = response_single(mention["query"], model, tokenizer, generation_config, args.debug)
            mention["rewrite"] = pred

    else:
        for k, v in data.items():
            if args.debug:
                v = v[:2]
            for idx, mention in tqdm(enumerate(v), desc=k, total=len(v)):    

                if not "ctxs" in mention.keys():
                    mention["ctxs"] = search(mention["query"])
                pred = response_single(mention["query"], model, tokenizer, generation_config, args.debug)
                mention["rewrite"] = pred
                

    json.dump(data, open(args.output_file, "w"), indent=4)