import fire
import os
import json

from seg2act.eval.HierDoc.metrics import (calc_heading_detection, 
                                          calc_tree_edit_distance_similarity)


def vary_length(
    exp_dir: str,
    pred_name: str = "pred-HierDoc",
):
    pred_path = os.path.join(exp_dir, f"{pred_name}.json")
    assert os.path.exists(pred_path), "File not exists!"
    with open(pred_path, 'r') as reader:
        x = json.load(reader)
        # group by document depth
        groups = {}
        for i in x:
            sort_id = max([node['depth'] for node in i['answers']])
            if sort_id not in groups:
                groups[sort_id] = []
            groups[sort_id].append(i)

        for k in groups:
            print('===', k, '===')
            hd_p, hd_r, hd_f1 = calc_heading_detection(groups[k])
            teds = calc_tree_edit_distance_similarity(groups[k])
            print(f"Heading Detection: ")
            print(f"\tP: {hd_p * 100:.3f}\n\tR: {hd_r * 100:.3f}\n\tF1: {hd_f1 * 100:.3f}")
            print(f"Tree Edit Distance-based Similarity = {teds:.3f}")


if __name__ == "__main__":
    fire.Fire(vary_length)
