import dspy
import concurrent.futures
from threading import Lock
from typing import List, Optional

from .expert_generation import GenerateExpertModule
from .grounded_question_answering import AnswerQuestionModule
from .dataclass import ConversationTurn, KnowledgeBase
from storm_wiki.modules.outline_generation import WritePageOutline
from .information_insertion_module import update_conv_turn_to_knowledge_base

class WarmStartModerator(dspy.Signature):
    """
    You are a moderator in a roundtable discussion. The goal is to chat with multiple experts to discuss the facts and background of the topic to familiarize the audience with the topic. 
    You will be presented with the topic, the history of question you have already asked, and the current expert you are discussing with. 
    Based on these information, generate the next question for the current expert to further the discussion.

    The output should only include the next question for the current expert. Do not include any other information or preamble.
    """
    topic = dspy.InputField(prefix="Topic for roundtable discussion: ", format=str)
    history = dspy.InputField(prefix="Experts you have already interacted with: ", format=str)
    current_expert = dspy.InputField(prefix="Expert you are talking with:", format=str)
    question = dspy.OutputField(prefix="Next question for the expert you are talking with: ", format=str)


class WarmStartConversation(dspy.Module):
    def __init__(self,
                 moderator_question_asking_lm,
                 generate_expert_module: GenerateExpertModule,
                 answer_question_module: AnswerQuestionModule,
                 max_num_experts: int=3,
                 max_turn_per_experts: int=2,
                 max_thread: int=3):
        self.ask_question = dspy.Predict(WarmStartModerator)
        self.max_num_experts = max_num_experts
        self.max_turn_per_experts = max_turn_per_experts
        self.moderator_question_asking_lm = moderator_question_asking_lm
        self.answer_question_module = answer_question_module
        self.max_thread = max_thread
        self.generate_experts_module = generate_expert_module
    
    def format_dialogue_question_history_string(self, conversation_history):
        output = []
        for idx, turn in enumerate(conversation_history):
            info =  turn.claim_to_make if turn.claim_to_make else turn.utterance
            output.append(f"{idx + 1}: {info}")
        return "\n".join(output)
    
    def generate_warmstart_experts(self, topic: str):
        background_seeking_dialogue = self.get_background_info(topic=topic)
        background_info = background_seeking_dialogue.utterance
        gen_expert_output = self.generate_experts_module(topic=topic, background_info=background_info, num_experts=self.max_num_experts)
        return gen_expert_output.experts, background_seeking_dialogue

    def get_background_info(self, topic):
        question = f"Background information about {topic}"
        answer = self.answer_question_module(topic=topic, 
                                             question=question,
                                             mode="extensive",
                                             style="conversational")
        
        return ConversationTurn(
            role="Default Background Researcher",
            raw_utterance=answer.response,
            utterance_type="Questioning",
            claim_to_make=question,
            queries=answer.queries,
            raw_retrieved_info=answer.raw_retrieved_info,
            cited_info=answer.cited_info
        )

    def forward(self, topic):
        # do background research, generate some experts
        experts, background_seeking_dialogue = self.generate_warmstart_experts(topic=topic)
        # init list to store the dialogue history
        conversation_history: List[ConversationTurn] = []
        lock = Lock()

        # hierarchical chat: chat with one expert. Generate question, get answer
        def process_expert(expert):
            for idx in range(self.max_turn_per_experts):
                try:
                    with lock:
                        history = self.format_dialogue_question_history_string(conversation_history)
                    with dspy.settings.context(lm=self.moderator_question_asking_lm):
                        question = self.ask_question(topic=topic, history=history, current_expert=expert).question
                    answer = self.answer_question_module(topic=topic, question=question, mode="brief", style="conversational")
                    conversation_turn = ConversationTurn(
                        role=expert,
                        claim_to_make=question,
                        raw_utterance=answer.response,
                        utterance_type="Support",
                        queries=answer.queries,
                        raw_retrieved_info=answer.raw_retrieved_info,
                        cited_info=answer.cited_info
                    )
                    with lock:
                        conversation_history.append(conversation_turn)
                except Exception as e:
                    print(f"Error processing expert {expert}: {e}")

        # multi-thread conversation
        with concurrent.futures.ThreadPoolExecutor(max_workers=self.max_thread) as executor:
            futures = [executor.submit(process_expert, expert) for expert in experts[:min(len(experts), self.max_num_experts)]]
            concurrent.futures.wait(futures)

        conversation_history = [background_seeking_dialogue] + conversation_history

        return dspy.Prediction(conversation_history=conversation_history, experts=experts)

