import argparse
import glob
import json
import os
import sys
from collections import defaultdict
from os.path import dirname

script_dir = os.path.dirname(os.path.abspath(__file__))
src_root = dirname(dirname(dirname(script_dir)))
sys.path.append(src_root)

from collaborative_storm.modules.collaborative_storm_utils import load_api_key
from collaborative_storm.engine import RoundTableConversation
load_api_key(toml_file_path=os.path.join(src_root, "..", "secrets.toml"))


def co_storm_extract_conversation_turns(log_file_path):
    with open(log_file_path) as f:
        round_table_conversation = RoundTableConversation.from_dict(json.load(f))
        conversation_turns = []
        total_length = 0
        for turn in round_table_conversation.conversation_history[7:]:
            turn_trimmed_data = {"utterance": turn.utterance, "type": turn.utterance_type, "role": turn.role}
            total_length += len(turn.utterance.split())
            conversation_turns.append(turn_trimmed_data)
    return conversation_turns

def baseline_extract_conversation_turns(log_file_path):
    with open(log_file_path) as f:
        data = json.load(f)
        conversation_turns = []
        total_length = 0
        for turn in data[-1]["dlg_turns"]:
            total_length += len(turn["user_utterance"].split())
            total_length += len(turn["agent_utterance"].split())
            conversation_turns.append({"utterance": turn["user_utterance"], "type": "Original Question", "role": "Guest"})
            conversation_turns.append({"utterance": turn["agent_utterance"], "type": "Potential Answer", "role": "Default speaker"})
    return conversation_turns

def generate_items_to_grade(conversation_turns, topic, intent, method, question_grading_rubric, answer_grading_rubric):
    def _get_conv_history(current_idx):
        history = conversation_turns[:current_idx]
        history_string_list = []
        total_length = 0
        for turn in reversed(history):
            turn_string = f"{turn['role']}: {turn['utterance']}"
            if total_length + len(turn_string.split()) >= 1500:
                if turn["type"] not in ["Original Question", "Information Request"]:
                    turn_string  = f"{turn['role']}: content omitted"
            history_string_list.append(turn_string)
            total_length += len(turn_string.split())
        if not history_string_list:
            return "None"
        return "\n".join(reversed(history_string_list))
    def dict_to_string(d):
        return '\n'.join(f"{key}: {value}" for key, value in d.items())
    
    items_to_grade = {}
    for idx in range(len(conversation_turns)):
        current_turn = conversation_turns[idx]
        if current_turn["role"].lower() == "guest":
            continue
        current_turn_context = _get_conv_history(idx)
        instruction = f"Generate the utterance in the discussion on the topic {topic} as you are interested in {intent}. Here's the conversation history:{current_turn_context}"
        rubrics = question_grading_rubric if current_turn["type"] in ["Original Question", "Information Request"] else answer_grading_rubric 
        for rubric_item in rubrics:
            rubric_item_short_name = rubric_item["criteria_description"].split(":", 1)[0].strip()
            if rubric_item_short_name not in items_to_grade:
                rubric_item_string = dict_to_string(rubric_item)
                items_to_grade[rubric_item_short_name] = {"rubric": rubric_item_string, "instructions": [], "responses": [], "meta": []}
            items_to_grade[rubric_item_short_name]["instructions"].append(instruction)
            items_to_grade[rubric_item_short_name]["responses"].append(current_turn["utterance"])
            items_to_grade[rubric_item_short_name]["meta"].append({"topic": topic, "method": method, "type": current_turn["type"], "idx": idx})
    return items_to_grade
            
def merge_dicts(dict1, dict2):
    if dict1 is None:
        return dict2
    if dict2 is None:
        return dict1
    for key in dict2:
        if key in dict1:
            dict1[key]["instructions"].extend(dict2[key]["instructions"])
            dict1[key]["responses"].extend(dict2[key]["responses"])
            dict1[key]["meta"].extend(dict2[key]["meta"])
        else:
            dict1[key] =  dict2[key]
    return dict1

def load_topic_and_intent(dataset_meta_path):
    topic_to_inent_mapping = {}
    with open(dataset_meta_path) as f:
        data = json.load(f)
        for meta in data:
            topic_to_inent_mapping[meta["topic"]] = meta["intent"]
    return topic_to_inent_mapping

def load_rubric(rubric_path):
     with open(rubric_path) as f:
        return json.load(f)
     

def main(args):
    topic_and_intent_mapping = load_topic_and_intent(args.topic_and_intent)
    answer_grading_rubric = load_rubric(args.answer_grading_rubric)
    question_grading_rubric = load_rubric(args.question_grading_rubric)
    base_dir = args.base_dir
    grading_item_output_dir = args.grading_item_output_dir

    grading_items = None
    for method_name in os.listdir(base_dir):
        method_path = os.path.join(base_dir, method_name)
        for article_name in os.listdir(method_path):
            article_path = os.path.join(method_path, article_name)
            cleaned_article_name = article_name.replace("_", " ").strip()
            conversation_turns = None
            if os.path.exists(os.path.join(article_path, "round_table_dump.json")):
                conversation_turns = co_storm_extract_conversation_turns(os.path.join(article_path, "round_table_dump.json"))
            elif os.path.exists(os.path.join(article_path, "conversation_log.json")):
                conversation_turns = baseline_extract_conversation_turns(os.path.join(article_path, "conversation_log.json"))
            if conversation_turns is not None:
                current_grading_items = generate_items_to_grade(conversation_turns, 
                                                                topic=article_name, 
                                                                intent=topic_and_intent_mapping[cleaned_article_name], 
                                                                method=method_name,
                                                                question_grading_rubric=question_grading_rubric, 
                                                                answer_grading_rubric=answer_grading_rubric)
                grading_items = merge_dicts(grading_items, current_grading_items)

    with open(grading_item_output_dir, "w") as f:
        json.dump(grading_items, f, indent=2)

if __name__ == '__main__':
    parser = argparse.ArgumentParser(description='Preapre discourse trace for auto evaluation')
    parser.add_argument('-t', '--topic_and_intent', type=str, default='../../dataset/final_core_dataset_meta.json',
                        help='Path to the topic and intent file')
    parser.add_argument('-a', '--answer_grading_rubric', type=str, default='./answer_grading_rubric.json',
                        help='Path to the answer grading rubric file')
    parser.add_argument('-q', '--question_grading_rubric', type=str, default='./question_grading_rubric.json',
                        help='Path to the question grading rubric file')
    parser.add_argument('-b', '--base_dir', type=str, default='../evaluation_dataset/raw_experiments_data',
                        help='Base directory for raw experiments data')
    parser.add_argument('-o', '--grading_item_output_dir', type=str, default='./grading_items.json',
                        help='Output directory for grading items')

    args = parser.parse_args()
    main(args)
