import os
os.environ['TOKENIZERS_PARALLELISM'] = 'false'
os.environ['CONTAINER'] = 'native'
# os.environ['NUM_PROCS'] = '1'
import json
import heapq
import transformers
import subprocess
import vllm
import time
from random import random
import sys
import re
from datetime import datetime
from tqdm import tqdm, trange
from pathlib import Path
from lean_dojo import *
# from local_pkg.lean_dojo import *
from loguru import logger
logger.add(
    "output/log/log_{time}.log", 
    encoding="utf-8", 
    enqueue=True, 
    rotation="200MB", 
    compression="zip",
    level="INFO")
from agent import draft,prove,reverse
from agent.gpt_api import request_gpt
from utils.common import generate_vllm,_load_data,_prompt_fn,_prompt_proofstep,_tactic_state,_unique_sorted
from utils.common import print_stats,resume_from,make_output_dir,_save,_make_save
from utils.error_history import save_error_log,save_informal_log,save_pseudo_log
from re_assist import KnowledgeRetrieval
from db.reprover_assist_db.reprover_gen import ReproverTacticGen
from db.mathlib_db.build_copra_aug import Lean3Bm25ReRanker,Lean3EncodeRanker

def steps_concat(steps:list):
    concated_steps = ','.join(steps)
    concated_steps_str = concated_steps.replace(',,',',')
    return concated_steps_str

def hypo_overlap_check(state,old_state):
    hypos = [s for s in state.split('\n') if not s.startswith('⊢')]
    hypos_describe = [':'.join(d.split(':')[1:]).replace(',','').strip() for d in hypos]

    old_hypos = [s for s in old_state.split('\n') if not s.startswith('⊢')]
    old_hypos_describe = [':'.join(d.split(':')[1:]).replace(',','').strip() for d in old_hypos]

    flag = not (len(hypos_describe) == len(set(hypos_describe)))
    flag = (len(set(old_hypos_describe)) == len(set(hypos_describe)))
    return flag

def generate_vllm_dojo(prompt, model, tokenizer, temperatures, num_samples, stop, max_tokens=256):
    texts, scores = [], []
    for temperature in temperatures:


        outputs, seq_scores = model.tactic_gen(prompt,output_scores=True,num_response=num_samples)

        if len(outputs) == 0:
            return [], []
        texts.extend(outputs)
        scores.extend(seq_scores)
    return texts, scores


def _unique_sorted(texts, scores):
    texts_ = []
    scores_ = []
    for t, s in sorted(zip(texts, scores), key=lambda x: -x[1]):
        if t not in texts_:
            texts_.append(t)
            scores_.append(s)
    return texts_, scores_


def _tactic_state(state):
    if isinstance(state, TacticState):
        ts = state.pp
    else:
        ts = state.unsolved_tactic_state
    return ts


def _prompt_proofstep(ts):
    prompt = f"[GOAL]{ts}[PROOFSTEP]"
    return ts


def _prompt_fn(model_name):
    if 'dojo' in model_name:
        return _prompt_proofstep
    
    raise NotImplementedError(model_name)



