'''
- combine `challenge_test_list` into `test_list`
- skip 2 samples where `test_setup_code` is present Node
- updating `test_list` for all combinations of:
    1. permute sequence of inputs
    2. permute sequence of outputs
    3. converting input types
    4. converting output types
    5. converting categorical responses to boolean outputs
'''
import json
import tqdm

from categories_bool import all_bool
from resp_type import all_type
from permute_args import all_permute
from remove_args import all_remove


def updating_assertions(test_lists):
    test_lists = all_type(test_lists)
    test_lists = all_bool(test_lists)
    test_lists = all_permute(test_lists)
    test_lists = all_remove(test_lists)
    return test_lists

def main():
    limit=51
    ip = r"Dataset\original\mbpp.json"
    with open(ip,'r') as f:
        data = json.load(f)

    try:
        with open(ip.replace(".json",'-updated.json'),'r') as f:
            updated_data = json.load(f)
        indx = len(updated_data)
        print(f"Resuming from {indx}th index")
    except:
        updated_data = []
        indx=0

    for i,d in tqdm.tqdm(enumerate(data)):
        if i==limit: break
        if i<indx: continue
        # - combine `challenge_test_list` into `test_list`
        if d['challenge_test_list']:
            d['test_list'].extend(d['challenge_test_list'])
        
        # # - skip 2/974 samples where `test_setup_code` is present Node
        if d['test_setup_code']: continue

        # - updating `test_list` for:
        try:
            updated_test_lists = updating_assertions(d['test_list'])
        except Exception as e:
            print(f"Error in {d['task_id']}: {e}")
            br

        updated_data.append({"task_id": d["task_id"], "text": d["text"], "test_list": d['test_list'], "updated_test_list": updated_test_lists, "code": d["code"]})
        with open(ip.replace(".json",'-updated.json'),'w') as f:
            json.dump(updated_data, f)
    with open(ip.replace(".json",'-updated.json'),'w') as f:
        json.dump(updated_data, f)
    return


if __name__ == "__main__":
    main()