import os
import json
from tqdm import tqdm

import random

experiment_data = {}

def process_reference_snippet(content, path, article_name):
    if article_name not in experiment_data:
        experiment_data[article_name] = {}
    for section in content:
        new_path = path + [section["section_title"]]
        path_string = " -> ".join(new_path)
        section_snippets = set()
        for item in section['section_content']:
            if "reference_snippet" in item and item['reference_snippet']:
                section_snippets.add(item['reference_snippet'])
        if section_snippets:
            experiment_data[article_name][path_string] = random.choice(list(section_snippets)) 
        # Recursively handle subsections
        if section['subsections']:
            process_reference_snippet(content=section['subsections'],
                                      path=new_path,
                                      article_name=article_name)

def get_all_json_files_in_dir(directory_path):
    files_to_process = []
    # Iterate through each file in the directory
    for filename in os.listdir(directory_path):
        if filename.endswith(".json"):
            file_path = os.path.join(directory_path, filename)
            files_to_process.append(file_path)
    return files_to_process

def get_article_headings(article_text_path):
    with open(article_text_path) as f:
        section_headings = []
        for line in f:
            if line and line.startswith("#") and "References" not in line:
                section_headings.append(line.replace("\n", "").strip())
    return "\n".join(section_headings)

files_to_process = get_all_json_files_in_dir(directory_path="json_with_snippet")

for file_path in tqdm(sorted(files_to_process)):
    with open(file_path, 'r') as file:
        data = json.load(file)
        process_reference_snippet(content=data['content'],
                                  path=[],
                                  article_name=os.path.basename(file_path).replace(".json", ""))

total_samples = 0
level_count = {}

data_config = {}
for article_name, aritlce_data in experiment_data.items():
    # get article outline
    data_config[article_name] = {}
    data_config[article_name]["structure"] = get_article_headings(os.path.join("txt", f"{article_name}.txt"))
    data_config[article_name]["exp_data"] = {}
    for path, snippet in aritlce_data.items():
        level = path.count("->") + 1
        if level <= 2 and random.random() >= 0.5 or level > 2:
            data_config[article_name]["exp_data"][path] = snippet
            level_count[level] = level_count.get(level, 0) + 1

print(f"total_samples = {total_samples}")
print(json.dumps(level_count, indent=2))

with open("info_insertion_accuracy_dataset.json", "w") as f:
    json.dump(data_config, f, indent=2)


