import os
os.environ['TOKENIZERS_PARALLELISM'] = 'false'
os.environ['CONTAINER'] = 'native'
# os.environ['NUM_PROCS'] = '1'
import json
import heapq
import transformers
import subprocess
import vllm
import time
from random import random
import sys
import re
# import os
# os.environ["CUDA_VISIBLE_DEVICES"] = 1
from datetime import datetime
from tqdm import tqdm, trange
from pathlib import Path
from lean_dojo import *
# from local_pkg.lean_dojo import *
from loguru import logger
from agent.gpt_api_2 import request
logger.add(
    "output/log/log_{time}.log", 
    encoding="utf-8", 
    enqueue=True, 
    rotation="200MB", 
    compression="zip",
    level="INFO")
from agent import draft,prove,reverse
from agent.gpt_api import request_gpt
# from utils.common import _load_data,_prompt_fn,_prompt_proofstep,_tactic_state,_unique_sorted
from utils.common import print_stats,resume_from,make_output_dir,_save,_make_save,_load_data
from utils.error_history import save_error_log,save_informal_log,save_pseudo_log
from re_assist import KnowledgeRetrieval
from db.reprover_assist_db.reprover_gen import ReproverTacticGen
from db.mathlib_db.build_copra_aug import Lean3Bm25ReRanker,Lean3EncodeRanker
# logger.add(sys.stderr, level="DEBUG")
LEAN4_2_LEAN3 = json.load(open("lea4_to_lean3.json"))
def steps_concat(steps:list):
    concated_steps = ','.join(steps)
    concated_steps_str = concated_steps.replace(',,',',')
    return concated_steps_str

def hypo_overlap_check(state,old_state):
    hypos = [s for s in state.split('\n') if not s.startswith('⊢')]
    hypos_describe = [':'.join(d.split(':')[1:]).replace(',','').strip() for d in hypos]

    old_hypos = [s for s in old_state.split('\n') if not s.startswith('⊢')]
    old_hypos_describe = [':'.join(d.split(':')[1:]).replace(',','').strip() for d in old_hypos]

    flag = not (len(hypos_describe) == len(set(hypos_describe)))
    flag = (len(set(old_hypos_describe)) == len(set(hypos_describe)))
    return flag


def generate_vllm(prompt, model, tokenizer, temperatures, num_samples, stop, max_tokens=256):
    texts, scores = [], []
    for temperature in temperatures:
        params = vllm.SamplingParams(
            n=num_samples,
            temperature=temperature,
            use_beam_search=temperature==0.0,
            max_tokens=max_tokens,
            stop=stop,
        )
        outputs = model.generate([prompt], params, use_tqdm=False)
        if len(outputs) == 0:
            return [], []
        for output in outputs[0].outputs:
            text = output.text.replace(tokenizer.eos_token, '')
            score = output.cumulative_logprob/max(len(output.token_ids), 1)
            texts.append(text)
            scores.append(score)

    texts, scores = _unique_sorted(texts, scores)
    return texts, scores


def _unique_sorted(texts, scores):
    texts_ = []
    scores_ = []
    for t, s in sorted(zip(texts, scores), key=lambda x: -x[1]):
        if t not in texts_:
            texts_.append(t)
            scores_.append(s)
    return texts_, scores_


def _tactic_state(state):
    if isinstance(state, TacticState):
        ts = state.pp
    else:
        ts = state.unsolved_tactic_state
    return ts


def _prompt_fewshot(ts,informal_proof=None,k=12):
    prompt = """
As a mathematician and expert in Lean 3 theorem prover, you should provide {k} tactic(s) helpful toward proving the proof state.
## Each tactic should be enclosed in code block '```lean' and '```'.
## Do not change the current proof state because you are only focusing on solving the current problem. Restate the current state before the tactic
[Current State]{current_state}\n[Tactic]
```lean\nnorm_num\n```
[Current State]{current_state}\n[Tactic]
```lean\nnlinarith\n```
[Current State]{current_state}\n[Tactic]
```lean\nring\n```
[Current State]{current_state}\n[Tactic]
    """.format(current_state = ts,k = k)
    return prompt

def _prompt_fewshot_autoformulation(ts,informal_proof,k=12):
    prompt = """
As a mathematician and expert in Lean 3 theorem prover, you should provide {k} tactic(s) helpful toward proving the proof state.
## The proof written in informal language is:
{informal_proof}
## Each tactic should be enclosed in code block '```lean' and '```'.
## Do not change the current proof state because you are only focusing on solving the current problem. Restate the current state before the tactic
[Current State]{current_state}\n[Tactic]
```lean\nnorm_num\n```
[Current State]{current_state}\n[Tactic]
```lean\nnlinarith\n```
[Current State]{current_state}\n[Tactic]
```lean\nsimp\n```
[Current State]{current_state}\n[Tactic]
    """.format(current_state = ts,informal_proof=informal_proof,k = k)
    return prompt



