import os
from dataclasses import dataclass, field

from storm_wiki.modules.callback import BaseCallbackHandler
from storm_wiki.modules.persona_generator import StormPersonaGenerator
from storm_wiki.modules.storm_dataclass import StormInformationTable, StormArticle
from utils import FileIOHelper, makeStringRed
from storm_wiki.engine import STORMWikiRunner, STORMWikiRunnerArguments, STORMWikiLMConfigs
from collaborative_storm.baseline.modified_storm_knowledge_curation import ModifiedStormKnowledgeCurationModule

@dataclass
class ModifiedSTORMWikiRunnerArguments:
    """Arguments for controlling the STORM Wiki pipeline."""
    output_dir: str = field(
        metadata={"help": "Output directory for the results."},
    )
    total_conv_turn: int = field(
        default=10,
        metadata={"help": "Maximum number of questions in conversational question asking."},
    )
    max_perspective: int = field(
        default=3,
        metadata={"help": "Maximum number of perspectives to consider in perspective-guided question asking."},
    )
    max_conv_turn: int = field(
        default=3,
        metadata={"help": "Maximum number of turn per perspective."},
    )
    max_search_queries_per_turn: int = field(
        default=3,
        metadata={"help": "Maximum number of search queries to consider in each turn."},
    )
    disable_perspective: bool = field(
        default=False,
        metadata={"help": "If True, disable perspective-guided question asking."},
    )
    search_top_k: int = field(
        default=3,
        metadata={"help": "Top k search results to consider for each search query."},
    )
    retrieve_top_k: int = field(
        default=3,
        metadata={"help": "Top k collected references for each section title."},
    )
    max_thread_num: int = field(
        default=10,
        metadata={"help": "Maximum number of threads to use. "
                          "Consider reducing it if keep getting 'Exceed rate limit' error when calling LM API."},
    )


class ModifiedSTORMWikiRunner(STORMWikiRunner):
    """STORM Wiki pipeline runner."""

    def __init__(self,
                 args: STORMWikiRunnerArguments,
                 lm_configs: STORMWikiLMConfigs):
        super().__init__(args=args, lm_configs=lm_configs)
        storm_persona_generator = StormPersonaGenerator(self.lm_configs.question_asker_lm)
        self.storm_knowledge_curation_module = ModifiedStormKnowledgeCurationModule(
            retriever=self.retriever,
            persona_generator=storm_persona_generator,
            conv_simulator_lm=self.lm_configs.conv_simulator_lm,
            question_asker_lm=self.lm_configs.question_asker_lm,
            max_search_queries_per_turn=self.args.max_search_queries_per_turn,
            search_top_k=self.args.search_top_k,
            total_conv_turn=self.args.total_conv_turn,
            max_num_perspective=self.args.max_perspective,
            max_turn_per_perspective=self.args.max_conv_turn,
            max_thread_num=self.args.max_thread_num
        )

    def run_knowledge_curation_module(self,
                                      intent,
                                      ground_truth_url: str = "None",
                                      callback_handler: BaseCallbackHandler = None) -> StormInformationTable:

        information_table, conversation_log = self.storm_knowledge_curation_module.research(
            topic=self.topic,
            intent=intent,
            ground_truth_url=ground_truth_url,
            callback_handler=callback_handler,
            max_perspective=self.args.max_perspective,
            disable_perspective=self.args.disable_perspective,
            return_conversation_log=True
        )

        FileIOHelper.dump_json(conversation_log, os.path.join(self.article_output_dir, 'conversation_log.json'))
        information_table.dump_url_to_info(os.path.join(self.article_output_dir, 'raw_search_results.json'))
        return information_table


    def run(self,
            topic: str,
            intent: str,
            ground_truth_url: str = '',
            do_research: bool = True,
            do_generate_outline: bool = True,
            do_generate_article: bool = True,
            do_polish_article: bool = True,
            remove_duplicate: bool = False,
            callback_handler: BaseCallbackHandler = BaseCallbackHandler()):
        """
        Run the STORM pipeline.

        Args:
            topic: The topic to research.
            ground_truth_url: A ground truth URL including a curated article about the topic. The URL will be excluded.
            do_research: If True, research the topic through information-seeking conversation;
             if False, expect conversation_log.json and raw_search_results.json to exist in the output directory.
            do_generate_outline: If True, generate an outline for the topic;
             if False, expect storm_gen_outline.txt to exist in the output directory.
            do_generate_article: If True, generate a curated article for the topic;
             if False, expect storm_gen_article.txt to exist in the output directory.
            do_polish_article: If True, polish the article by adding a summarization section and (optionally) removing
             duplicated content.
            remove_duplicate: If True, remove duplicated content.
            callback_handler: A callback handler to handle the intermediate results.
        """
        assert do_research or do_generate_outline or do_generate_article or do_polish_article, \
            makeStringRed("No action is specified. Please set at least one of --do-research, --do-generate-outline, --do-generate-article, --do-polish-article")

        self.topic = topic
        self.article_dir_name = topic.replace(' ', '_').replace('/', '_')
        self.article_output_dir = os.path.join(self.args.output_dir, self.article_dir_name)
        os.makedirs(self.article_output_dir, exist_ok=True)

        # research module
        information_table: StormInformationTable = None
        if do_research:
            information_table = self.run_knowledge_curation_module(intent=intent,
                                                                   ground_truth_url=ground_truth_url,
                                                                   callback_handler=callback_handler)
        # outline generation module
        outline: StormArticle = None
        if do_generate_outline:
            # load information table if it's not initialized
            if information_table is None:
                 information_table = self._load_information_table_from_local_fs(os.path.join(self.article_output_dir, 'conversation_log.json'))
            outline = self.run_outline_generation_module(information_table=information_table,
                                                         callback_handler=callback_handler)

        # article generation module
        draft_article: StormArticle = None
        if do_generate_article:
            if information_table is None:
                 information_table = self._load_information_table_from_local_fs(os.path.join(self.article_output_dir, 'conversation_log.json'))
            if outline is None:
                outline = self._load_outline_from_local_fs(topic=topic, outline_local_path=os.path.join(self.article_output_dir, 'storm_gen_outline.txt'))
            draft_article = self.run_article_generation_module(outline=outline,
                                                               information_table=information_table,
                                                               callback_handler=callback_handler)

        # article polishing module
        if do_polish_article:
            if draft_article is None:
                draft_article_path = os.path.join(self.article_output_dir, 'storm_gen_article.txt')
                url_to_info_path = os.path.join(self.article_output_dir, 'url_to_info.json')
                draft_article =  self._load_draft_article_from_local_fs(topic=topic, draft_article_path=draft_article_path, url_to_info_path=url_to_info_path)
            self.run_article_polishing_module(draft_article=draft_article, remove_duplicate=remove_duplicate)
