import os
os.environ['TOKENIZERS_PARALLELISM'] = 'false'
os.environ['CONTAINER'] = 'native'
os.environ["PATH"] = "/home/jupyter/.elan/bin:"+os.getenv("PATH")
os.environ['NUM_PROCS'] = '1'
import json
import heapq
import transformers
import subprocess
import vllm
import time
import sys
from datetime import datetime
from tqdm import tqdm, trange
from pathlib import Path
from lean_dojo import *
# from local_pkg.lean_dojo import *
from loguru import logger
from db.reprover_assist_db.reprover_gen import ReproverTacticGen
from db.mathlib_db.build_copra_aug import Lean3Bm25ReRanker,Lean3EncodeRanker
from utils.common import print_stats,resume_from,make_output_dir,_save,_make_save,hypo_overlap_check

def generate_vllm_dojo(prompt, model, tokenizer, temperatures, num_samples, stop, max_tokens=256):
    texts, scores = [], []
    for temperature in temperatures:


        outputs, seq_scores = model.tactic_gen(prompt,output_scores=True,num_response=num_samples)

        if len(outputs) == 0:
            return [], []
        texts.extend(outputs)
        scores.extend(seq_scores)
    return texts, scores


def _unique_sorted(texts, scores):
    texts_ = []
    scores_ = []
    for t, s in sorted(zip(texts, scores), key=lambda x: -x[1]):
        if t not in texts_:
            texts_.append(t)
            scores_.append(s)
    return texts_, scores_


def _tactic_state(state):
    if isinstance(state, TacticState):
        ts = state.pp
    else:
        ts = state.unsolved_tactic_state
    return ts


def _prompt_proofstep(ts):
    prompt = f"[GOAL]{ts}[PROOFSTEP]"
    return ts


def _prompt_fn(model_name):
    if 'dojo' in model_name:
        return _prompt_proofstep
    
    raise NotImplementedError(model_name)


def best_first_search(
        theorem,
        model,
        tokenizer,
        max_iters,
        temperatures,
        num_samples,
        prompt_fn,
        timeout=600,
        early_stop=True,
        max_tokens=256
) -> dict:
    """Best first search that tries one attempt per temperature."""
    attempt_results = []
    try:
        with Dojo(theorem, hard_timeout=timeout) as (dojo, init_state):
            for temperature in [temperatures]:
                start = time.time()
                proof_finished = False
                queue = [(0.0, [], init_state, [])]
                visited = set()

                for iteration in trange(max_iters):
                    if len(queue) == 0 or proof_finished:
                        break

                    total_score, steps, state, trace = heapq.heappop(queue)
                    ts = _tactic_state(state)
                    visited.add(ts)

                    step_cands, step_scores = generate_vllm_dojo(
                        prompt_fn(ts),
                        model,
                        tokenizer,
                        temperature,
                        num_samples,
                        stop=tokenizer.eos_token,
                        max_tokens=max_tokens
                    )
                    step_cands = [s.strip() for s in step_cands]
                    logger.info(f"#total tac len\n{len(step_cands)} ......")

                    logger.info(f"#tactic\n{step_cands[:10]} ......")
                    for step, score in zip(step_cands, step_scores):
                        result = dojo.run_tac(state, step)
                        step_trace = {
                            "tactic": step,
                            "state_before": _tactic_state(state)
                        }
                        if isinstance(result, ProofFinished):
                            attempt_results.append({
                                'theorem': theorem.full_name,
                                'proof': steps + [step],
                                'score': total_score - score,
                                'success': True,
                                'failure_reason': '',
                                'trace': trace + [step_trace],
                                'temperature': temperatures,
                                'elapsed': start - time.time(),
                                'iteration': iteration
                            })
                            if early_stop:
                                return attempt_results
                            proof_finished = True
                            break
                        elif isinstance(result, TacticState):
                            if _tactic_state(result) not in visited:
                                # Score is negative log probability summed across steps
                                new_score = (total_score - score)
                                heapq.heappush(
                                    queue, (new_score, steps+[step], result, trace+[step_trace])
                                )
    except (DojoInitError, DojoHardTimeoutError, DojoCrashError, subprocess.CalledProcessError) as e:
        if len(attempt_results) == 0:
            attempt_results.append({
                'theorem': theorem.full_name,
                'success': False,
                'failure_reason': type(e).__name__
            })

    if len(attempt_results) == 0:
        attempt_results.append({
            'theorem': theorem.full_name,
            'success': False,
            'failure_reason': 'SearchEnded'
        })

    return attempt_results


def _save(model_name, results, args_dict, output_dir, shard):
    Path(output_dir).mkdir(parents=True, exist_ok=True)
    output_file = os.path.join(
        output_dir,
        'results__%s__%s.json' % (model_name.replace('/', '_'), shard)
    )
    with open(output_file, 'w') as f:
        json.dump({
            'results': results,
            'args': args_dict
            }, f, indent=4)
        print(output_file)


def _load_model(model_name):
    if 'pythia' in model_name:
        model = vllm.LLM(
            model=model_name,
            tensor_parallel_size=1,
            dtype='float16'
        )
        tokenizer = transformers.GPTNeoXTokenizerFast.from_pretrained(model_name)
    if 'dojo' in model_name:
        nn_ranker = Lean3EncodeRanker(dst_dir=\
        "db/mathlib_db/leandojo_minif2f",
        device_id=7)
        model = ReproverTacticGen(emb_ranker=nn_ranker,device_id=7)
        tokenizer = model.tokenizer
    else:
        raise NotImplementedError(model_name)

    return model, tokenizer


