import concurrent.futures
import dspy
import numpy as np
import os
import time

from contextlib import contextmanager
from dataclasses import dataclass, field, asdict, fields
from itertools import zip_longest
from typing import List, Union, Literal, Optional
from sklearn.metrics.pairwise import cosine_similarity

from src.interface import LMConfigs
from src.lm import OpenAIModel, AzureOpenAIModel
from collaborative_storm.modules.article_generation import ArticleGenerationModule
import collaborative_storm.modules.collaborative_storm_utils as collaborative_storm_utils
from collaborative_storm.modules.dataclass import KnowledgeBase, NewInformation, ConversationTurn, KnowledgeNode
from collaborative_storm.modules.encoder import get_text_embeddings
from collaborative_storm.modules.expert_generation import GenerateExpertModule
from collaborative_storm.modules.grounded_question_answering import AnswerQuestionModule
from collaborative_storm.modules.grounded_question_generation import GroundedQuestionGenerationModule
from collaborative_storm.modules.information_insertion_module import (
    InsertInformationModule,
    update_conv_turn_to_knowledge_base,
    ExpandSectionModule
)
from collaborative_storm.modules.retriever import NewRetriever
from collaborative_storm.modules.round_table_chat import RoundTableConversationModule
from collaborative_storm.modules.simulate_user import GenSimulatedUserUtterance
from collaborative_storm.modules.warmstart_hierarchical_chat import WarmStartModule