class GenerateWarmStartOutline(dspy.Signature):
    """Generate a outline of the wikipedia-like report from a roundtable discussion. You will be presented discussion points in the conversation and corresponding queries.
       You will be given a draft outline which you can borrow some inspiration. Do not include sections that are not mentioned in the given discussion history.
       Use "#" to denote section headings, "##" to denote subsection headings, and so on.
        Follow these guidelines:
        1. Use "#" for section titles, "##" for subsection titles, "###" for subsubsection titles, and so on.
        2. Do not include any additional information.
        3. Exclude the topic name from the outline.
        The organization of outline should adopt wikiepdia style.
    """

    topic = dspy.InputField(prefix="The topic discussed: ", format=str)
    draft = dspy.InputField(prefix="Draft outline you can reference to: ", format=str)
    conv = dspy.InputField(prefix="Discussion history:\n", format=str)
    outline = dspy.OutputField(
        prefix='Write the conversation outline (Use "#" Title" to indicate section title, "##" Title" to indicate subsection title, ...):\n',
        format=str
    )


class GenerateWarmStartOutlineModule(dspy.Module):
    def __init__(self, engine):
        self.engine = engine
        self.gen_outline = dspy.Predict(GenerateWarmStartOutline)
        self.draft_outline = dspy.Predict(WritePageOutline)
    
    def extract_questions_and_queries(self, conv: List[ConversationTurn]):
        context = []
        for turn in conv:
            focus = turn.claim_to_make
            queries = turn.queries
            queries_string = "\n\t".join(f"Query {idx + 1}: {query}" for idx, query in enumerate(queries))
            string = f"Discussion focus {len(context) + 1}: {focus}\n\t{queries_string}"
            context.append(string)
        return "\n".join(context)
    
    def get_draft_outline(self, topic: str):
        with dspy.settings.context(lm=self.engine):
            return self.draft_outline(topic=topic).outline
    
    def forward(self, topic: str, conv: List[ConversationTurn]):
        discussion_history = self.extract_questions_and_queries(conv)
        draft_outline = self.get_draft_outline(topic=topic)
        with dspy.settings.context(lm=self.engine):
            outline = self.gen_outline(topic=topic, draft=draft_outline, conv=discussion_history).outline
        return dspy.Prediction(outline=outline, draft_outline=draft_outline)
    
class WarmStartModule():
    def __init__(self,
                 warmstart_outline_gen_lm,
                 moderator_question_asking_lm,
                 answer_question_module,
                 information_insertion_module,
                 generate_expert_module,
                 warmstart_max_num_experts: int,
                 warmstart_max_turn_per_experts: int,
                 warmstart_max_thread: int,
                 logging_wrapper
                 ):
        self.warmstart_conv = WarmStartConversation(
            moderator_question_asking_lm=moderator_question_asking_lm,
            generate_expert_module=generate_expert_module,
            answer_question_module=answer_question_module,
            max_num_experts=warmstart_max_num_experts,
            max_turn_per_experts=warmstart_max_turn_per_experts,
            max_thread=warmstart_max_thread)
        self.warmstart_outline_gen_module = GenerateWarmStartOutlineModule(engine=warmstart_outline_gen_lm)
        self.information_insert_module = information_insertion_module
        self.logging_wrapper = logging_wrapper
    
    def initiate_warm_start(self, 
                            topic: str):
        """
        Initiates a warm start process for the given topic by generating a warm start conversation and inserting the 
        resulting information into a knowledge base.

        Args:
            topic (str): The topic for which to initiate the warm start process.

        Returns:
            Tuple[List[ConversationTurn], List[str], KnowledgeBase]:
                - A list of ConversationTurn instances representing the conversation history.
                - A list of strings representing the experts involved in the conversation.
                - A KnowledgeBase instance containing the organized information.
        """
        warm_start_conversation_history: List[ConversationTurn] = []
        warm_start_experts = None
        # get warm start conversations
        with self.logging_wrapper.log_event("WarmStartModule.warmstart_conv"):
            warm_start_result = self.warmstart_conv(topic=topic)
            warm_start_conversation_history = warm_start_result.conversation_history
            warm_start_experts = warm_start_result.experts

        # get warm start conv outline
        with self.logging_wrapper.log_event("WarmStartModule.warmstart_outline_gen_module"):
            warm_start_outline_output = self.warmstart_outline_gen_module(topic=topic, conv=warm_start_conversation_history)
        # init knowledge base
        with self.logging_wrapper.log_event("WarmStartModule.update_conv_turn_to_knowledge_base"):
            kb = KnowledgeBase(topic=topic)
            kb.insert_from_outline_string(outline_string=warm_start_outline_output.outline)
            # insert information to knowledge base
            for turn in warm_start_conversation_history:
                update_conv_turn_to_knowledge_base(conv_turn=turn, 
                                                kb=kb,
                                                information_insert_module=self.information_insert_module,
                                                allow_create_new_node=False)
        return warm_start_conversation_history, warm_start_experts, kb