import argparse
import json
import os
from collections import defaultdict

def load_and_process_report(path):
    with open(path) as f:
        final_page = f.read()
    output = []
    for line in final_page.split('\n'):
        if len(line) == 0 or line[0] == '#':
            continue
        output.append(line)
    return '\n'.join(output).strip()

def co_storm_format_data(article_dir_path):
    report_output = load_and_process_report(os.path.join(article_dir_path, "report.txt"))
    docs = []
    with open(os.path.join(article_dir_path, "round_table_dump.json")) as f:
        data = json.load(f)
        info_uuid_to_info_dict = {int(key): value for key, value in data["knowledge_base"]["info_uuid_to_info_dict"].items()}
        for index, citation_data in sorted(info_uuid_to_info_dict.items()):
            docs.append({"title": citation_data["title"], "text": citation_data["snippets"][0]})
    return {'output': report_output, 'docs': docs}

def baseline_format_data(article_dir_path):
    report_output = load_and_process_report(os.path.join(article_dir_path, "storm_gen_article.txt"))
    docs = []
    with open(os.path.join(article_dir_path, "url_to_info.json")) as f:
        data = json.load(f)
        url_to_unified_index = data["url_to_unified_index"]
        max_index = max(url_to_unified_index.values())
        unified_index_to_url_mapping = {value: key for key, value in data["url_to_unified_index"].items()}
        if set(url_to_unified_index.values()) != set(range(1, max_index + 1)):
            print(f"expect {max_index} but found {len(set(url_to_unified_index.values()))}")
        for index in range(1, max_index + 1):
            url = unified_index_to_url_mapping.get(index, None)
            if url is not None:
                url_data = data["url_to_info"][url]
                docs.append({"title": url_data["title"], "text": url_data["snippets"][0]})
            else:
                docs.append({"title": "", "text": ""})
    return {'output': report_output, 'docs': docs}
     

def main(args):
    base_dir = args.base_dir
    grading_item_output_dir = args.grading_item_output_dir

                                                                                                                                                                                                                                                               = {}
    for method_name in os.listdir(base_dir):
        method_path = os.path.join(base_dir, method_name)
        for article_name in os.listdir(method_path):
            if article_name not in grading_items:
                grading_items[article_name] = {}
            article_path = os.path.join(method_path, article_name)
            cleaned_article_name = article_name.replace("_", " ").strip()
            if "baseline" in method_name:
                grading_items[article_name][method_name] = baseline_format_data(article_path)
            elif "new_method" in method_name:
                grading_items[article_name][method_name] = co_storm_format_data(article_path)

    with open(grading_item_output_dir, "w") as f:
        json.dump(grading_items, f, indent=2)

if __name__ == '__main__':
    parser = argparse.ArgumentParser(description='Preapre discourse trace for auto evaluation')
    parser.add_argument('-b', '--base_dir', type=str, default='../evaluation_dataset/raw_experiments_data',
                        help='Base directory for raw experiments data')
    parser.add_argument('-o', '--grading_item_output_dir', type=str, default='./grading_items.json',
                        help='Output directory for grading items')

    args = parser.parse_args()
    main(args)