class CollaborativeStormLMConfigs(LMConfigs):
    """Configurations for LLM used in different parts of STORM.

    Given that different parts in STORM framework have different complexity, we use different LLM configurations
    to achieve a balance between quality and efficiency. If no specific configuration is provided, we use the default
    setup in the paper.
    """

    def __init__(self):
        self.answer_question_lm = None
        self.generate_experts_lm = None
        self.warmstart_question_asking_lm = None
        self.action_planning_lm = None
        self.utterance_polishing_lm = None
        self.warmstart_outline_gen_lm = None
        self.information_insert_lm = None
        self.grounded_question_generation_lm = None
        self.article_generation_lm = None

    def init_openai_model(
        self,
        openai_api_key: str,
        azure_api_key: str,
        openai_type: Literal["openai", "azure"],
        api_base: Optional[str] = None,
        api_version: Optional[str] = None,
        temperature: Optional[float] = 1.0,
        top_p: Optional[float] = 0.9,
    ):
        """Legacy: Corresponding to the original setup in the NAACL'24 paper."""
        azure_kwargs = {
            "api_key": azure_api_key,
            "temperature": temperature,
            "top_p": top_p,
            "api_base": api_base,
            "api_version": api_version
        }

        openai_kwargs = {
            "api_key": openai_api_key,
            "api_provider": "openai",
            "temperature": temperature,
            "top_p": top_p,
            "api_base": None,
        }

        if openai_type and openai_type == "openai":
            self.answer_question_lm = OpenAIModel(
                model="gpt-4o-2024-05-13", max_tokens=1000, **openai_kwargs
            )
            self.generate_experts_lm = OpenAIModel(
                model="gpt-4o-2024-05-13", max_tokens=500, **openai_kwargs
            )
            self.warmstart_question_asking_lm = OpenAIModel(
                model="gpt-4o-2024-05-13", max_tokens=300, **openai_kwargs
            )
            self.action_planning_lm = OpenAIModel(
                model="gpt-4o-2024-05-13", max_tokens=500, **openai_kwargs
            )
            self.utterance_polishing_lm = OpenAIModel(
                model="gpt-4o-2024-05-13", max_tokens=2000, **openai_kwargs
            )
            self.warmstart_outline_gen_lm = OpenAIModel(
                model="gpt-4-1106-preview", max_tokens=500, **openai_kwargs
            )
            self.information_insert_lm = OpenAIModel(
                model="gpt-4o-2024-05-13", max_tokens=300, **openai_kwargs
            )
            self.grounded_question_generation_lm = OpenAIModel(
                model="gpt-4o-2024-05-13", max_tokens=300, **openai_kwargs
            )
            self.article_generation_lm = OpenAIModel(
                model="gpt-4o-2024-05-13", max_tokens=1000, **openai_kwargs
            )
        elif openai_type and openai_type == "azure":
            self.answer_question_lm = AzureOpenAIModel(
                model="gpt-4o", max_tokens=1000, **azure_kwargs, model_type="chat"
            )
            self.generate_experts_lm = AzureOpenAIModel(
                model="gpt-4o", max_tokens=500, **azure_kwargs, model_type="chat"
            )
            self.warmstart_question_asking_lm = AzureOpenAIModel(
                model="gpt-4o", max_tokens=300, **azure_kwargs, model_type="chat"
            )
            self.action_planning_lm = AzureOpenAIModel(
                model="gpt-4o", max_tokens=500, **azure_kwargs, model_type="chat"
            )
            self.utterance_polishing_lm = AzureOpenAIModel(
                model="gpt-4o", max_tokens=2000, **azure_kwargs, model_type="chat"
            )
            self.warmstart_outline_gen_lm = OpenAIModel(
                model="gpt-4-1106-preview", max_tokens=500, **openai_kwargs
            )
            self.information_insert_lm = AzureOpenAIModel(
                model="gpt-4o", max_tokens=300, **azure_kwargs, model_type="chat"
            )
            self.grounded_question_generation_lm = AzureOpenAIModel(
                model="gpt-4o", max_tokens=300, **azure_kwargs, model_type="chat"
            )
            self.article_generation_lm = AzureOpenAIModel(
                model="gpt-4o", max_tokens=1000, **azure_kwargs, model_type="chat"
            )
        else:
            raise Exception(
                "No valid OpenAI API provider is provided. Cannot use default LLM configurations."
            )

    def set_answer_question_lm(self, model: Union[dspy.dsp.LM, dspy.dsp.HFModel]):
        self.answer_question_lm = model

    def set_generate_experts_lm(self, model: Union[dspy.dsp.LM, dspy.dsp.HFModel]):
        self.generate_experts_lm = model

    def set_warmstart_question_asking_lm(
        self, model: Union[dspy.dsp.LM, dspy.dsp.HFModel]
    ):
        self.warmstart_question_asking_lm = model

    def set_action_planning_lm(self, model: Union[dspy.dsp.LM, dspy.dsp.HFModel]):
        self.action_planning_lm = model

    def set_utterance_polishing_lm(self, model: Union[dspy.dsp.LM, dspy.dsp.HFModel]):
        self.utterance_polishing_lm = model

    def set_warmstart_outline_gen_lm(self, model: Union[dspy.dsp.LM, dspy.dsp.HFModel]):
        self.warmstart_outline_gen_lm = model

    def set_information_insert_lm(self, model: Union[dspy.dsp.LM, dspy.dsp.HFModel]):
        self.information_insert_lm = model

    def set_grounded_question_generation_lm(
        self, model: Union[dspy.dsp.LM, dspy.dsp.HFModel]
    ):
        self.grounded_question_generation_lm = model

    def set_article_generation_lm(self, model: Union[dspy.dsp.LM, dspy.dsp.HFModel]):
        self.article_generation_lm = model
    
    def collect_and_reset_lm_usage(self):
        lm_usage = {}
        for attr_name in self.__dict__:
            if '_lm' in attr_name and hasattr(getattr(self, attr_name), 'get_usage_and_reset'):
                usage = getattr(self, attr_name).get_usage_and_reset()
                if any(value['prompt_tokens'] != 0 or value['completion_tokens'] != 0 for value in usage.values()):
                    lm_usage[attr_name] = usage
        return lm_usage
    
    def to_dict(self):
        """
        Converts the CollaborativeStormLMConfigs instance to a dictionary representation.

        Returns:
            dict: The dictionary representation of the CollaborativeStormLMConfigs.
        """
        config_dict = {}
        for attr_name in self.__dict__:
            config_dict[attr_name] = getattr(self, attr_name).kwargs
        return config_dict


