import ujson
import os

from model import MODE

engine = 'chatglm2-6b'
shot = 'zero-shot'
model_path = 'THUDM/chatglm2-6b'
glm2_tokenizer = MODE['glm2']['tokenizer'].from_pretrained(model_path)
glm2_model = MODE['glm2']['model'].from_pretrained(model_path)
glm2_model.cuda()
print('done')

from uni_user_prompt import zero_shot_prompt
def get_user_prompt(text: str):
    return zero_shot_prompt.format(text)

for iter in range(1, 4):
    with open('./data/test.json', 'r', encoding='utf-8') as f:
        test_data = ujson.load(f)
    saved_dir = f'./data/{engine}-{shot}/iter{iter}'
    os.makedirs(saved_dir, exist_ok=True)
    existing_files = os.listdir(saved_dir)
    for idx, item in enumerate(test_data):
        if f'{idx}.json' in existing_files:
            print(f'{idx}: DONE')
        else:
            try:
                text = item['text']
                instruction = get_user_prompt(text=text)
                
                response, history = glm2_model.chat(
                    glm2_tokenizer,
                    instruction,
                    history=[])

                print(response)
                item['chatglm2-6b'] = response
                with open(f'{saved_dir}/{idx}.json', 'w', encoding='utf-8') as f:
                    ujson.dump(item, f, ensure_ascii=False, indent=2)
            except Exception as e:
                print(e)