 
import os



from collaborative_storm.modules.collaborative_storm_utils import load_api_key, extract_and_remove_citations
from collaborative_storm.engine import CollaborativeStormLMConfigs, RunnerArgument, RoundTableConversation, LoggingWrapper
from collaborative_storm.modules.dataclass import KnowledgeBase

# load api key
load_api_key(toml_file_path=os.path.join("..", "secrets.toml"))

def init(topic: str):
    # LM config
    lm_config: CollaborativeStormLMConfigs = CollaborativeStormLMConfigs()
    lm_config.init_openai_model(
        openai_api_key=os.getenv("OPENAI_API_KEY"),
        azure_api_key=os.getenv("AZURE_API_KEY"),
        openai_type=os.getenv("OPENAI_API_TYPE"),
        api_base=os.getenv("API_BASE"),
        api_version=os.getenv("API_VERSION"),
    )
    
    
    # set runner argument
    runner_argument = RunnerArgument(topic=topic, 
                                     warmstart_max_num_experts=3, 
                                     warmstart_max_turn_per_experts=1, 
                                     max_search_queries=1,
                                    rag_only_baseline_mode=False)
    # set logging wrapper
    logging_wrapper = LoggingWrapper(lm_config)
    # init round table conversation
    round_table_conversation = RoundTableConversation(lm_config=lm_config,
                                                      runner_argument=runner_argument,
                                                     logging_wrapper=logging_wrapper)
    return round_table_conversation

def main():
    topic = input("Topic of interest: ")
    K =  5 # auto run for K turns

    round_table_conversation: RoundTableConversation = init(topic=topic)
    round_table_conversation.warm_start()
    for _ in range(K):
        conv_turn = round_table_conversation.conversation_step()

if __name__ == "__main__":
    main()