from __future__ import annotations

import os
os.environ["TOKENIZERS_PARALLELISM"] = "false"
import os.path as osp
import json
import argparse
from datetime import datetime

from omegaconf import OmegaConf
from termcolor import colored
from tqdm import tqdm

from mcts_math.agents import SBSREACT
from mcts_math.agents import MCTS
from mcts_math.solver import Solver
from mcts_math.config import BaseConfig
from react_demo import load_qaf
from react_batch_demo import batch


def parse_args():
    args = argparse.ArgumentParser()
    args.add_argument('--custom_cfg', type=str, default="")
    args.add_argument(
        "--qaf", "--question-answer-file", 
        type=str, 
        default="",
        help="the file includes question / partial solution (optional) / answer (optional)")
    args.add_argument('--output_dir', type=str, default="")
    args.add_argument('--model_ckpt', type=str, default="")

    args = args.parse_args()
    return args


if __name__ == '__main__':
    args = parse_args()

    config = OmegaConf.structured(BaseConfig)
    if args.custom_cfg:
        custom_config = OmegaConf.load(args.custom_cfg)
        config = OmegaConf.merge(config, custom_config)
    config = OmegaConf.create(OmegaConf.to_yaml(config, resolve=True))
    config.model_dir = args.model_ckpt
    print(config)

    llm_version = os.path.basename(config.model_dir.rstrip("/"))

    data = load_qaf(args.qaf)
    solver = Solver(config=config)

    # init method
    if config.mode == "mcts":
        method = MCTS
        method_tag = f"{config.mode}-{config.step_beam_width}-{config.n_generate_sample}.{config.temperature}"
    elif config.mode == "sbs":
        method = SBSREACT
        method_tag = f"{config.mode}-{config.step_beam_width}-{config.n_generate_sample}.{config.temperature}"
    else:
        raise NotImplementedError

    if 'checkpoint' in config.model_dir:
        args.output_dir = osp.join(args.output_dir, osp.basename(osp.dirname(config.model_dir)), osp.basename(config.model_dir))
    
    else:
        args.output_dir = osp.join(args.output_dir, osp.basename(config.model_dir))
    
    os.makedirs(args.output_dir, exist_ok=True)
    # saved_jsonl_file = f"{args.output_dir}{config.mode}.{llm_version}.{datetime.now().strftime('%Y%m%d%H%M%S')}.jsonl" 
    saved_jsonl_file = osp.join(args.output_dir, config.mode, f"{osp.basename(args.qaf)}.{method_tag}.{llm_version}.{datetime.now().strftime('%Y%m%d%H%M%S')}.jsonl")
    os.makedirs(osp.dirname(saved_jsonl_file), exist_ok=True)
    
    with open(saved_jsonl_file, "w") as writer:
        for cur_data in tqdm(batch(data, config.batch_size), desc="Main Processing"):
            agents = [method(config=config, question=d["question"], ground_truth=d["answer"] if config.is_sampling else None) 
                      for d in cur_data]
            jsonlines = solver.solve(agents)
            for d in cur_data:
                question = d["question"]
                d["react"] = jsonlines[question]
                writer.write(json.dumps(d, ensure_ascii=False) + '\n')
                writer.flush()