"""
STORM Wiki pipeline powered by GPT-3.5/4 and You.com search engine.
You need to set up the following environment variables to run this script:
    - OPENAI_API_KEY: OpenAI API key
    - OPENAI_API_TYPE: OpenAI API type (e.g., 'openai' or 'azure')
    - AZURE_API_BASE: Azure API base URL if using Azure API
    - AZURE_API_VERSION: Azure API version if using Azure API
    - YDC_API_KEY: You.com API key

Output will be structured as below
args.output_dir/
    topic_name/  # topic_name will follow convention of underscore-connected topic name w/o space and slash
        conversation_log.json           # Log of information-seeking conversation
        raw_search_results.json         # Raw search results from search engine
        direct_gen_outline.txt          # Outline directly generated with LLM's parametric knowledge
        storm_gen_outline.txt           # Outline refined with collected information
        url_to_info.json                # Sources that are used in the final article
        storm_gen_article.txt           # Final article generated
        storm_gen_article_polished.txt  # Polished final article (if args.do_polish_article is True)
"""
import contextlib
import os
import json
import sys
from argparse import ArgumentParser

import concurrent.futures
from tqdm import tqdm
from threading import Lock

script_dir = os.path.dirname(os.path.abspath(__file__))
src_root = os.path.join(script_dir, "..", "..")
sys.path.append(src_root)

from src.lm import OpenAIModel
from storm_wiki.engine import STORMWikiLMConfigs
from collaborative_storm.baseline.modified_storm_engine import ModifiedSTORMWikiRunner, ModifiedSTORMWikiRunnerArguments
from utils import load_api_key

def process_article(args, topic, intent, progress_lock=None, progress_bar=None):
    try:
        article_dir_name = topic.replace(' ', '_').replace('/', '_')
        article_output_dir = os.path.join(args.output_dir, article_dir_name)
        if os.path.exists(os.path.join(article_output_dir, "storm_gen_article_polished.txt")):
            return
        load_api_key(toml_file_path=os.path.join(src_root, "..", "secrets.toml"))
        lm_configs = STORMWikiLMConfigs()
        openai_kwargs = {
            'api_key': os.getenv("OPENAI_API_KEY"),
            'api_provider': os.getenv('OPENAI_API_TYPE'),
            'temperature': 1.0,
            'top_p': 0.9,
            'api_base': os.getenv('AZURE_API_BASE'),
        }

        conv_simulator_lm = OpenAIModel(model='gpt-4o-2024-05-13', max_tokens=500, **openai_kwargs)
        question_asker_lm = OpenAIModel(model='gpt-4o-2024-05-13', max_tokens=500, **openai_kwargs)
        outline_gen_lm = OpenAIModel(model='gpt-4-1106-preview', max_tokens=400, **openai_kwargs)
        article_gen_lm = OpenAIModel(model='gpt-4o-2024-05-13', max_tokens=700, **openai_kwargs)
        article_polish_lm = OpenAIModel(model='gpt-4o-2024-05-13', max_tokens=4000, **openai_kwargs)

        lm_configs.set_conv_simulator_lm(conv_simulator_lm)
        lm_configs.set_question_asker_lm(question_asker_lm)
        lm_configs.set_outline_gen_lm(outline_gen_lm)
        lm_configs.set_article_gen_lm(article_gen_lm)
        lm_configs.set_article_polish_lm(article_polish_lm)

        engine_args = ModifiedSTORMWikiRunnerArguments(
            output_dir=args.output_dir,
            total_conv_turn=args.total_conv_turn,
            max_perspective=args.max_perspective,
            max_conv_turn= args.max_turn_per_perspective,
            max_search_queries_per_turn=args.max_search_queries_per_turn,
            search_top_k=args.search_top_k
        )
        runner = ModifiedSTORMWikiRunner(engine_args, lm_configs)

        runner.run(
            topic=topic,
            intent=intent,
            do_research=True,
            do_generate_outline=True,
            do_generate_article=True,
            do_polish_article=True,
        )
        runner.post_run()
        with open(os.path.join(runner.article_output_dir, "log.txt"), 'w') as f:
            with contextlib.redirect_stdout(f):
                runner.summary()
    except Exception as e:
        print(f"error processing topic: {{{topic}}}\nError: {e}")
    finally:
        if progress_lock is not None and progress_bar is not None:
            with progress_lock:
                progress_bar.update(1)

def main(args):
    if args.batch_experiment_json_path is not None and os.path.exists(args.batch_experiment_json_path):
        with open(args.batch_experiment_json_path) as f:
            all_exp_data = json.load(f)
            
            progress_lock = Lock()
            with tqdm(total=len(all_exp_data), desc="Processing Articles") as progress_bar:
                # Use ThreadPoolExecutor for parallel processing with max workers set to 10
                with concurrent.futures.ThreadPoolExecutor(max_workers=10) as executor:
                    futures = []
                    for data_point in all_exp_data:
                        futures.append(executor.submit(
                            process_article, args, data_point["topic"], data_point["intent"], progress_lock, progress_bar))
                    
                    # Ensure all futures are handled and catch exceptions
                    for future in concurrent.futures.as_completed(futures):
                        try:
                            future.result()
                        except Exception as e:
                            print(f"Error in processing task: {e}")
    else:

        topic = input("Topic: ")
        intent = input("Intent: ")
        process_article(args, topic, intent)



if __name__ == '__main__':
    parser = ArgumentParser()
    parser.add_argument('--batch-experiment-json-path', type=str, default=None,
                        help='Path to the JSON file for batch experiments.')
    # global arguments
    parser.add_argument('--output-dir', type=str, default='./results/gpt',
                        help='Directory to store the outputs.')
    parser.add_argument('--total-conv-turn', type=int, default=15,
                        help='Total number of conversation turns.')
    parser.add_argument('--max-perspective', type=int, default=0,
                        help='Maximum number of perspectives.')
    parser.add_argument('--max-turn-per-perspective', type=int, default=2,
                        help='Maximum number of turns per perspective.')
    parser.add_argument('--max-search-queries-per-turn', type=int, default=2,
                        help='Maximum number of search queries per turn.')
    parser.add_argument('--search-top-k', type=int, default=10,
                        help='Top k results for each query in the search.')
    main(parser.parse_args())
