import concurrent.futures
import copy
import os
import json
import sys
import traceback


from argparse import ArgumentParser
from tqdm import tqdm
from threading import Lock

script_dir = os.path.dirname(os.path.abspath(__file__))
src_root = os.path.join(script_dir, "..", "..")
sys.path.append(src_root)

from collaborative_storm.modules.collaborative_storm_utils import load_api_key
from collaborative_storm.engine import CollaborativeStormLMConfigs, RunnerArgument, RoundTableConversation, LoggingWrapper
from collaborative_storm.modules.dataclass import KnowledgeBase


def process_article(args, topic, intent, progress_lock=None, progress_bar=None):
    round_table_conversation = None
    article_output_dir = ""
    try:
        # load api key
        load_api_key(toml_file_path=os.path.join(src_root, "..", "secrets.toml"))
        # LM config
        lm_config: CollaborativeStormLMConfigs = CollaborativeStormLMConfigs()
        lm_config.init_openai_model(openai_api_key=os.getenv("OPENAI_API_KEY"),
                                    openai_type=os.getenv('OPENAI_API_TYPE'))
        # set runner argument
        runner_argument = RunnerArgument(topic=topic, 
                                            warmstart_max_num_experts=args.max_perspective, 
                                            warmstart_max_turn_per_experts=args.max_turn_per_perspective, 
                                            max_search_queries=args.max_search_queries_per_turn,
                                            total_conv_turn=args.total_conv_turn,
                                            retrieve_top_k=args.search_top_k,
                                            max_search_thread=1,
                                            disable_moderator=args.disable_moderator,
                                            disable_multi_experts=args.disable_multi_experts,
                                            moderator_override_N_consecutive_answering_turn=args.moderator_override_N_consecutive_answering_turn)
        
        
        # setup output dir
        article_dir_name = topic.replace(' ', '_').replace('/', '_')
        article_output_dir = os.path.join(args.output_dir, article_dir_name)
        os.makedirs(article_output_dir, exist_ok=True)

        if os.path.exists(os.path.join(article_output_dir, "report.txt")):
            return

        # set logging wrapper
        logging_wrapper = LoggingWrapper(lm_config)

        # init round table conversation
        round_table_conversation = RoundTableConversation(lm_config=lm_config,
                                                        runner_argument=runner_argument,
                                                        logging_wrapper=logging_wrapper)

        should_skip_warm_start = False
        if args.warm_start_override_dir:
            warm_start_snapshot_path = os.path.join(args.warm_start_override_dir, article_dir_name, "warm_start_snapshot_round_table_dump.json")
            if os.path.exists(warm_start_snapshot_path):
                with open(warm_start_snapshot_path) as f:
                    data = json.load(f)
                    warm_start_snapshot = RoundTableConversation.from_dict(data)
                    round_table_conversation.experts = copy.deepcopy(warm_start_snapshot.experts)
                    round_table_conversation.conversation_history = copy.deepcopy(warm_start_snapshot.conversation_history)
                    round_table_conversation.knowledge_base = KnowledgeBase.from_dict(warm_start_snapshot.knowledge_base.to_dict())
                    should_skip_warm_start = True

        if not should_skip_warm_start:
            round_table_conversation.warm_start()


        def get_total_non_questioning_conv_turn(conversation_turns):
            count = 0
            for turn in conversation_turns:
                if turn.claim_to_make is not None and len(turn.claim_to_make) > 0:
                    count += 1
            return count

        # start discourse
        round_table_conversation.conversation_step(simulate_user=True, simulate_user_intent=intent)

        with open(os.path.join(article_output_dir, "warm_start_snapshot_round_table_dump.json"), "w") as f:
                json.dump(round_table_conversation.to_dict(), f, indent=2)

        while get_total_non_questioning_conv_turn(round_table_conversation.conversation_history) < round_table_conversation.runner_argument.total_conv_turn:
            round_table_conversation.conversation_step()

        # gen report
        round_table_conversation.reorganize_knowledge_base()
        report = round_table_conversation.generate_report()

        # save results
        with open(os.path.join(article_output_dir, "report.txt"), "w") as f:
            f.write(report)
            

    except Exception as e:
        print(f"error processing topic: {{{topic}}}\nError: {e}")
        traceback.print_exc()
    finally:
        if not os.path.exists(os.path.join(article_output_dir, "round_table_dump.json")):
            with open(os.path.join(article_output_dir, "round_table_dump.json"), "w") as f:
                json.dump(round_table_conversation.to_dict(), f, indent=2)
            
            with open(os.path.join(article_output_dir, "log.json"), "w") as f:
                json.dump(round_table_conversation.logging_wrapper.dump_logging_and_reset(), f, indent=2)

        if progress_lock is not None and progress_bar is not None:
            with progress_lock:
                progress_bar.update(1)

def main(args):
    if args.batch_experiment_json_path is not None and os.path.exists(args.batch_experiment_json_path):
        with open(args.batch_experiment_json_path) as f:
            all_exp_data = json.load(f)
            
            progress_lock = Lock()
            with tqdm(total=len(all_exp_data), desc="Processing Articles") as progress_bar:
                # Use ThreadPoolExecutor for parallel processing with max workers set to 10
                with concurrent.futures.ThreadPoolExecutor(max_workers=5) as executor:
                    futures = []
                    for data_point in all_exp_data:
                        futures.append(executor.submit(
                            process_article, args, data_point["topic"], data_point["intent"], progress_lock, progress_bar))
                    
                    # Ensure all futures are handled and catch exceptions
                    for future in concurrent.futures.as_completed(futures):
                        try:
                            future.result()
                        except Exception as e:
                            print(f"Error in processing task: {e}")
    else:
        topic = input("Topic: ")
        intent = input("Intent: ")
        process_article(args, topic, intent)

if __name__ == '__main__':
    parser = ArgumentParser()
    parser.add_argument('--batch-experiment-json-path', type=str, default=None,
                        help='Path to the JSON file for batch experiments.')
    # global arguments
    parser.add_argument('--output-dir', type=str, default='./results/gpt',
                        help='Directory to store the outputs.')
    parser.add_argument('--warm-start-override-dir', type=str,
                        help='directory of output where can borrow warm start data')
    parser.add_argument('--total-conv-turn', type=int, default=15,
                        help='Total number of conversation turns.')
    parser.add_argument('--max-perspective', type=int, default=0,
                        help='Maximum number of perspectives.')
    parser.add_argument('--moderator-override-N-consecutive-answering-turn', type=int, default=3,
                        help='Maximum number of turns moderator override')
    parser.add_argument('--max-turn-per-perspective', type=int, default=2,
                        help='Maximum number of turns per perspective.')
    parser.add_argument('--max-search-queries-per-turn', type=int, default=2,
                        help='Maximum number of search queries per turn.')
    parser.add_argument('--search-top-k', type=int, default=10,
                        help='Top k results for each query in the search.')
    parser.add_argument('--disable-moderator', action='store_true', default=False,
                    help='Disable the moderator functionality.')
    parser.add_argument('--disable-multi-experts', action='store_true', default=False,
                    help='Disable the multi experts discourse functionality.')
    main(parser.parse_args())
