# Absolute Grading: Outputs score of 1 to 5
import argparse
from prometheus_eval import PrometheusEval
from prometheus_eval.prompts import ABSOLUTE_PROMPT, SCORE_RUBRIC_TEMPLATE
import json
import os


def load_json(path):
    with open(path) as f:
        return json.load(f)
    


def main(args):
    judge = PrometheusEval(model_id="prometheus-eval/prometheus-7b-v2.0", absolute_grade_template=ABSOLUTE_PROMPT)
    judge.model.gpu_memory_utilization = 0.5

    grading_item = load_json("grading_items.json")
    result = {}
    for rubric_short_name, items_to_grade in grading_item.items():
        result[rubric_short_name] = {}
        instructions = items_to_grade["instructions"]
        responses = items_to_grade["responses"]
        feedbacks, ratings = judge.absolute_grade(
            instructions=instructions,
            responses=responses,
            rubric=items_to_grade["rubric"],
            reference_answers=None,
            params={}
        )
        meta_list = items_to_grade["meta"]
        for feedback, rating, meta in zip(feedbacks, ratings, meta_list):
            topic, method, utterance_type, idx = meta.values()
            if method not in result[rubric_short_name]:
                result[rubric_short_name][method] = {"method_raw_data": {}, "avg_rating": 0.0}
            if topic not in result[rubric_short_name][method]["method_raw_data"]:
                result[rubric_short_name][method]["method_raw_data"][topic] = {"topic_raw_data": [], "avg_rating": 0.0}
            result[rubric_short_name][method]["method_raw_data"][topic]["topic_raw_data"].append({"type": utterance_type, "idx": idx, "feedback": feedback, "rating": rating})
    
    for rubric_short_name, rubric_data in result.items():
        for method, method_data in rubric_data.items():
            method_ratings = []
            for topic, topic_data in method_data["method_raw_data"].items():
                topic_ratings = []
                for raw_data in topic_data["topic_raw_data"]:
                    topic_ratings.append(raw_data["rating"])
                result[rubric_short_name][method]["method_raw_data"][topic]["avg_rating"] = sum(topic_ratings) / len(topic_ratings)
            for topic, topic_data in method_data["method_raw_data"].items():
                method_ratings.append(topic_data["avg_rating"])
            result[rubric_short_name][method]["avg_rating"] = sum(method_ratings) / len(method_ratings)
    
    with open("grading_result.json", "w") as f:
        json.dump(result, f, indent=2)
            
            
if __name__ == "__main__":
    parser = argparse.ArgumentParser(description="Process and evaluate articles")
    args = parser.parse_args()
    main(args)

# example usage
# python prometeus_eval.py --result-dir 2k_report_only --file-name-to-evaluate article_to_evaluate_cap_2k.txt --rubric-path storm_rubric.json --output-dir ./storm_rubric_output