import json
import heapq
import transformers
import subprocess
import vllm
import time
import sys
import re
from datetime import datetime
from tqdm import tqdm, trange
from pathlib import Path
from lean_dojo import *
# from local_pkg.lean_dojo import *
from loguru import logger
# logger.add(sys.stderr, level="DEBUG")
from utils.success_target import success_target,test_success,valid_success
def _save(model_name, results, args_dict, output_dir, shard):
    Path(output_dir).mkdir(parents=True, exist_ok=True)
    output_file = os.path.join(
        output_dir,
        'results__%s__%s.json' % (model_name.replace('/', '_'), shard)
    )
    with open(output_file, 'w') as f:
        json.dump({
            'results': results,
            'args': args_dict
            }, f, indent=4)
        print(output_file)

def _make_save(model_name, results, args_dict, output_dir, shard):
    Path(output_dir).mkdir(parents=True, exist_ok=True)
    output_file = os.path.join(
        output_dir,
        'results__%s__%s.json' % (model_name.replace('/', '_'), shard)
    )
    return output_file


def generate_vllm(prompt, model, tokenizer, temperatures, num_samples, stop, max_tokens=256):
    texts, scores = [], []
    for temperature in temperatures:
        params = vllm.SamplingParams(
            n=num_samples,
            temperature=temperature,
            use_beam_search=temperature==0.0,
            max_tokens=max_tokens,
            stop=stop,
        )
        outputs = model.generate([prompt], params, use_tqdm=False)
        if len(outputs) == 0:
            return [], []
        for output in outputs[0].outputs:
            text = output.text.replace(tokenizer.eos_token, '')
            score = output.cumulative_logprob/max(len(output.token_ids), 1)
            texts.append(text)
            scores.append(score)

    texts, scores = _unique_sorted(texts, scores)
    return texts, scores


def _unique_sorted(texts, scores):
    texts_ = []
    scores_ = []
    for t, s in sorted(zip(texts, scores), key=lambda x: -x[1]):
        if t not in texts_:
            texts_.append(t)
            scores_.append(s)
    return texts_, scores_


def _tactic_state(state):
    if isinstance(state, TacticState):
        ts = state.pp
    else:
        ts = state.unsolved_tactic_state
    return ts


def _prompt_proofstep(ts):
    prompt = f"[GOAL]{ts}[PROOFSTEP]"
    return prompt


def _prompt_fn(model_name):
    if 'pythia' in model_name:
        return _prompt_proofstep
    return ""


def print_stats(results):
    logger.info(len([x for x in results if x['success']]) / len(results))
    logger.info("# successes: ", len([x for x in results if x['success']]), sep="\t")


def resume_from(results_filename, data):
    results = json.load(open(results_filename))['results']
    data = data[len(results):]
    print("=== Resuming from %d" % (len(results)))
    return results, data


def make_output_dir(output_dir,dt=None):
    if dt is None:
        dt = datetime.now().strftime("demo")
    else:
        dt  = "demo"
    output_dir = os.path.join(output_dir, dt)
    Path(output_dir).mkdir(parents=True, exist_ok=True)
    return output_dir



def _load_data(dataset_name, dataset_path, use_target_q=False):
    if dataset_name == 'leandojo':
        URL = "https://github.com/leanprover-community/mathlib4"
        repo = LeanGitRepo(URL, COMMIT)
        with open(dataset_path) as f:
            data = json.load(f)
    elif 'minif2f' in dataset_name:
        data = []
        with open(dataset_path) as f:
            for line in f.readlines():
                data_ = json.loads(line)
                data.append(data_)

        if 'valid' in dataset_name:
            target_q = valid_success
            data = [x for x in data if x['split'] == 'valid']
        else:
            target_q = test_success
            data = [x for x in data if x['split'] == 'test']
        if use_target_q:
            data = [x for x in data if x['full_name'] in target_q]
        repo = LeanGitRepo(data[0]['url'], data[0]['commit'])
    else:
        raise NotImplementedError(dataset_name)

    return repo, data


def hypo_overlap_check(state,old_state):
    hypos = [s for s in state.split('\n') if not s.startswith('⊢')]
    hypos_describe = [':'.join(d.split(':')[1:]).replace(',','').strip() for d in hypos]

    old_hypos = [s for s in old_state.split('\n') if not s.startswith('⊢')]
    old_hypos_describe = [':'.join(d.split(':')[1:]).replace(',','').strip() for d in old_hypos]

    flag = not (len(hypos_describe) == len(set(hypos_describe)))
    flag = (len(set(old_hypos_describe)) == len(set(hypos_describe)))
    return flag