import argparse
import glob
import json
import os
import sys
from collections import defaultdict
from os.path import dirname


def co_storm_extract_cited_snippets(log_file_path):
    snippets = set()
    with open(log_file_path) as f:
        data = json.load(f)
        for citation_idx, citation_data in data["knowledge_base"]["info_uuid_to_info_dict"].items():
            snippets.add(citation_data["snippets"][0])
    return snippets

def baseline_extract_cited_snippets(log_file_path):
    snippets = set()
    with open(log_file_path) as f:
        data = json.load(f)
        for url, url_data in data["url_to_info"].items():
            snippets.add(url_data["snippets"][0])
    return snippets
     

def main(args):
    base_dir = args.base_dir
    outpur_dir = args.output_dir
    all_snippets = {}
    for method_name in os.listdir(base_dir):
        method_path = os.path.join(base_dir, method_name)
        for article_name in os.listdir(method_path):
            article_path = os.path.join(method_path, article_name)
            cleaned_article_name = article_name.replace("_", " ").strip()
            snippets = None
            if os.path.exists(os.path.join(article_path, "round_table_dump.json")):
                snippets = co_storm_extract_cited_snippets(os.path.join(article_path, "round_table_dump.json"))
            elif os.path.exists(os.path.join(article_path, "url_to_info.json")):
                snippets = baseline_extract_cited_snippets(os.path.join(article_path, "url_to_info.json"))
            if snippets is not None:
                if article_name not in all_snippets:
                    all_snippets[article_name] = {"snippets": [], "method_to_index_mapping": {}}
                for snippet in snippets:
                    if snippet not in all_snippets[article_name]["snippets"]:
                        all_snippets[article_name]["snippets"].append(snippet)
                    snippet_idx = all_snippets[article_name]["snippets"].index(snippet)
                    if method_name not in all_snippets[article_name]["method_to_index_mapping"]:
                        all_snippets[article_name]["method_to_index_mapping"][method_name] = [snippet_idx]
                    else:
                        all_snippets[article_name]["method_to_index_mapping"][method_name].append(snippet_idx)
                    
    
    total = 0
    for aritcle_name, data in all_snippets.items():
        total += sum([len(i.split()) for i in data["snippets"]])
    print(total)

    with open(outpur_dir, "w") as f:
        json.dump(all_snippets, f, indent=2)

if __name__ == '__main__':
    parser = argparse.ArgumentParser(description='Preapre discourse trace for auto evaluation')
    parser.add_argument('-b', '--base-dir', type=str, default='../evaluation_dataset/raw_experiments_data/',
                        help='Base directory for raw experiments data')
    parser.add_argument('-o', '--output-dir', type=str, default='./snippets_to_grade.json',
                        help='Output directory for grading items')

    args = parser.parse_args()
    main(args)