@dataclass
class RunnerArgument:
    """Arguments for controlling the STORM Wiki pipeline."""

    topic: str = field(
        metadata={"help": "Topic of discourse"},
    )
    retrieve_top_k: int = field(
        default=10,
        metadata={"help": "retrieve top k results for each query in retriever"},
    )
    max_search_queries: int = field(
        default=2,
        metadata={
            "help": "Maximum number of search queries to consider for each question."
        },
    )
    total_conv_turn: int = field(
        default=20,
        metadata={
            "help": "Maximum number turn in conversation."
        },
    )
    max_search_thread: int = field(
        default=5,
        metadata={"help": "Maximum number of parallel thread for retriever"},
    )
    max_search_queries_per_turn: int = field(
        default=3,
        metadata={"help": "Maximum number of search queries to consider in each turn."},
    )
    warmstart_max_num_experts: int = field(
        default=3,
        metadata={
            "help": "Max number of experts in perspective guided QA in warm start process"
        },
    )
    warmstart_max_turn_per_experts: int = field(
        default=2,
        metadata={"help": "Max number of turns per perspective in warm start process"},
    )
    warmstart_max_thread: int = field(
        default=3,
        metadata={
            "help": "Max number thread for parallel perspective guided QA in warm start process"
        },
    )
    max_thread_num: int = field(
        default=10,
        metadata={
            "help": "Maximum number of threads to use. "
            "Consider reducing it if keep getting 'Exceed rate limit' error when calling LM API."
        },
    )
    max_num_round_table_experts: int = field(
        default=2,
        metadata={"help": "Max number of active experts in round table discussion."},
    )
    moderator_override_N_consecutive_answering_turn: int = field(
        default=3,
        metadata={"help": "Number of consecutive experts answering turn before moderator override the conversation"},
    )
    node_expansion_trigger_count: int = field(
        default=10,
        metadata={"help": "Trigger node expansion for node that contain more than N snippets"},
    )
    disable_moderator: bool = field(
        default=False,
        metadata={"help": "If True, disable moderator."},
    )
    disable_multi_experts: bool = field(
        default=False,
        metadata={"help": "If True, disable moderator."},
    )
    rag_only_baseline_mode: bool = field(
        default=False,
        metadata={"help": "If True, switch to rag online baseline mode"},
    )

    def to_dict(self):
        """
        Converts the RunnerArgument instance to a dictionary representation.

        Returns:
            dict: The dictionary representation of the RunnerArgument.
        """
        return asdict(self)

    @classmethod
    def from_dict(cls, data):
        """
        Constructs a RunnerArgument instance from a dictionary representation.

        Args:
            data (dict): The dictionary representation of the RunnerArgument.

        Returns:
            RunnerArgument: The constructed RunnerArgument instance.
        """
        return cls(**data)

class EventLog():
    def __init__(self, event_name):
        self.event_name: str = event_name
        self.start_time = None
        self.end_time = None
        self.total_time_in_seconds: float = None
    
    def record_start_time(self):
        self.start_time = time.time()

    def record_end_time(self):
        self.end_time = time.time()
        assert self.start_time is not None, f"failure to record end time for event {{{self.event_name}}} as start time is not recorded"
        self.total_time_in_seconds = self.end_time - self.start_time
    
    def get_total_time(self):
        return self.total_time_in_seconds

