import os
os.environ['TOKENIZERS_PARALLELISM'] = 'false'
os.environ['CONTAINER'] = 'native'
import json
import heapq
import transformers
import subprocess
import vllm
import time
import sys
import re
from datetime import datetime
from tqdm import tqdm, trange
from pathlib import Path
from lean_dojo import *
from loguru import logger
logger.add(
    "output/log/log_{time}.log", 
    encoding="utf-8", 
    enqueue=True, 
    rotation="200MB", 
    compression="zip",
    level="INFO")
from agent import draft,prove,reverse
from agent.gpt_api import request_gpt
from utils.common import generate_vllm,_load_data,_prompt_fn,_prompt_proofstep,_tactic_state,_unique_sorted
from utils.common import print_stats,resume_from,make_output_dir,_save,_make_save,hypo_overlap_check
from re_assist import KnowledgeRetrieval

def steps_concat(steps:list):
    concated_steps = ','.join(steps)
    concated_steps_str = concated_steps.replace(',,',',')
    return concated_steps_str

def agent_search(
        theorem,
        formal_statement,
        max_iters,
        temperatures,
        num_samples,
        prompt_fn,
        retriver,
        timeout=600,
        early_stop=True,
        max_tokens=256,
        use_hyp = True,
        use_suffice = False,
        use_model_proof = False,
) -> dict:
    """Best first search that tries one attempt per temperature."""
    attempt_results = []
    token_prompt_list = []
    token_compli_list = []
    informal_draft_agent = draft.InformalProofDraft(llm_engine='gpt-4-turbo-128k',is_conversation=True,temperature=0.2)
    
    pseudo_agent = draft.PseudoProofDraft(llm_engine='gpt-4-turbo-128k',is_conversation=False,temperature=0.2)
    nextstep_align_agent = draft.AlignProofDraft(llm_engine='gpt-4-turbo-128k',is_conversation=False,temperature=0.2)
    state_explain_agent = draft.StateExplainer(llm_engine="gpt-4-turbo-128k",is_conversation=False,temperature=0.2)
    reverse_have_agent = reverse.ReverseHypo(curiosity='lv1',k=8,is_conversation=False,temperature=0.0)


    
    informal_statement, informal_proof, token_prompt, token_compli = informal_draft_agent.run(
        formal_statement,
        theorem.full_name,
        exact_match=not use_model_proof,
        
        )
    logger.info(f"# informal proof: \n{informal_proof}")
    token_prompt_list.append(token_prompt)
    token_compli_list.append(token_compli)
    full_pseudo_code, token_prompt, token_compli = pseudo_agent.run(informal_statement,informal_proof,formal_statement)
    logger.info(f"# pseudo code \n{full_pseudo_code}")
    token_prompt_list.append(token_prompt)
    token_compli_list.append(token_compli)
    # key state; value: [lemma expressions]
    hypothese_trace = {}
    pseudo_code = ""
    try:
        with Dojo(theorem, hard_timeout=timeout) as (dojo, init_state):
            for temperature in [temperatures]:
                start = time.time()
                proof_finished = False
                queue = [(0.0, [], init_state, [],"")]
                visited = set()
                formal_prove_agent = prove.NextTacticFroamProof(retrive_agent=retriver,llm_engine='gpt-4-turbo-128k',k=8,is_conversation=False,temperature=0.0)
                error_db = []
                for iteration in trange(max_iters):
                    if len(queue) == 0 or proof_finished:
                        break
                    total_score, steps, state, trace, proof_ckpt = heapq.heappop(queue)
                    ts = _tactic_state(state)
                    re_premises,_,_ = retriver.premise_summerizer.run(full_pseudo_code,ts)
                    prefix_hypoth = []
                    ckpt_state = state
                    if use_hyp:
                        hypoth_tactics, token_prompt, token_compli = reverse_have_agent.run(
                            ts,#+'\n'+pseudo_code,
                            full_pseudo_code, 
                            informal_statement, 
                            informal_proof,
                            premise = re_premises,
                            )
                        # add these tactics into current hypotheses
                        token_prompt_list.append(token_prompt)
                        token_compli_list.append(token_compli)
                        logger.info(f"# phypo\n{hypoth_tactics}")
                        for step in hypoth_tactics:
                            result = dojo.run_tac(state, step) # dfs
                            step_trace = {
                                "tactic": step,
                                "state_before": _tactic_state(state)
                            }
                            if isinstance(result, ProofFinished):
                                raise RuntimeError
                            elif isinstance(result, TacticState):
                                overlap_flag = hypo_overlap_check(_tactic_state(result),_tactic_state(state))
                                
                                if not overlap_flag and 'sorry' not in step:
                                    logger.critical(f"# New phypo\n{step}")
                                    prefix_hypoth.append(step)
                                    state = result
                    ts = _tactic_state(state)
                    state_informal, token_prompt, token_compli = state_explain_agent.run(ts)
                    token_prompt_list.append(token_prompt)
                    token_compli_list.append(token_compli)

                    pseudo_code_aglin, token_prompt, token_compli = nextstep_align_agent.run(
                        full_pseudo_code,
                        ts+f'\n/- {state_informal} -/\n'
                        )
                    token_prompt_list.append(token_prompt)
                    token_compli_list.append(token_compli)

                    # pseudo_code = full_pseudo_code
                    pseudo_code = f"[Current State]\n{ts}\n/- {state_informal} -/\n[Next Step]\n{pseudo_code_aglin}\n"
                    logger.info(f"# Pseudo Code\n{pseudo_code}")
                    visited.add(ts)
                    response,token_prompt, token_compli = formal_prove_agent.run(
                        state=ts,
                        pseudo_code=informal_proof+'\n'+pseudo_code,
                        informal_statement=informal_statement,
                        theorem_name=theorem_name,
                        premise_inject = re_premises
                    )
                    # logger.info(f"# Output\n{model_input_prompt}\n{response}")
                    token_prompt_list.append(token_prompt)
                    token_compli_list.append(token_compli)

                    search_steps = response#formal_prove_agent.parse_results(model_input_prompt+response)
                    ckpt_state = state
                    expand_success_flag = False
                    logger.info(f"# proof stpes: {search_steps}")
                    prun_queue = {}
                    for _step in search_steps:
                        step = _step
                        result = dojo.run_tac(state, step) # bfs
                        step_trace = {
                            "tactic": step,
                            "state_before": _tactic_state(state)
                        }
                        if isinstance(result, ProofFinished):
                            
                            step = steps_concat(prefix_hypoth+[_step])
                            logger.info(f"# success tac\n{step}")
                            attempt_results.append({
                                'theorem': theorem.full_name,
                                'proof': steps + [step],
                                'score': 0,
                                'success': True,
                                'failure_reason': '',
                                'trace': trace + [step_trace],
                                'temperature': temperatures,
                                'elapsed': start - time.time(),
                                'iteration': iteration,
                                'prompt_tokens':sum(token_prompt_list),
                                'compli_tokens':sum(token_compli_list)
                            })
                            if early_stop:
                                return attempt_results
                            proof_finished = True
                            break
                        elif isinstance(result, TacticState):
                            # count the state to perform error feedback
                            result_state = _tactic_state(result)
                            prun_flag = False
                            tolerant = 0
                            if _tactic_state(result) not in prun_queue:
                                prun_queue[result_state] = 1
                            elif prun_queue[result_state] <= tolerant:
                                prun_queue[result_state] += 1
                            else:
                                prun_flag = True
                            step = steps_concat(prefix_hypoth+[_step])
                            expand_success_flag = True
                            
                            if _tactic_state(result) not in visited and not prun_flag:
                                # suffices reverse and push them in
                                logger.info(f"# success tac\n{step}")
                                ckpt_state = _tactic_state(result)
                                
                                ckpt_result = result
                                heapq.heappush(
                                    queue, (len(ckpt_state), steps+[step], ckpt_result, trace+[step_trace], "")
                                )
                        elif isinstance(result,LeanError):
                            # reconstruct error examples
                            if _step not in ['norm_num', 'nlinarith', 'ring', 'dec_trivial!']:

                                error_msg = result.error
 
                                error_log = {
                                    "State":_tactic_state(state).strip(),
                                    "Tactic":step.strip(),
                                    "Error":error_msg.strip(),
                                }
                                error_db.append(error_log)
                    if len(queue) == 0 or proof_finished:
                        break
                    #         pass # add to error database

                        # heapq.heappush(
                                # queue, (1, steps, state, trace, "")
                                # )
    except (DojoInitError, DojoHardTimeoutError, DojoCrashError, subprocess.CalledProcessError) as e:
        if len(attempt_results) == 0:
            attempt_results.append({
                'theorem': theorem.full_name,
                'success': False,
                'failure_reason': type(e).__name__,
                'prompt_tokens':sum(token_prompt_list),
                'compli_tokens':sum(token_compli_list)
            })

    if len(attempt_results) == 0:
        attempt_results.append({
            'theorem': theorem.full_name,
            'success': False,
            'failure_reason': 'SearchEnded',
            'prompt_tokens':sum(token_prompt_list),
            'compli_tokens':sum(token_compli_list)
        })

    return attempt_results