def _prompt_fn(model_name):
    if 'gpt' in model_name:
        return _prompt_fewshot
        # return _prompt_fewshot_autoformulation

    raise NotImplementedError(model_name)



def best_first_search(
        theorem,
        formal_statement,
        model,
        tokenizer,
        max_iters,
        temperatures,
        num_samples,
        prompt_fn,
        retriver,
        timeout=600,
        early_stop=True,
        max_tokens=256,
        use_hyp = True,
        use_suffice = False,
        use_model_proof = False,
) -> dict:
    """Best first search that tries one attempt per temperature."""
    attempt_results = []
    token_prompt_list = []
    token_compli_list = []
    hypothese_trace = {}
    pseudo_code = ""
    max_reverse_try = 10
    informal_draft_agent = draft.InformalProofDraft(llm_engine='gpt-4-turbo-128k',is_conversation=True)
    informal_statement, informal_proof, token_prompt, token_compli = informal_draft_agent.run(
    formal_statement,
    theorem.full_name,
    exact_match=True,
    )
    print(informal_proof)
    try:
        with Dojo(theorem, hard_timeout=timeout) as (dojo, init_state):
            for temperature in [temperatures]:
                start = time.time()
                proof_finished = False
                queue = [(0.0, [], init_state, [])]
                visited = set()

                for iteration in trange(max_iters):
                    if len(queue) == 0 or proof_finished:
                        break
                    total_score, steps, state, trace = heapq.heappop(queue)
                    ts = _tactic_state(state)
                    logger.info(f"# Current state\n{ts}")
                    prefix_hypoth = []
                    input_prompt = prompt_fn(ts,informal_proof=informal_proof,k=16)
                    # total_score, steps, state, trace =    heapq.heappop(queue)
                    ts = _tactic_state(state)
                    visited.add(ts)

                    _,step_cands,_,_ = request_gpt(
                        prompt_fn(ts,informal_proof=informal_proof,k=16),
                    )
                    tactics = re.findall(r"```lean(.*?)```",step_cands,re.S)#+['nlinarith','norm_num','ring']
                    tactics = [re.sub('/-[\s\S]*?-/','',t).strip()for t in tactics]
                    tactics = [re.sub("--.*",'',t).strip()for t in tactics]
                    tactics = [re.sub('```lean','',t).strip()for t in tactics]
                    tactics = [re.sub('```','',t).strip()for t in tactics]
                    tactics = [re.sub('\n','',t).strip()for t in tactics]
                    logger.info(f"# {len(tactics)} Tactics (unverified)")
                    tactics = list(set(tactics))


                    step_cands = tactics
                    step_cands = [s.strip() for s in step_cands]
                    logger.info(f"# total len {len(step_cands)}")
                    logger.info(f"# steps candidates: {step_cands[:10]}")
                    for step in step_cands:
                        # try:
                        result = dojo.run_tac(state, step)
                        # except:
                        #     print("assertion error?")
                        #     continue
                        step_trace = {
                            "tactic": step,
                            "state_before": _tactic_state(state)
                        }
                        # try to fix lean version failure
                        if isinstance(result, LeanError):
                            # fix only once
                            # for k in LEAN4_2_LEAN3:
                            #     if k in step.split(' '):
                            #         step = step.replace(k,LEAN4_2_LEAN3[k])
                            result = dojo.run_tac(state, step)
                            step_trace = {
                                "tactic": step,
                                "state_before": _tactic_state(state)
                            }
                        if isinstance(result, ProofFinished):
                            attempt_results.append({
                                'theorem': theorem.full_name,
                                'proof': steps + [step],
                                # 'score': total_score - score,
                                'success': True,
                                'failure_reason': '',
                                'trace': trace + [step_trace],
                                'temperature': temperatures,
                                'elapsed': start - time.time(),
                                'iteration': iteration
                            })
                            if early_stop:
                                return attempt_results
                            proof_finished = True
                            break
                        elif isinstance(result, TacticState):
                            if _tactic_state(result) not in visited:
                                # Score is negative log probability summed across steps
                                new_score = len(_tactic_state(result) )
                                step = steps_concat(prefix_hypoth+[step])
                                logger.info(f"# Succ Tac: \n{step}")
                                heapq.heappush(
                                    queue, (new_score, steps+[step], result, trace+[step_trace])
                                )
    except (DojoInitError, DojoHardTimeoutError, DojoCrashError, subprocess.CalledProcessError) as e:
        if len(attempt_results) == 0:
            attempt_results.append({
                'theorem': theorem.full_name,
                'success': False,
                'failure_reason': type(e).__name__
            })

    if len(attempt_results) == 0:
        attempt_results.append({
            'theorem': theorem.full_name,
            'success': False,
            'failure_reason': 'SearchEnded'
        })

    return attempt_results


def _save(model_name, results, args_dict, output_dir, shard):
    Path(output_dir).mkdir(parents=True, exist_ok=True)
    output_file = os.path.join(
        output_dir,
        'results__%s__%s.json' % (model_name.replace('/', '_'), shard)
    )
    with open(output_file, 'w') as f:
        json.dump({
            'results': results,
            'args': args_dict
            }, f, indent=4)
        print(output_file)