def best_first_search(
        theorem,
        formal_statement,
        model,
        tokenizer,
        max_iters,
        temperatures,
        num_samples,
        prompt_fn,
        retriver,
        timeout=600,
        early_stop=True,
        max_tokens=256,
        use_hyp = True,
        use_suffice = False,
        use_model_proof = False,
) -> dict:
    """Best first search that tries one attempt per temperature."""
    attempt_results = []
    token_prompt_list = []
    token_compli_list = []
    informal_draft_agent = draft.InformalProofDraft(llm_engine='gpt-4-turbo-128k',is_conversation=True,temperature=0.2)
    
    pseudo_agent = draft.PseudoProofDraft(llm_engine='gpt-4-turbo-128k',is_conversation=False,temperature=0.2)
    nextstep_align_agent = draft.AlignProofDraft(llm_engine='gpt-4-turbo-128k',is_conversation=False,temperature=0.2)
    state_explain_agent = draft.StateExplainer(llm_engine="gpt-4-turbo-128k",is_conversation=False,temperature=0.2)
    reverse_have_agent = reverse.ReverseHypo(curiosity='lv1',k=8,is_conversation=False,temperature=0.0)


    
    informal_statement, informal_proof, token_prompt, token_compli = informal_draft_agent.run(
        formal_statement,
        theorem.full_name,
        exact_match=False
        
        )
    logger.info(f"# informal proof: \n{informal_proof}")
    token_prompt_list.append(token_prompt)
    token_compli_list.append(token_compli)
    full_pseudo_code, token_prompt, token_compli = pseudo_agent.run(informal_statement,informal_proof,formal_statement)
    # pseudo_code_aglin, _,_ = nextstep_align_agent.run(informal_statement,informal_proof,formal_statement)
    # pseudo_code = 
    token_prompt_list.append(token_prompt)
    token_compli_list.append(token_compli)
    # key state; value: [lemma expressions]
    hypothese_trace = {}
    pseudo_code = ""
    max_reverse_try = 10
    try:
        with Dojo(theorem, hard_timeout=timeout) as (dojo, init_state):
            for temperature in [temperatures]:
                start = time.time()
                proof_finished = False
                queue = [(0.0, [], init_state, [])]
                visited = set()

                for iteration in trange(max_iters):
                    if len(queue) == 0 or proof_finished:
                        break
                    total_score, steps, state, trace = heapq.heappop(queue)
                    ts = _tactic_state(state)
                    logger.info(f"# Current state\n{ts}")
                    prefix_hypoth = []
                    ckpt_state = state
                    if use_hyp and max_reverse_try:
                        max_reverse_try-=1

                        hypoth_tactics, token_prompt, token_compli = reverse_have_agent.run(
                            ts,#+'\n'+pseudo_code,
                            full_pseudo_code, 
                            informal_statement, 
                            informal_proof,
                            # premise = re_premises,
                            )

                        # add these tactics into current hypotheses
                        token_prompt_list.append(token_prompt)
                        token_compli_list.append(token_compli)
                        
                        for step in hypoth_tactics:
                            result = dojo.run_tac(state, step) # dfs
                            step_trace = {
                                "tactic": step,
                                "state_before": _tactic_state(state)
                            }
                            if isinstance(result, ProofFinished):
                                raise RuntimeError
                            elif isinstance(result, TacticState):
                                overlap_flag = hypo_overlap_check(_tactic_state(result),_tactic_state(state))
                                if not overlap_flag and 'sorry' not in step:
                                    logger.critical(f"# New phypo\n{step}")
                                    prefix_hypoth.append(step)
                                    state = result
                    ts = _tactic_state(state)
                    ts = _tactic_state(state)
                    visited.add(ts)

                    step_cands, step_scores = generate_vllm_dojo(
                        prompt_fn(ts),
                        model,
                        tokenizer,
                        temperature,
                        num_samples=16,
                        stop=tokenizer.eos_token,
                        max_tokens=max_tokens
                    )
                    step_cands = [s.strip() for s in step_cands]
                    logger.info(f"#tactic\n{step_cands[:10]} ......")
                    for step, score in zip(step_cands, step_scores):
                        result = dojo.run_tac(state, step)
                        step_trace = {
                            "tactic": step,
                            "state_before": _tactic_state(state)
                        }
                        if isinstance(result, ProofFinished):
                            step = steps_concat(prefix_hypoth+[step])
                            attempt_results.append({
                                'theorem': theorem.full_name,
                                'proof': steps + [step],
                                'score': total_score - score,
                                'success': True,
                                'failure_reason': '',
                                'trace': trace + [step_trace],
                                'temperature': temperatures,
                                'elapsed': start - time.time(),
                                'iteration': iteration,
                                'prompt_tokens':sum(token_prompt_list),
                                'compli_tokens':sum(token_compli_list)
                            })
                            if early_stop:
                                return attempt_results
                            proof_finished = True
                            break
                        elif isinstance(result, TacticState):
                            if _tactic_state(result) not in visited:
                                # Score is negative log probability summed across steps
                                new_score = (total_score - score)
                                step = steps_concat(prefix_hypoth+[step])
                                logger.info(f"# Succ Tac: \n{step}")
                                heapq.heappush(
                                    queue, (new_score, steps+[step], result, trace+[step_trace])
                                )
    except (DojoInitError, DojoHardTimeoutError, DojoCrashError, subprocess.CalledProcessError) as e:
        if len(attempt_results) == 0:
            attempt_results.append({
                'theorem': theorem.full_name,
                'success': False,
                'failure_reason': type(e).__name__,
                'prompt_tokens':sum(token_prompt_list),
                'compli_tokens':sum(token_compli_list)
            })

    if len(attempt_results) == 0:
        attempt_results.append({
            'theorem': theorem.full_name,
            'success': False,
            'failure_reason': 'SearchEnded',
            'prompt_tokens':sum(token_prompt_list),
            'compli_tokens':sum(token_compli_list)
        })

    return attempt_results


def _save(model_name, results, args_dict, output_dir, shard):
    Path(output_dir).mkdir(parents=True, exist_ok=True)
    output_file = os.path.join(
        output_dir,
        'results__%s__%s.json' % (model_name.replace('/', '_'), shard)
    )
    with open(output_file, 'w') as f:
        json.dump({
            'results': results,
            'args': args_dict
            }, f, indent=4)
        print(output_file)


def _load_model(model_name):
    if 'pythia' in model_name:
        model = vllm.LLM(
            model=model_name,
            tensor_parallel_size=1,
            dtype='float16'
        )
        tokenizer = transformers.GPTNeoXTokenizerFast.from_pretrained(model_name)
    if 'dojo' in model_name:
        nn_ranker = Lean3EncodeRanker(dst_dir=\
        "db/mathlib_db/leandojo_minif2f",
        device_id=7)
        model = ReproverTacticGen(emb_ranker=nn_ranker,device_id=5)
        tokenizer = model.tokenizer
    else:
        raise NotImplementedError(model_name)

    return model, tokenizer