class LoggingWrapper():
    def __init__(self, lm_config: CollaborativeStormLMConfigs):
        # {pipeline_stage: 
        #   {
        #       "timestamp": {event_name: EventLog()},
        #       "lm_usage": xxx
        #   }
        # }
        self.logging_dict: dict = {} 
        self.lm_config = lm_config
        self.current_pipeline_stage = None
    
    def _pipeline_stage_start(self, pipeline_stage: str):
        self.current_pipeline_stage = pipeline_stage
        self.logging_dict[pipeline_stage] = {"time_usage": {}, "lm_usage": {}, "query_count": 0}

    def _event_start(self, event_name: str):
        if event_name not in self.logging_dict[self.current_pipeline_stage]["time_usage"]:
            event = EventLog(event_name=event_name)
            event.record_start_time()
            self.logging_dict[self.current_pipeline_stage]["time_usage"][event_name] = event
        else:
            self.logging_dict[self.current_pipeline_stage]["time_usage"][event_name].record_start_time()
    
    def _event_end(self, event_name: str):
        assert event_name in self.logging_dict[self.current_pipeline_stage]["time_usage"], f"failure to record end time for event {{{event_name}}} as start time is not recorded"
        self.logging_dict[self.current_pipeline_stage]["time_usage"][event_name].record_end_time()

    def _pipeline_stage_end(self):
        self.logging_dict[self.current_pipeline_stage]["lm_usage"] = self.lm_config.collect_and_reset_lm_usage()
        self.logging_dict[self.current_pipeline_stage]["lm_history"] = self.lm_config.collect_and_reset_lm_history()

    def add_query_count(self, count):
        self.logging_dict[self.current_pipeline_stage]["query_count"] = self.logging_dict[self.current_pipeline_stage]["query_count"] + count

    @contextmanager
    def log_event(self, event_name):
        self._event_start(event_name)
        yield
        self._event_end(event_name)
    
    @contextmanager
    def log_pipeline_stage(self, pipeline_stage):
        start_time = time.time()
        self._pipeline_stage_start(pipeline_stage)
        yield
        self.logging_dict[self.current_pipeline_stage]["total_wall_time"] = time.time() - start_time
        self._pipeline_stage_end()


    def dump_logging_and_reset(self, reset_logging=True):
        log_dump = {}
        for pipeline_stage, pipeline_log in self.logging_dict.items():
            time_stamp_log = {event_name: event.get_total_time() for event_name, event in pipeline_log["time_usage"].items()}
            log_dump[pipeline_stage] = {"time_usage": time_stamp_log, 
                                        "lm_usage": pipeline_log["lm_usage"],
                                        "lm_history": pipeline_log["lm_history"],
                                        "query_count": pipeline_log["query_count"],
                                        "total_wall_time": pipeline_log["total_wall_time"]}
        if reset_logging:
            self.logging_dict.clear()
        return log_dump
    
    