def _load_model(model_name):
    if 'pythia' in model_name:
        model = vllm.LLM(
            model=model_name,
            tensor_parallel_size=1,
            dtype='float16'
        )
        tokenizer = transformers.GPTNeoXTokenizerFast.from_pretrained(model_name)
    else:
        raise NotImplementedError(model_name)

    return model, tokenizer



def print_stats(results):
    print(len([x for x in results if x['success']]) / len(results))
    print("# successes: ", len([x for x in results if x['success']]), sep="\t")


def resume_from(results_filename, data):
    results = json.load(open(results_filename))['results']
    data = data[len(results):]
    print("=== Resuming from %d" % (len(results)))
    return results, data


def make_output_dir(output_dir):
    dt = 'demo'
    output_dir = os.path.join(output_dir, dt)
    Path(output_dir).mkdir(parents=True, exist_ok=True)
    return output_dir


if __name__ == '__main__':
    import argparse

    parser = argparse.ArgumentParser()
    parser.add_argument(
        '--model-name',
        default='llmstep',
        # choices=[
        #     'wellecks/llmstep-mathlib4-pythia2.8b',
        # ]
    )
    parser.add_argument(
        '--dataset-name',
        default='minif2f-valid',
        choices=['minif2f-valid', 'minif2f-test', 'leandojo','proofnet']
    )
    parser.add_argument('--shard', type=int, required=True)
    parser.add_argument('--resume-from', type=str, default=None)
    parser.add_argument('--dataset-path', default='data/minif2f.jsonl')
    parser.add_argument('--output-dir', default='output')
    parser.add_argument('--early-stop', action='store_true')
    parser.add_argument('--num-shards', type=int, default=1)
    parser.add_argument('--max-iters', type=int, default=100)
    parser.add_argument('--timeout', type=int, default=600)
    parser.add_argument('--num-examples', type=int, default=-1)
    parser.add_argument('--num-samples', type=int, default=32)
    parser.add_argument('--clear-process-hours', type=int, default=3)
    parser.add_argument('--temperatures', type=float, nargs='+', default=[0.0])
    parser.add_argument('--resume-ckpt', action="store_true")
    parser.add_argument('--retrieval-device',type=int, default=3)
    parser.add_argument('--target-q', action="store_true")
    parser.add_argument('--use-model-proof', action="store_true")
    


    args = parser.parse_args()

    model, tokenizer = None,None

    output_dir = make_output_dir(args.output_dir)
    
    logger.info(f"#args\n{args}")


    repo, data = _load_data(args.dataset_name, args.dataset_path,args.target_q)
    shard_size = len(data) // args.num_shards
    data = data[args.shard*shard_size:(args.shard+1)*shard_size]
    print("Shard size: %d" % (len(data)))

    if args.resume_from is not None:
        results, data = resume_from(args.resume_from, data)
    else:
        results = []

    cache_log_file = _make_save(
        model_name=args.model_name,
        results=results,
        args_dict=args.__dict__,
        output_dir=output_dir,
        shard=args.shard
    )
    skip_examples = []
    resume_ckpt = args.resume_ckpt
    give_up = ['numbertheory_4x3m7y3neq2003','amc12a_2019_p21']
    # give_up = []
    skip_examples = skip_examples+give_up
    if os.path.exists(cache_log_file):
        cache_log_content = json.load(open(cache_log_file))
        for res in cache_log_content['results']:
            if res['success']:
                results.append(res)
                skip_examples.append(res['example']['full_name'])
            elif res['example']['full_name'] in give_up:
                results.append(res)
            elif resume_ckpt :
                results.append(res)
                skip_examples.append(res['example']['full_name'])
    
    logger.info(f"# rest {244-len(skip_examples)}")
    start = time.time()
    for example in tqdm(data, total=len(data)):
        file_path = example['file_path']
        theorem_name = example['full_name']
        formal_statement = example['statement']
        # if "amc12b_2002_p19" not in theorem_name:
        #     continue
        if theorem_name in skip_examples:
            continue
        logger.info(f"Running on {theorem_name}")
        theorem = Theorem(repo, file_path, theorem_name)
        attempt_results = best_first_search(
            theorem, formal_statement,model, tokenizer,
            max_iters=args.max_iters,
            prompt_fn=_prompt_fn(args.model_name),
            temperatures=args.temperatures,
            num_samples=args.num_samples,
            retriver=None,
            timeout=args.timeout,
            early_stop=args.early_stop,
            use_model_proof = args.use_model_proof
        )
        result = {
            'attempt_results': attempt_results,
            'success': any([x['success'] for x in attempt_results]),
            'example': example
        }
        results.append(result)

        _save(
            model_name=args.model_name,
            results=results,
            args_dict=args.__dict__,
            output_dir=output_dir,
            shard=args.shard
        )
        print_stats(results)
        result_success = result['success']
        logger.info(f"Success? {result_success}")
