# Lean proof search with LeanDojo interaction
# Author: Sean Welleck
import os
os.environ['TOKENIZERS_PARALLELISM'] = 'false'
os.environ['CONTAINER'] = 'native'
import json
import heapq
import transformers
import subprocess
import vllm
import time
import sys
from datetime import datetime
from tqdm import tqdm, trange
from pathlib import Path
from lean_dojo import *
# from local_pkg.lean_dojo import *
from loguru import logger
logger.remove()
logger.add(sys.stderr, level="DEBUG")

from agent import draft,prove
from agent.gpt_api import request_gpt


def generate_vllm(prompt, model, tokenizer, temperatures, num_samples, stop, max_tokens=256):
    texts, scores = [], []
    for temperature in temperatures:
        params = vllm.SamplingParams(
            n=num_samples,
            temperature=temperature,
            use_beam_search=temperature==0.0,
            max_tokens=max_tokens,
            stop=stop,
        )
        outputs = model.generate([prompt], params, use_tqdm=False)
        if len(outputs) == 0:
            return [], []
        for output in outputs[0].outputs:
            text = output.text.replace(tokenizer.eos_token, '')
            score = output.cumulative_logprob/max(len(output.token_ids), 1)
            texts.append(text)
            scores.append(score)

    texts, scores = _unique_sorted(texts, scores)
    return texts, scores


def _unique_sorted(texts, scores):
    texts_ = []
    scores_ = []
    for t, s in sorted(zip(texts, scores), key=lambda x: -x[1]):
        if t not in texts_:
            texts_.append(t)
            scores_.append(s)
    return texts_, scores_


def _tactic_state(state):
    if isinstance(state, TacticState):
        ts = state.pp
    else:
        ts = state.unsolved_tactic_state
    return ts


def _prompt_proofstep(ts):
    prompt = f"[GOAL]{ts}[PROOFSTEP]"
    return prompt


def _prompt_fn(model_name):
    if 'pythia' in model_name:
        return _prompt_proofstep
    raise NotImplementedError(model_name)


def agent_search(
        theorem,
        formal_statement,
        max_iters,
        temperatures,
        num_samples,
        prompt_fn,
        timeout=600,
        early_stop=True,
        max_tokens=256
) -> dict:
    """Best first search that tries one attempt per temperature."""
    attempt_results = []
    token_prompt_list = []
    token_compli_list = []
    informal_draft_agent = draft.InformalProofDraft()

    informal_statement, informal_proof , token_prompt, token_compli = informal_draft_agent.run(formal_statement,theorem.full_name)
    token_prompt_list.append(token_prompt)
    token_compli_list.append(token_compli)
    error_db = []
    try:
        with Dojo(theorem, hard_timeout=timeout) as (dojo, init_state):
            for temperature in [temperatures]:
                start = time.time()
                proof_finished = False
                # queue = [(0.0, [], [init_state], []),(0.0, [], [init_state], []),(0.0, [], [init_state], [])]
                queue = [(0.0, [], [init_state], [])]
                visited = set()
                formal_prove_agent = prove.FormalProof('gpt-3.5-turbo')
                if len(queue) == 0 or proof_finished:
                    break
                dfs_input_prompt = formal_prove_agent.prepare_input_prompt(formal_statement,
                                                                            informal_statement,
                                                                            theorem.full_name)
                total_score, steps, state, trace = heapq.heappop(queue)
                state = state[0]
                ts = _tactic_state(state)
                visited.add(ts)
                _,dfs_proof,token_prompt, token_compli = request_gpt(dfs_input_prompt)
                token_prompt_list.append(token_prompt)
                token_compli_list.append(token_compli)
                dfs_steps = prove.parse_formal_steps(dfs_proof)

                dfs_steps = [s.strip() for s in dfs_steps]
                ckpt_state = state
                for step in dfs_steps:
                    state_b = state
                    # 需要处理RunTime error
                    result = dojo.run_tac(state, step) # bfs
                    state = result
                    step_trace = {
                        "tactic": step,
                        "state_before": _tactic_state(state_b)
                    }
                    if isinstance(result, ProofFinished):
                        attempt_results.append({
                            'theorem': theorem.full_name,
                            'proof': steps + [step],
                            'score': 0,
                            'success': True,
                            'failure_reason': '',
                            'trace': trace + [step_trace],
                            'temperature': temperatures,
                            'elapsed': start - time.time(),
                            'iteration': 0,
                            "token_prompt_total":sum(token_prompt_list),
                            "token_compli_total":sum(token_compli_list)
                        })
                        if early_stop:
                            return attempt_results
                        proof_finished = True
                        break
                    elif isinstance(result, TacticState):
                        if _tactic_state(result) not in visited:
                            # Score is negative log probability summed across steps
                            # new_score = (total_score - score)
                            heapq.heappush(
                                queue, (0, steps+[step], [result], trace+[step_trace])
                            )
                    elif isinstance(result,LeanError):
                        logger.info(result)
                        break
    except (DojoInitError, DojoHardTimeoutError, DojoCrashError, subprocess.CalledProcessError,RuntimeError) as e:
        # if len(attempt_results) == 0:
            logger.info(type(e).__name__)
            attempt_results.append({
                'theorem': theorem.full_name,
                'success': False,
                'failure_reason': type(e).__name__
            })

    if len(attempt_results) == 0:
        attempt_results.append({
            'theorem': theorem.full_name,
            'success': False,
            'failure_reason': 'SearchEnded',
            "token_prompt_total":sum(token_prompt_list),
            "token_compli_total":sum(token_compli_list)
        })

    return attempt_results



