from storm_wiki.modules.knowledge_curation import StormKnowledgeCurationModule, WikiWriter, TopicExpert
from storm_wiki.modules.storm_dataclass import DialogueTurn, StormInformationTable, StormInformation

from typing import Union, List, Tuple, Optional, Dict

import logging
import dspy
from src.interface import Retriever
from storm_wiki.modules.callback import BaseCallbackHandler
from storm_wiki.modules.persona_generator import StormPersonaGenerator
from storm_wiki.modules.storm_dataclass import StormInformationTable

import concurrent.futures
from concurrent.futures import as_completed
from utils import ArticleTextProcessing


class ModifiedConvSimulator(dspy.Module):
    """Simulate a conversation between a Wikipedia writer with specific persona and an expert."""

    def __init__(self, topic_expert_engine: Union[dspy.dsp.LM, dspy.dsp.HFModel],
                 question_asker_engine: Union[dspy.dsp.LM, dspy.dsp.HFModel],
                 retriever: Retriever, max_search_queries_per_turn: int, search_top_k: int, max_turn: int):
        super().__init__()
        self.wiki_writer = WikiWriter(engine=question_asker_engine)
        self.topic_expert = TopicExpert(
            engine=topic_expert_engine,
            max_search_queries=max_search_queries_per_turn,
            search_top_k=search_top_k,
            retriever=retriever
        )
        self.max_turn = max_turn

    def forward(self, topic: str, persona: str, ground_truth_url: str, callback_handler: BaseCallbackHandler, dlg_history: Optional[List[DialogueTurn]]=None):
        """
        topic: The topic to research.
        persona: The persona of the Wikipedia writer.
        ground_truth_url: The ground_truth_url will be excluded from search to avoid ground truth leakage in evaluation.
        """
        dlg_history: List[DialogueTurn] = [] if dlg_history is None else dlg_history
        init_dlg_history_count =len(dlg_history)
        for _ in range(self.max_turn):
            user_utterance = self.wiki_writer(topic=topic, persona=persona, dialogue_turns=dlg_history).question
            if user_utterance == '':
                logging.error('Simulated Wikipedia writer utterance is empty.')
                break
            if user_utterance.startswith('Thank you so much for your help!'):
                break
            expert_output = self.topic_expert(topic=topic, question=user_utterance, ground_truth_url=ground_truth_url)
            dlg_turn = DialogueTurn(
                agent_utterance=expert_output.answer,
                user_utterance=user_utterance,
                search_queries=expert_output.queries,
                search_results=expert_output.searched_results
            )
            dlg_history.append(dlg_turn)
            if callback_handler is not None:
                callback_handler.on_dialogue_turn_end(dlg_turn=dlg_turn)

        dlg_history = dlg_history[init_dlg_history_count:]
        return dspy.Prediction(dlg_history=dlg_history)
    
