import json
import argparse

parser = argparse.ArgumentParser(description="TRP")
parser.add_argument(
    "--prompt", type=str, default="naive"
)
args = parser.parse_args()


with open(f'output/{args.prompt}.json', 'r') as f:
    rewards = json.load(f)

transformed_structure = {}
rewards = sorted(rewards, key=lambda x: (x['task_id'], len(x['history'])))

for task in rewards:

    task_id = task["task_id"]
    history = task["history"]
    answer = task["answer"]
    len_his = len(history)

    if task_id not in transformed_structure:
        transformed_structure[task_id] = {}

    his_prompt = str(history)
    len_his = len(history)

    if len(history) not in transformed_structure[task_id]:
        transformed_structure[task_id][len_his] = {}

    for ans, rew in answer.items():
        print(ans, rew)
        transformed_structure[task_id][len_his][ans] = round(rew)

with open(f'../Reward_Ensemble/dataset/reward/{args.prompt}.json', 'w') as f:
    json.dump(transformed_structure, f, indent=4)