def print_stats(results):
    print(len([x for x in results if x['success']]) / len(results))
    print("# successes: ", len([x for x in results if x['success']]), sep="\t")


def resume_from(results_filename, data):
    results = json.load(open(results_filename))['results']
    data = data[len(results):]
    print("=== Resuming from %d" % (len(results)))
    return results, data


def make_output_dir(output_dir):
    dt = 'demo'
    output_dir = os.path.join(output_dir, dt)
    Path(output_dir).mkdir(parents=True, exist_ok=True)
    return output_dir


if __name__ == '__main__':
    import argparse

    parser = argparse.ArgumentParser()
    parser.add_argument(
        '--model-name',
        default='dojo',
        # choices=[
        #     'wellecks/llmstep-mathlib4-pythia2.8b',
        # ]
    )
    parser.add_argument(
        '--dataset-name',
        default='minif2f-valid',
        choices=['minif2f-valid', 'minif2f-test', 'leandojo','proofnet']
    )
    parser.add_argument('--shard', type=int, required=True)
    parser.add_argument('--resume-from', type=str, default=None)
    parser.add_argument('--dataset-path', default='data/minif2f.jsonl')
    parser.add_argument('--output-dir', default='output')
    parser.add_argument('--early-stop', action='store_true')
    parser.add_argument('--num-shards', type=int, default=1)
    parser.add_argument('--max-iters', type=int, default=100)
    parser.add_argument('--timeout', type=int, default=600)
    parser.add_argument('--num-examples', type=int, default=-1)
    parser.add_argument('--num-samples', type=int, default=32)
    parser.add_argument('--clear-process-hours', type=int, default=3)
    parser.add_argument('--temperatures', type=float, nargs='+', default=[0.0])
    parser.add_argument('--resume-ckpt', action="store_true")
    parser.add_argument('--retrieval-device',type=int, default=3)
    parser.add_argument('--target-q', action="store_true")
    parser.add_argument('--use-model-proof', action="store_true")
    


    args = parser.parse_args()

    model, tokenizer = _load_model(args.model_name)

    output_dir = make_output_dir(args.output_dir)
    
    logger.info(f"#args\n{args}")


    repo, data = _load_data(args.dataset_name, args.dataset_path,args.target_q)
    shard_size = len(data) // args.num_shards
    # data = data[args.shard*shard_size:(args.shard+1)*shard_size]
    print("Shard size: %d" % (len(data)))
    
    if args.resume_from is not None:
        results, data = resume_from(args.resume_from, data)
    else:
        results = []

    cache_log_file = _make_save(
        model_name=args.model_name,
        results=results,
        args_dict=args.__dict__,
        output_dir=output_dir,
        shard=args.shard
    )
    skip_examples = []
    resume_ckpt = args.resume_ckpt
    # give_up = ['numbertheory_4x3m7y3neq2003','amc12a_2019_p21']
    give_up = []
    skip_examples = skip_examples+give_up
    if os.path.exists(cache_log_file):
        cache_log_content = json.load(open(cache_log_file))
        for res in cache_log_content['results']:
            if res['success']:
                results.append(res)
                skip_examples.append(res['example']['full_name'])
            elif res['example']['full_name'] in give_up:
                results.append(res)
            elif resume_ckpt :
                results.append(res)
                skip_examples.append(res['example']['full_name'])
    
    retriver = KnowledgeRetrieval(device_id=4,is_conversation=False)
    logger.info(f"# rest {len(data)-len(skip_examples)}")
    start = time.time()
    for example in tqdm(data, total=len(data)):
        file_path = example['file_path']
        theorem_name = example['full_name']
        if theorem_name in skip_examples:
            continue
        formal_statement = example['statement']
        logger.info(f"Running on {theorem_name}")
        theorem = Theorem(repo, file_path, theorem_name)
        attempt_results = best_first_search(
            theorem, formal_statement,model, tokenizer,
            max_iters=args.max_iters,
            prompt_fn=_prompt_fn(args.model_name),
            temperatures=args.temperatures,
            num_samples=args.num_samples,
            retriver=retriver,
            timeout=args.timeout,
            early_stop=args.early_stop,
            use_model_proof = args.use_model_proof
        )
        result = {
            'attempt_results': attempt_results,
            'success': any([x['success'] for x in attempt_results]),
            'example': example
        }
        results.append(result)

        _save(
            model_name=args.model_name,
            results=results,
            args_dict=args.__dict__,
            output_dir=output_dir,
            shard=args.shard
        )
        print_stats(results)
        result_success = result['success']
        logger.info(f"Success? {result_success}")