class ModifiedStormKnowledgeCurationModule(StormKnowledgeCurationModule):
    """
    The interface for knowledge curation stage. Given topic, return collected information.
    """

    def __init__(self,
                 retriever: Retriever,
                 persona_generator: Optional[StormPersonaGenerator],
                 conv_simulator_lm: Union[dspy.dsp.LM, dspy.dsp.HFModel],
                 question_asker_lm: Union[dspy.dsp.LM, dspy.dsp.HFModel],
                 max_search_queries_per_turn: int,
                 search_top_k: int,
                 total_conv_turn: int,
                 max_num_perspective: int,
                 max_turn_per_perspective: int,
                 max_thread_num: int):
        """
        Store args and finish initialization.
        """
        super().__init__(retriever=retriever,
                         persona_generator=persona_generator,
                         conv_simulator_lm=conv_simulator_lm,
                         question_asker_lm=question_asker_lm,
                         max_search_queries_per_turn=max_search_queries_per_turn,
                         search_top_k=search_top_k,
                         max_conv_turn=max_turn_per_perspective,
                         max_thread_num=max_thread_num)
        self.max_num_perspective = max_num_perspective
        self.max_turn_per_perspective = max_turn_per_perspective
        self.total_conv_turn = total_conv_turn
        self.conv_simulator = ModifiedConvSimulator(
            topic_expert_engine=conv_simulator_lm,
            question_asker_engine=question_asker_lm,
            retriever=retriever,
            max_search_queries_per_turn=max_search_queries_per_turn,
            search_top_k=search_top_k,
            max_turn=max_turn_per_perspective
        )

    def _run_conversation(self, conv_simulator, topic, ground_truth_url, considered_personas,
                          callback_handler: BaseCallbackHandler,
                          dlg_history: Optional[List[DialogueTurn]]=None) -> List[Tuple[str, List[DialogueTurn]]]:
        """
        Executes multiple conversation simulations concurrently, each with a different persona,
        and collects their dialog histories. The dialog history of each conversation is cleaned
        up before being stored.

        Parameters:
            conv_simulator (callable): The function to simulate conversations. It must accept four
                parameters: `topic`, `ground_truth_url`, `persona`, and `callback_handler`, and return
                an object that has a `dlg_history` attribute.
            topic (str): The topic of conversation for the simulations.
            ground_truth_url (str): The URL to the ground truth data related to the conversation topic.
            considered_personas (list): A list of personas under which the conversation simulations
                will be conducted. Each persona is passed to `conv_simulator` individually.
            callback_handler (callable): A callback function that is passed to `conv_simulator`. It
                should handle any callbacks or events during the simulation.

        Returns:
            list of tuples: A list where each tuple contains a persona and its corresponding cleaned
            dialog history (`dlg_history`) from the conversation simulation.
        """

        conversations = []

        def run_conv(persona):
            return conv_simulator(
                topic=topic,
                ground_truth_url=ground_truth_url,
                persona=persona,
                callback_handler=callback_handler,
                dlg_history=dlg_history
            )

        max_workers = min(self.max_thread_num, len(considered_personas))

        with concurrent.futures.ThreadPoolExecutor(max_workers=max_workers) as executor:
            future_to_persona = {executor.submit(run_conv, persona): persona for persona in considered_personas}

            for future in as_completed(future_to_persona):
                persona = future_to_persona[future]
                conv = future.result()
                conversations.append((persona, ArticleTextProcessing.clean_up_citation(conv).dlg_history))

        return conversations
    
    def research(self,
                 topic: str,
                 intent: str,
                 ground_truth_url: str,
                 callback_handler: BaseCallbackHandler,
                 max_perspective: int = 0,
                 disable_perspective: bool = True,
                 return_conversation_log=False) -> Union[StormInformationTable, Tuple[StormInformationTable, Dict]]:
        """
        Curate information and knowledge for the given topic

        Args:
            topic: topic of interest in natural language.
        
        Returns:
            collected_information: collected information in InformationTable type.
        """

        # perspective guided QA
        conversations = []
        if max_perspective > 0:
            considered_personas = []
            if disable_perspective:
                considered_personas = [""]
            else:
                considered_personas = self._get_considered_personas(topic=topic, max_num_persona=max_perspective)

            self.conv_simulator.max_turn = self.max_turn_per_perspective
            conversations.extend(self._run_conversation(conv_simulator=self.conv_simulator,
                                                        topic=topic,
                                                        ground_truth_url=ground_truth_url,
                                                        considered_personas=considered_personas,
                                                        callback_handler=callback_handler,
                                                        dlg_history=None))

        # simulated user QA
        conv_history = sum([dlg for _, dlg in conversations], [])
        self.conv_simulator.max_turn = self.total_conv_turn - len(conv_history)
        conversations.extend(self._run_conversation(conv_simulator=self.conv_simulator,
                                                    topic=topic,
                                                    ground_truth_url=ground_truth_url,
                                                    considered_personas=[f"researcher with interest in {intent}"],
                                                    callback_handler=callback_handler,
                                                    dlg_history=conv_history))
        
        information_table = StormInformationTable(conversations)
        callback_handler.on_information_gathering_end()
        if return_conversation_log:
            return information_table, StormInformationTable.construct_log_dict(conversations)
        return information_table