import os
import json

def read_json(data_file):
    with open(data_file, "r") as f:
        data = json.load(f)
    return data

def save_json(data_file, data):
    with open(data_file, "w", encoding="utf-8") as f:
        json.dump(data, f, indent=4, ensure_ascii=False)

if __name__ == '__main__':
    all_data = read_json("./data/hh_rlhf_en/hh_rlhf_en_train.json")

    for data in all_data:
        data["output"] = data["output"][0]


    save_json("./data/hh_rlhf_en/hh_rlhf_en_train_sft.json", all_data)


    