def print_stats(results):
    logger.info(len([x for x in results if x['success']]) / len(results))
    logger.info("# successes: ", len([x for x in results if x['success']]), sep="\t")


def resume_from(results_filename, data):
    results = json.load(open(results_filename))['results']
    data = data[len(results):]
    print("=== Resuming from %d" % (len(results)))
    return results, data


def make_output_dir(output_dir):
    dt = datetime.now().strftime("demo")
    output_dir = os.path.join(output_dir, dt)
    Path(output_dir).mkdir(parents=True, exist_ok=True)
    return output_dir



def _load_data(dataset_name, dataset_path):
    if dataset_name == 'leandojo':
        URL = "https://github.com/leanprover-community/mathlib4"
        COMMIT = "5a919533f110b7d76410134a237ee374f24eaaad"
        repo = LeanGitRepo(URL, COMMIT)
        with open(dataset_path) as f:
            data = json.load(f)
    elif 'minif2f' in dataset_name:
        data = []
        with open(dataset_path) as f:
            for line in f.readlines():
                data_ = json.loads(line)
                # assert data_['commit'] == 'd00c776260c77de7e70125ef0cd119de6c0ff1de'
                data.append(data_)

        if 'valid' in dataset_name:
            data = [x for x in data if x['split'] == 'valid']
        else:
            data = [x for x in data if x['split'] == 'test']
        repo = LeanGitRepo(data[0]['url'], data[0]['commit'])
    else:
        raise NotImplementedError(dataset_name)

    return repo, data


if __name__ == '__main__':
    import argparse

    parser = argparse.ArgumentParser()
    parser.add_argument(
        '--model-name',
        default='wellecks/llmstep-mathlib4-pythia2.8b',
        # choices=[
        #     'wellecks/llmstep-mathlib4-pythia2.8b',
        # ]
    )
    parser.add_argument(
        '--dataset-name',
        default='minif2f-valid',
        choices=['minif2f-valid', 'minif2f-test', 'leandojo']
    )
    parser.add_argument('--shard', type=int, required=True)
    parser.add_argument('--resume-from', type=str, default=None)
    parser.add_argument('--dataset-path', default='data/minif2f.jsonl')
    parser.add_argument('--output-dir', default='output/')
    parser.add_argument('--early-stop', action='store_true')
    parser.add_argument('--num-shards', type=int, default=1)
    parser.add_argument('--max-iters', type=int, default=100)
    parser.add_argument('--timeout', type=int, default=600)
    parser.add_argument('--num-examples', type=int, default=-1)
    parser.add_argument('--num-samples', type=int, default=32)
    parser.add_argument('--clear-process-hours', type=int, default=3)
    parser.add_argument('--temperatures', type=float, nargs='+', default=[0.0, 0.5])
    args = parser.parse_args()

    output_dir = make_output_dir(args.output_dir)

    repo, data = _load_data(args.dataset_name, args.dataset_path)
    shard_size = len(data) // args.num_shards
    data = data[args.shard*shard_size:(args.shard+1)*shard_size]
    print("Shard size: %d" % (len(data)))

    if args.resume_from is not None:
        results, data = resume_from(args.resume_from, data)
    else:
        results = []

    start = time.time()
    for example in tqdm(data, total=len(data)):
        file_path = example['file_path']
        theorem_name = example['full_name']
        formal_statement = example['statement']
        theorem = Theorem(repo, file_path, theorem_name)
        attempt_results = agent_search(
            theorem,
            formal_statement = formal_statement,
            max_iters=args.max_iters,
            prompt_fn=_prompt_fn(args.model_name),
            temperatures=args.temperatures,
            num_samples=args.num_samples,
            timeout=args.timeout,
            early_stop=args.early_stop
        )
        result = {
            'attempt_results': attempt_results,
            'success': any([x['success'] for x in attempt_results]),
            'example': example
        }
        results.append(result)


        print_stats(results)

