import ujson
import os
import argparse

from transformers import AutoModelForCausalLM, AutoTokenizer
from transformers.generation.utils import GenerationConfig

args_obj = argparse.ArgumentParser()
args_obj.add_argument('--engine', required=True, type=str)
args = args_obj.parse_args()

engine = args.engine
shot = 'zero-shot'

model_name = f'Qwen/{engine}'
model = AutoModelForCausalLM.from_pretrained(model_name,
                                            torch_dtype="auto",
                                            device_map="auto")
model.generation_config = GenerationConfig.from_pretrained(model_name)
model.eval()
tokenizer = AutoTokenizer.from_pretrained(model_name, use_fast=False)
print('done')

from uni_user_prompt import zero_shot_prompt
def get_user_prompt(text: str):
    return zero_shot_prompt.format(text)

def get_prediction(messages: list):
    input_text = tokenizer.apply_chat_template(
        messages,
        tokenize=False,
        add_generation_prompt=True
    )

    input_ids = tokenizer(input_text, return_tensors="pt").to(f"cuda")
    start = len(tokenizer(input_text)['input_ids'])

    outputs = model.generate(**input_ids, max_new_tokens=100)
    content = tokenizer.decode(outputs[0][start:-1])
    return content

for iter in range(1, 4):
    with open('./data/test.json', 'r', encoding='utf-8') as f:
        test_data = ujson.load(f)
    source_dir = f'./data/{engine}-{shot}/iter{iter}'
    os.makedirs(source_dir, exist_ok=True)
    existing_files = os.listdir(source_dir)
    for idx, item in enumerate(test_data):
        if f'{idx}.json' in existing_files:
            print(f'{idx}: DONE')
        else:
            try:
                text = item['text']
                user_content = get_user_prompt(text=text)
                messages = [{'role': 'user', 'content': user_content}]
                content = get_prediction(messages=messages)

                print(content)
                item[engine] = content
                with open(f'{source_dir}/{idx}.json', 'w', encoding='utf-8') as f:
                    ujson.dump(item, f, ensure_ascii=False, indent=2)
            except Exception as e:
                print(e)