import fire
import os
import json

from seg2act.eval.ChCatExt.metrics import calc_hierarchical_metrics
from transformers import AutoTokenizer


def vary_length(
    exp_dir: str,
    pred_name: str = "",
    base_model: str = "", 
):
    pred_path = os.path.join(exp_dir, f"{pred_name}.json")
    assert os.path.exists(pred_path), "File not exists!"
    tokenizer = AutoTokenizer.from_pretrained(base_model, trust_remote_code=True)
    with open(pred_path, 'r', encoding="utf-8") as f_r:
        x = json.load(f_r)
   
        # group by document token
        group_tokens = {}
        for i in x:
            whole_text = ' '.join(i['segments'])
            result = tokenizer(whole_text)

            sort_id = min(len(result['input_ids']) // 1000, 5)
            sort_id = {0: 0, 1: 1, 2: 1 , 3: 1, 4: 1, 5: 2}[sort_id]
            
            if sort_id not in group_tokens:
                group_tokens[sort_id] = {'preds': [], 'golds': []}
            group_tokens[sort_id]['preds'].append(i['preds'])
            group_tokens[sort_id]['golds'].append(i['answers'])

        keys = sorted(list(group_tokens.keys()))
        for k in keys:
            print('=== (token <', (k + 1) * 1000, ') ===', len(group_tokens[k]['golds']))
            metric = calc_hierarchical_metrics(group_tokens[k]['preds'], group_tokens[k]['golds'])
            for x1 in ["heading", "text", "overall"]:
                print(f"{x1.capitalize()}: ")
                for x2 in ["p", "r", "f1"]:
                    print(f"\t{x2.capitalize()}: {metric[x1][x2] * 100:.3f}\t")

if __name__ == "__main__":
    fire.Fire(vary_length)

