import json
import time
import jsonlines
import argparse
import os


#from venus_chat import VenusClient

parser = argparse.ArgumentParser(description='Path information')
parser.add_argument('--generate_predictions', dest='generate_predictions', default='', help='Root path saving codes and data.')
parser.add_argument('--dataset', dest='dataset', default='', help='')
parser.add_argument('--task_type', dest='task_type', default='scene_abs', help='scene_abs or refine')

def process_data(res_data_file, dataset_name):

    res = list()

    data_info = '/cfs/cfs-0s927vn5/haojinghuang/codes/human_preference_alignment_algorithms/data/dataset_info.json'
    with open(data_info, "r") as info_f:
        data_info = json.load(info_f)

    #data_info = json.load(data_info)

    #eval_data_file = "./data/" + data_info[dataset_name]["file_name"]
    eval_data_file = data_info[dataset_name]["file_name"]

    #print(eval_data_file)
    #exit(0)
    data_file = os.path.join(res_data_file, "generated_predictions.jsonl")

    with open(eval_data_file, "r") as eval_f:
        eval_data = json.load(eval_f)

    with jsonlines.open(data_file, "r") as f:
        i = 0
        for data in f:
            #print(data['predict'].split('}')[0][2:])

            if args.task_type == 'scene_abs':
                response = data['predict']
                #response = "{" + response + "}"
            else:
                response = data['predict']
                #response = json.loads(response)
                print(response)
                #exit(0)
                #res.append(response)
            
            eval_data[i]['output'] = response
            i += 1

            if i >= len(eval_data):
                break

    save_path = os.path.join(res_data_file, "predictions.json")
    with open(save_path, "w", encoding="utf-8") as save_f:
        json.dump(eval_data, save_f, indent=4, ensure_ascii=False)

    print("Json file save in {}.".format(save_path))

if __name__ == '__main__':

    args = parser.parse_args()

    process_data(args.generate_predictions, args.dataset)

    
    