if __name__ == '__main__':
    import argparse

    parser = argparse.ArgumentParser()
    parser.add_argument(
        '--model-name',
        default='gpt',
    )
    parser.add_argument(
        '--dataset-name',
        default='minif2f-valid',
        choices=['minif2f-valid', 'minif2f-test', 'leandojo']
    )
    parser.add_argument('--shard', type=int, required=True)
    parser.add_argument('--resume-from', type=str, default=None)
    parser.add_argument('--dataset-path', default='data/minif2f.jsonl')
    parser.add_argument('--output-dir', default='output/')
    parser.add_argument('--early-stop', action='store_true')
    parser.add_argument('--num-shards', type=int, default=1)
    parser.add_argument('--max-iters', type=int, default=500)
    parser.add_argument('--timeout', type=int, default=600)
    parser.add_argument('--num-examples', type=int, default=-1)
    parser.add_argument('--num-samples', type=int, default=32)
    parser.add_argument('--clear-process-hours', type=int, default=3)
    parser.add_argument('--temperatures', type=float, nargs='+', default=[0.0])
    parser.add_argument('--resume-ckpt', action="store_true")
    parser.add_argument('--retrieval-device',type=int, default=3)
    parser.add_argument('--target-q', action="store_true")
    parser.add_argument('--use-model-proof', action="store_true")
    


    args = parser.parse_args()

    output_dir = make_output_dir(args.output_dir)
    
    logger.info(f"#args\n{args}")


    repo, data = _load_data(args.dataset_name, args.dataset_path, use_target_q=args.target_q)
    shard_size = len(data) // args.num_shards
    # data = data[args.shard*shard_size:(args.shard+1)*shard_size]
    print("Shard size: %d" % (len(data)))
    
    if args.resume_from is not None:
        results, data = resume_from(args.resume_from, data)
    else:
        results = []

    cache_log_file = _make_save(
        model_name=args.model_name,
        results=results,
        args_dict=args.__dict__,
        output_dir=output_dir,
        shard=args.shard
    )
    skip_examples = []
    resume_ckpt = args.resume_ckpt
    give_up = ['numbertheory_4x3m7y3neq2003','amc12a_2019_p21']
    # give_up = []
    skip_examples = skip_examples+give_up
    if os.path.exists(cache_log_file):
        cache_log_content = json.load(open(cache_log_file))
        for res in cache_log_content['results']:
            if res['success']:
                results.append(res)
                skip_examples.append(res['example']['full_name'])
            elif res['example']['full_name'] in give_up:
                results.append(res)
            elif resume_ckpt and not args.target_q:
                results.append(res)
                skip_examples.append(res['example']['full_name'])
    
    retriver = KnowledgeRetrieval(device_id=args.retrieval_device,is_conversation=False)
    start = time.time()
    for example in tqdm(data, total=len(data)):
        file_path = example['file_path']
        theorem_name = example['full_name']
        if theorem_name in skip_examples:
            continue
        formal_statement = example['statement']
        logger.info(f"Running on {theorem_name}")
        theorem = Theorem(repo, file_path, theorem_name)
        attempt_results = agent_search(
            theorem,
            formal_statement = formal_statement,
            max_iters=args.max_iters,
            prompt_fn=_prompt_fn(args.model_name),
            temperatures=args.temperatures,
            num_samples=args.num_samples,
            retriver=retriver,
            timeout=args.timeout,
            early_stop=args.early_stop,
            use_model_proof = args.use_model_proof
        )
        result = {
            'attempt_results': attempt_results,
            'success': any([x['success'] for x in attempt_results]),
            'example': example
        }
        results.append(result)

        _save(
            model_name=args.model_name,
            results=results,
            args_dict=args.__dict__,
            output_dir=output_dir,
            shard=args.shard
        )

        print_stats(results)
        result_success = result['success']
        logger.info(f"Success? {result_success}")