def _load_data(dataset_name, dataset_path):
    if dataset_name == 'leandojo':
        URL = "https://github.com/leanprover-community/mathlib4"
        COMMIT = "5a919533f110b7d76410134a237ee374f24eaaad"
        repo = LeanGitRepo(URL, COMMIT)
        with open(dataset_path) as f:
            data = json.load(f)
    elif 'minif2f' in dataset_name:
        data = []
        with open(dataset_path) as f:
            for line in f.readlines():
                data_ = json.loads(line)
                # assert data_['commit'] == 'd00c776260c77de7e70125ef0cd119de6c0ff1de'
                data.append(data_)

        if 'valid' in dataset_name:
            data = [x for x in data if x['split'] == 'valid']
        else:
            data = [x for x in data if x['split'] == 'test']
        repo = LeanGitRepo(data[0]['url'], data[0]['commit'])
    elif 'proofnet' in dataset_name:
        data = []
        with open(dataset_path) as f:
            for line in f.readlines():
                data_ = json.loads(line)
                # assert data_['commit'] == 'd00c776260c77de7e70125ef0cd119de6c0ff1de'
                data.append(data_)
        repo = LeanGitRepo(data[0]['url'], data[0]['commit'])
    else:
        raise NotImplementedError(dataset_name)

    return repo, data


def print_stats(results):
    print(len([x for x in results if x['success']]) / len(results))
    print("# successes: ", len([x for x in results if x['success']]), sep="\t")


def resume_from(results_filename, data):
    results = json.load(open(results_filename))['results']
    data = data[len(results):]
    print("=== Resuming from %d" % (len(results)))
    return results, data


def make_output_dir(output_dir):
    dt = 'demo'
    output_dir = os.path.join(output_dir, dt)
    Path(output_dir).mkdir(parents=True, exist_ok=True)
    return output_dir


if __name__ == '__main__':
    import argparse

    parser = argparse.ArgumentParser()
    parser.add_argument(
        '--model-name',
        default='dojo',
        # choices=[
        #     'wellecks/llmstep-mathlib4-pythia2.8b',
        # ]
    )
    parser.add_argument(
        '--dataset-name',
        default='minif2f-valid',
        choices=['minif2f-valid', 'minif2f-test', 'leandojo','proofnet']
    )
    parser.add_argument('--shard', type=int, required=True)
    parser.add_argument('--resume-from', type=str, default=None)
    parser.add_argument('--dataset-path', default='data/minif2f.jsonl')
    parser.add_argument('--output-dir', default='output')
    parser.add_argument('--early-stop', action='store_true')
    parser.add_argument('--num-shards', type=int, default=1)
    parser.add_argument('--max-iters', type=int, default=100)
    parser.add_argument('--timeout', type=int, default=600)
    parser.add_argument('--num-examples', type=int, default=-1)
    parser.add_argument('--num-samples', type=int, default=32)
    parser.add_argument('--clear-process-hours', type=int, default=3)
    parser.add_argument('--temperatures', type=float, nargs='+', default=[0.0, 0.5])
    args = parser.parse_args()

    model, tokenizer = _load_model(args.model_name)

    output_dir = make_output_dir(args.output_dir)

    repo, data = _load_data(args.dataset_name, args.dataset_path)
    shard_size = len(data) // args.num_shards
    data = data[args.shard*shard_size:(args.shard+1)*shard_size]
    print("Shard size: %d" % (len(data)))

    if args.resume_from is not None:
        results, data = resume_from(args.resume_from, data)
    else:
        results = []

    cache_log_file = _make_save(
        model_name=args.model_name,
        results=results,
        args_dict=args.__dict__,
        output_dir=output_dir,
        shard=args.shard
    )
    skip_examples = []
    give_up = []
    skip_examples = skip_examples+give_up
    if os.path.exists(cache_log_file):
        cache_log_content = json.load(open(cache_log_file))
        for res in cache_log_content['results']:
            if res['success']:
                results.append(res)
                skip_examples.append(res['example']['full_name'])
            elif res['example']['full_name'] in give_up:
                results.append(res)
            elif True :
                results.append(res)
                skip_examples.append(res['example']['full_name'])
    
    start = time.time()
    print(f"rest: {244-len(skip_examples)}")
    for example in tqdm(data, total=len(data)):
        file_path = example['file_path']
        theorem_name = example['full_name']
        if theorem_name in skip_examples:
            continue
        formal_statement = example['statement']
        logger.info(f"Running on {theorem_name}")
        theorem = Theorem(repo, file_path, theorem_name)
        attempt_results = best_first_search(
            theorem, model, tokenizer,
            max_iters=args.max_iters,
            prompt_fn=_prompt_fn(args.model_name),
            temperatures=args.temperatures,
            num_samples=args.num_samples,
            timeout=args.timeout,
            early_stop=args.early_stop
        )
        result = {
            'attempt_results': attempt_results,
            'success': any([x['success'] for x in attempt_results]),
            'example': example
        }
        results.append(result)

        _save(
            model_name=args.model_name,
            results=results,
            args_dict=args.__dict__,
            output_dir=output_dir,
            shard=args.shard
        )
        print_stats(results)
        result_success = result['success']
        logger.info(f"Success? {result_success}")

