import json

from pathlib import Path
from utils import set_logger
from gpt_ambiguity_prompt import Prompter
from tqdm.auto import tqdm


def main():
    logger = set_logger()
    prompter = Prompter()
    
    input_json_path = Path(f"../entity_mention_list.json")
    save_json_path = Path(f"../filtered_mention_list.json")

    with open(input_json_path, "r") as f:
        input = json.load(f)

    def chunk_list(original_list, chunk_size):
        chunked_list = [original_list[i:i + chunk_size] for i in range(0, len(original_list), chunk_size)]
        return chunked_list

    output = {}

    for i, (key, value) in enumerate(tqdm(input.items(), desc="total")):
        filtered_mentions = []
        entity = key.split('/')[-1]
        chunked_list = chunk_list(value, 10)
        for mentions in tqdm(chunked_list, desc=f"entity: {entity}", leave=False):
            try:
                str_list = prompter.prompt(mentions, entity)
                filtered_mentions.extend(eval(str_list))
                logger.info(f"Filtered mentions from gpt-4 (entity: {entity}, mentions: {repr(mentions)})")
            except:
                logger.warning(f"Failed to filter mentions (entity: {entity}, mentions: {repr(mentions)})")

        output[key] = filtered_mentions

        if i % 1 == 0:
            with open(save_json_path, 'w') as f:
                json.dump(output, f, indent=3)
            logger.info(f"Saved json (current i: {i}, current entity: {entity})")

    with open(save_json_path, 'w') as f:
        json.dump(output, f, indent=3)

if __name__ == "__main__":
    main()