class RoundTableConversation():
    def __init__(self, 
                 lm_config: CollaborativeStormLMConfigs,
                 runner_argument: RunnerArgument,
                 logging_wrapper: LoggingWrapper):
        self.runner_argument = runner_argument
        self.lm_config = lm_config
        self.logging_wrapper = logging_wrapper
        self.experts = []
        self.conversation_history = []
        self.generate_expert_module = GenerateExpertModule(
            engine=self.lm_config.generate_experts_lm
        )
        self.information_insert_module = InsertInformationModule(
            engine=self.lm_config.information_insert_lm
        )
        self.grounded_question_generation_module = GroundedQuestionGenerationModule(
            engine=self.lm_config.grounded_question_generation_lm
        )
        retriever = NewRetriever(
            search_top_k_source=self.runner_argument.retrieve_top_k,
            max_thread=self.runner_argument.max_search_thread,
        )
        self.grounded_question_answering_module = AnswerQuestionModule(
            retriever=retriever,
            max_search_queries=self.runner_argument.max_search_queries,
            search_top_k=self.runner_argument.retrieve_top_k,
            expert_question_answer_lm=self.lm_config.answer_question_lm,
            logging_wrapper=logging_wrapper
        )
        self.knowledge_base = KnowledgeBase(topic=self.runner_argument.topic)
        self.warm_start_module = WarmStartModule(
            warmstart_outline_gen_lm=self.lm_config.warmstart_outline_gen_lm,
            moderator_question_asking_lm=self.lm_config.warmstart_question_asking_lm,
            answer_question_module=self.grounded_question_answering_module,
            information_insertion_module=self.information_insert_module,
            generate_expert_module=self.generate_expert_module,
            warmstart_max_num_experts=self.runner_argument.warmstart_max_num_experts,
            warmstart_max_turn_per_experts=self.runner_argument.warmstart_max_turn_per_experts,
            warmstart_max_thread=self.runner_argument.warmstart_max_thread,
            logging_wrapper=self.logging_wrapper
        )
        self.round_table_conv_module = RoundTableConversationModule(
            action_planning_lm=self.lm_config.action_planning_lm,
            utterance_polishing_lm=self.lm_config.utterance_polishing_lm,
            answer_question_module=self.grounded_question_answering_module,
            logging_wrapper=self.logging_wrapper
        )
        self.expand_section_module = ExpandSectionModule(engine=self.lm_config.information_insert_lm,
                                                         information_insert_module=self.information_insert_module)
        self.article_generation_module = ArticleGenerationModule(
            engine=self.lm_config.article_generation_lm
        )
        self.gen_simulated_user_utterance = GenSimulatedUserUtterance(engine=self.lm_config.warmstart_question_asking_lm)

    def to_dict(self):
        return {"runner_argument": self.runner_argument.to_dict(),
                "lm_config": self.lm_config.to_dict(),
                "conversation_history": [turn.to_dict() for turn in self.conversation_history],
                "experts": self.experts,
                "knowledge_base": self.knowledge_base.to_dict()}

    @classmethod
    def from_dict(cls, data):
        lm_config = CollaborativeStormLMConfigs() # FIXME: does not use the lm_config data but naively use default setting
        lm_config.init_openai_model(openai_api_key=os.getenv("OPENAI_API_KEY"),
                                    azure_api_key=os.getenv("AZURE_API_KEY"),
                                    openai_type=os.getenv("OPENAI_API_TYPE"),
                                    api_base=os.getenv("AZURE_API_BASE"),
                                    api_version=os.getenv("AZURE_API_VERSION"))
        round_table_conversation = cls(lm_config=lm_config,
                                       runner_argument=RunnerArgument.from_dict(data["runner_argument"]),
                                       logging_wrapper=LoggingWrapper(lm_config))
        round_table_conversation.conversation_history = [ConversationTurn.from_dict(turn) for turn in data["conversation_history"]]
        round_table_conversation.experts = data["experts"]
        round_table_conversation.knowledge_base = KnowledgeBase.from_dict(data["knowledge_base"])
        return round_table_conversation

    def update_expert_list_from_utterance(self, focus: str, background_info: str):
        self.experts = self.generate_expert_module(
            topic=self.runner_argument.topic,
            background_info=background_info,
            focus=focus,
            num_experts=self.runner_argument.max_num_round_table_experts,
        ).experts

    def get_next_expert(self, dry_run=False):
        if dry_run:
            return self.experts[0]
        current_expert = self.experts.pop(0)
        self.experts.append(current_expert)
        return current_expert

    def warm_start(self):
        with self.logging_wrapper.log_pipeline_stage(pipeline_stage=f"warmstart"):
            if not self.runner_argument.rag_only_baseline_mode:
                warmstart_conv, warmstart_experts, kb = self.warm_start_module.initiate_warm_start(topic=self.runner_argument.topic)
                self.knowledge_base = kb
                self.experts = warmstart_experts
                self.conversation_history = warmstart_conv
                self.reorganize_knowledge_base()
            else:
                if self.knowledge_base is None:
                    self.knowledge_base = KnowledgeBase(topic=self.runner_argument.topic)
                if self.conversation_history is None:
                    self.conversation_history = []
                if self.experts is None:
                    self.experts = []
                conv_turn = self.generate_rag_response(question=self.runner_argument.topic)
                self.conversation_history.append(conv_turn)
                update_conv_turn_to_knowledge_base(conv_turn=conv_turn,
                                                   kb=self.knowledge_base,
                                                   information_insert_module=self.information_insert_module,
                                                   allow_create_new_node=True,
                                                   insert_under_root=self.runner_argument.rag_only_baseline_mode)
            

    def _get_conv_turn_unused_information(self, conv_turn):
        # extract all snippets from raw retrieved information
        raw_retrieved_info: List[NewInformation] = conv_turn.raw_retrieved_info
        raw_retrieved_single_snippet_info: List[NewInformation] = []
        for info in raw_retrieved_info:
            for snippet_idx in range(len(info.snippets)):
                raw_retrieved_single_snippet_info.append(
                    collaborative_storm_utils.extract_storm_info_snippet(
                        info, snippet_index=snippet_idx
                    )
                )
        # get all cited information
        cited_info = list(self.knowledge_base.info_uuid_to_info_dict.values())
        cited_info_hash_set = set([hash(info) for info in cited_info])
        cited_snippets = [info.snippets[0] for info in cited_info]
        # get list of unused information
        unused_information: List[NewInformation] = [
            info
            for info in raw_retrieved_single_snippet_info
            if hash(info) not in cited_info_hash_set
        ]
        # extract snippets to get embeddings
        unused_information_snippets = [info.snippets[0] for info in unused_information]
        # get embeddings
        cache = self.knowledge_base.embedding_cache
        unused_snippets_embeddings, _ = get_text_embeddings(
            unused_information_snippets, embedding_cache=cache, max_workers=100
        )
        claim_embedding, _ = get_text_embeddings(
            conv_turn.claim_to_make, embedding_cache=cache
        )
        query_embedding, _ = get_text_embeddings(
            conv_turn.queries, embedding_cache=cache
        )
        cited_snippets_embedding, _ = get_text_embeddings(
            cited_snippets, embedding_cache=cache
        )
        # calculate similarity
        query_similarities = cosine_similarity(
            unused_snippets_embeddings, query_embedding
        )
        max_query_similarity = np.max(query_similarities, axis=1)
        cited_snippets_similarity = np.max(
            cosine_similarity(unused_snippets_embeddings, cited_snippets_embedding),
            axis=1,
        )
        # use claim similarity to filter out "real" not useful data
        claim_similarity = cosine_similarity(
            unused_snippets_embeddings, claim_embedding.reshape(1, -1)
        ).flatten()
        claim_similarity = np.where(claim_similarity >= 0.25, 1.0, 0.0)
        # calculate score: snippet that is close to topic but far from query
        query_sim_weight = 0.5
        cited_snippets_sim_weight = 1 - query_sim_weight
        combined_scores = (
            ((1 - max_query_similarity) ** query_sim_weight)
            * ((1 - cited_snippets_similarity) ** cited_snippets_sim_weight)
            * claim_similarity
        )
        sorted_indices = np.argsort(combined_scores)[::-1]
        return [unused_information[idx] for idx in sorted_indices]

    def _get_sorted_unused_snippets(self, last_n_conv_turn: int = 2):
        # get last N conv turn and batch encode all related strings
        considered_conv_turn = []
        batch_snippets = [self.runner_argument.topic]
        for conv_turn in reversed(self.conversation_history):
            if len(considered_conv_turn) == last_n_conv_turn:
                break
            if conv_turn.utterance_type == "Questioning":
                break
            considered_conv_turn.append(conv_turn)
            batch_snippets.extend(
                sum([info.snippets for info in conv_turn.raw_retrieved_info], [])
            )
            batch_snippets.append(conv_turn.claim_to_make)
            batch_snippets.extend(conv_turn.queries)
        cache = self.knowledge_base.embedding_cache
        get_text_embeddings(batch_snippets, embedding_cache=cache, max_workers=300)

        # get sorted unused snippets for each turn
        sorted_snippets = []
        for conv_turn in considered_conv_turn:
            sorted_snippets.append(self._get_conv_turn_unused_information(conv_turn))

        # use round robin rule to merge these snippets
        merged_snippets = []
        for elements in zip_longest(*sorted_snippets, fillvalue=None):
            merged_snippets.extend(e for e in elements if e is not None)
        return merged_snippets

    def generate_grounded_question(self):
        unused_snippets: List[NewInformation] = self._get_sorted_unused_snippets()
        generated_question = self.grounded_question_generation_module(
            topic=self.runner_argument.topic,
            knowledge_base=self.knowledge_base,
            last_conv_turn=self.conversation_history[-1],
            unused_snippets=unused_snippets,
        )
        return ConversationTurn(
            role="Moderator",
            raw_utterance=generated_question.raw_utterance,
            utterance_type="Original Question",
            utterance=generated_question.utterance,
            cited_info=generated_question.cited_info,
        )

    def _should_generate_question(self):
        consecutive_non_questioning_turn = 0
        for conv_turn in reversed(self.conversation_history):
            if conv_turn.utterance_type not in [
                "Original Question",
                "Information Request",
            ]:
                consecutive_non_questioning_turn += 1
            else:
                break
        return consecutive_non_questioning_turn >= self.runner_argument.moderator_override_N_consecutive_answering_turn
    
    def reorganize_knowledge_base(self):
        with self.logging_wrapper.log_event("reorganize_knowledge_base.expand_node"):
            self.knowledge_base.trim_empty_leaf_nodes()
            self.knowledge_base.merge_single_child_nodes()

            expanded_nodes = []

            def _find_first_node_to_expand(root: KnowledgeNode):
                if root is None:
                    return None
                if root not in expanded_nodes and len(root.content) >= self.runner_argument.node_expansion_trigger_count:
                    return root
                for child in root.children:
                    to_return = _find_first_node_to_expand(root=child)
                    if to_return is not None:
                        return to_return
                return None

            while True:
                node_to_expand = _find_first_node_to_expand(root=self.knowledge_base.root)
                if node_to_expand is None:
                    break
                self.expand_section_module(node_to_expand, knowledge_base=self.knowledge_base)
                expanded_nodes.append(node_to_expand)
            self.knowledge_base.trim_empty_leaf_nodes()
            self.knowledge_base.merge_single_child_nodes()
            self.knowledge_base.update_all_info_path()



    def generate_report(self) -> str:
        with self.logging_wrapper.log_pipeline_stage("generate_report"):
            with self.logging_wrapper.log_event("generate_report"):
                return self.article_generation_module(
                    topic=self.runner_argument.topic, knowledge_base=self.knowledge_base
                )
    
    def dump_logging_and_reset(self):
        return self.logging_wrapper.dump_logging_and_reset()
    
    def simulate_user_ask_question(self, intent: str):
        with self.logging_wrapper.log_event("RoundTableConversation.simulate_user_ask_question"):
            return self.gen_simulated_user_utterance(topic=self.runner_argument.topic, 
                                                     intent=intent,
                                                     conv_history=self.conversation_history)

    def generate_rag_response(self, question):
        grounded_answer = self.grounded_question_answering_module(topic=self.runner_argument.topic, 
                                                                  question=question,
                                                                  mode="brief",
                                                                  style="conversational and concise")
        conversation_turn = ConversationTurn(
            role="RAG system",
            raw_utterance="",
            utterance_type="Potential Answer"
        )
        conversation_turn.claim_to_make = question
        conversation_turn.raw_utterance = grounded_answer.response
        conversation_turn.utterance = grounded_answer.response
        conversation_turn.queries=grounded_answer.queries
        conversation_turn.raw_retrieved_info=grounded_answer.raw_retrieved_info
        conversation_turn.cited_info=grounded_answer.cited_info
        return conversation_turn
        
    def conversation_step(self, user_utterance: str="", simulate_user: bool = False, simulate_user_intent: str = ""):
        with self.logging_wrapper.log_pipeline_stage(pipeline_stage=f"conv_turn_{len(self.conversation_history) + 1}"):
            conv_turn = None
            # if user inject a question, just update the history
            if simulate_user:
                utterance = self.simulate_user_ask_question(simulate_user_intent)
                conv_turn = ConversationTurn(role="Guest",
                                            raw_utterance=utterance,
                                            utterance_type="Original Question")
            elif user_utterance:
                conv_turn = ConversationTurn(role="Guest",
                                            raw_utterance=user_utterance,
                                            utterance_type="Original Question")
            elif self.runner_argument.rag_only_baseline_mode:
                if self.conversation_history and self.conversation_history[-1].role == "Guest":
                    conv_turn = self.generate_rag_response(question=self.conversation_history[-1].utterance)
            # check if need moderator steering
            elif not self.runner_argument.disable_moderator and self._should_generate_question():
                with self.logging_wrapper.log_event("RoundTableConversation.generate_grounded_question"):
                    conv_turn = self.generate_grounded_question()
                    self.reorganize_knowledge_base()
            # experts RAG gen
            else:
                is_last_turn_questioning = self.conversation_history and self.conversation_history[-1].utterance_type in ["Original Question", "Information Request"]
                current_expert = "General Knowledge Provider: Focus on broadly covering the basic facts about the question."
                if not is_last_turn_questioning and not self.runner_argument.disable_multi_experts:
                    current_expert = self.get_next_expert()
                with self.logging_wrapper.log_event("RoundTableConversation.grounded_question_generation_module._generate_summary"):
                    conversation_summary = self.grounded_question_generation_module._generate_summary(topic=self.runner_argument.topic, knowledge_base=self.knowledge_base)

                last_conv_turn = self.conversation_history[-1]
                conv_turn = self.round_table_conv_module(topic=self.runner_argument.topic, 
                                                        current_expert=current_expert,
                                                        conversation_summary=conversation_summary,
                                                        last_conv_turn=last_conv_turn).conversation_turn

                def polish_subtask():
                    self.round_table_conv_module.polish_utterance(conversation_turn=conv_turn, last_conv_turn=last_conv_turn)
                
                def update_expert_list_subtask():
                    if is_last_turn_questioning and not self.runner_argument.disable_multi_experts:
                        with self.logging_wrapper.log_event("RoundTableConversation.update_expert_list_from_utterance"):
                            self.update_expert_list_from_utterance(focus=last_conv_turn.raw_utterance,
                                                                background_info=conv_turn.raw_utterance)
                
                with concurrent.futures.ThreadPoolExecutor() as executor:
                    polish_future = executor.submit(polish_subtask)
                    update_expert_future = executor.submit(update_expert_list_subtask)
                    concurrent.futures.wait([polish_future, update_expert_future])
                    polish_future.result()
                
            if conv_turn is not None:
                self.conversation_history.append(conv_turn)
                with self.logging_wrapper.log_event("RoundTableConversation.update_conv_turn_to_knowledge_base"):
                    update_conv_turn_to_knowledge_base(conv_turn=conv_turn,
                                                       kb=self.knowledge_base,
                                                       information_insert_module=self.information_insert_module,
                                                       allow_create_new_node=True,
                                                       insert_under_root=self.runner_argument.rag_only_baseline_mode)
        return conv_turn

    def get_next_turn_experts_for_frontend_render(self):
        if self.runner_argument.rag_only_baseline_mode:
            return None
        elif not self.runner_argument.disable_moderator and self._should_generate_question():
            return "Moderator"
        else:
            is_last_turn_questioning = self.conversation_history and self.conversation_history[-1].utterance_type in ["Original Question", "Information Request"]
            current_expert = "General Knowledge Provider"
            if not is_last_turn_questioning and not self.runner_argument.disable_multi_experts:
                current_expert = self.get_next_expert(dry_run=True)
                current_expert = current_expert.split(":")[0]
            return current_expert
