
import sys
from langchain import PromptTemplate
import json
import random
from agent.gpt_api import request_gpt
from glob import glob
from db.error_db.code.build_error_aug import ErrorEncodeRanker
import torch
from typing import Union, List
from transformers import AutoTokenizer, AutoModelForTextEncoding
from db.mathlib_db.build_copra_aug import Lean3Bm25ReRanker,Lean3EncodeRanker
import re
from loguru import logger

class TemplateManager:

    def __init__(self):
        self.templates = {
            
        }

    def get_template(self, template_name):
        return self.templates.get(template_name, None)

class ErrorSummarizer:
    def __init__ (
        self, re_model:AutoModelForTextEncoding, 
        inner_k = 10, outer_k = 2, llm_engine="gpt-4-turbo-128k",curiosity = "lv2"
        ):
        self.template_manager = TemplateManager()
        self.error_encode_ranker = ErrorEncodeRanker(model=re_model)
        self.error_encode_ranker.reindex()
        self.inner_k = inner_k
        self.outer_k = outer_k
        self.llm_engine = llm_engine
        self.curiosity = curiosity

    # def query_
    def extend_error_db(self,error_msg:dict,theorem_name):
        self.error_encode_ranker.append_sample(error_msg,theorem_name=theorem_name)

    def prepare_errors(self, state, theorem_name):
        scores = self.error_encode_ranker.get_scores(state, k = int(10e5))
        inner_examples = self.error_encode_ranker.get_inner_examples(scores, theorem_name, k = self.inner_k)
        outer_examples = self.error_encode_ranker.get_outer_examples(scores, theorem_name, k = self.outer_k)
        inner_example_str = ""
        outer_example_str = ""
        error_template = "[State]\n{}\n[Error Msg]\n{}\n"
        for i in inner_examples:
            inner_example_str+=error_template.format(i['tactic_state'].strip(),i["error_msg"].strip())
        for i in outer_examples:
            outer_example_str +=error_template.format(i['tactic_state'].strip(),i["error_msg"].strip())

        return inner_example_str, outer_example_str

    def prepare_input_prompt(self, full_pseudo_code, proof_state, theorem_name):
        input_prompt = self.template_manager.get_template("error_summary_prompt")
        input_constructor = {
            "full_pseudo_code":full_pseudo_code,
            "proof_state":proof_state
        }
        inner_error, outer_error = self.prepare_errors(proof_state, theorem_name)
        input_constructor.update({
            "inner_error":inner_error,
            "outer_error":outer_error
        })
        input_prompt = input_prompt.format(**input_constructor)

        return input_prompt

    def run(self, full_pseudo_code, proof_state, theorem_name):
        input_prompt = self.prepare_input_prompt(full_pseudo_code, proof_state, theorem_name)
        messageId, response, token_prompt, token_compli = request_gpt(input_prompt,curiosity=self.curiosity,model=self.llm_engine)
        suggest = self.parse_results(response)
        return suggest, token_prompt, token_compli

    def parse_results(self,result):
        return result

class PremiseSummarizer:
    def __init__ (
        self, re_model:Lean3EncodeRanker,
        pre_k = 10,  llm_engine="gpt-4-turbo-128k",curiosity = "lv1",is_conversation=False,
        temperature = 0.4,

        ):
        self.template_manager = TemplateManager()
        self.premise_encode_ranker = re_model
        # self.premise_encode_ranker.reindex()
        self.re_pre_k = pre_k
        self.llm_engine = llm_engine
        self.curiosity = curiosity
        self.at_least_k = 4
        self.is_conversation = is_conversation
        self.temperature = temperature


    def prepare_premise(self, state):
        scores = self.premise_encode_ranker.get_scores(state, self.re_pre_k)
        premises = [self.premise_encode_ranker._responses[i] for i,s in scores]
        premises_examples = []
        for idx, p in enumerate(premises[:self.re_pre_k]):
            template = f"[Premise_{idx}] {p.strip()}"
            premises_examples += [template]
        return premises_examples

    def prepare_input_conversation(self, full_pseudo_code, proof_state):
        input_constructor = {
            "full_pseudo_code":full_pseudo_code,
            "proof_state":proof_state,
            "k":self.at_least_k
        }
        premises = self.prepare_premise(proof_state)
        input_constructor.update({
            "premises":"\n".join(premises),
        })
        # if not self.is_conversation:
        msg = [
            {"role":"system",
            "content":self.template_manager.get_template("premise_summary_prompt_system").format()}
        ]
        input_prompt = self.template_manager.get_template("premise_summary_prompt_instruction").format(**input_constructor)
        msg.append({"role":"user","content":input_prompt})
        return "", msg, premises

    def prepare_input_prompt(self, full_pseudo_code, proof_state):
        input_prompt = self.template_manager.get_template("premise_summary_prompt")
        input_constructor = {
            "full_pseudo_code":full_pseudo_code,
            "proof_state":proof_state,
            "k":self.at_least_k
        }
        premises = self.prepare_premise(proof_state)
        input_constructor.update({
            "premises":"\n".join(premises),
        })
        input_prompt = input_prompt.format(**input_constructor)

        return input_prompt, premises

    def run(self, full_pseudo_code, proof_state, conv_msg = []):
        if not self.is_conversation:
            input_prompt, premises = self.prepare_input_prompt(full_pseudo_code, proof_state)
        else:
            input_prompt, conv_msg, premises = self.prepare_input_conversation(full_pseudo_code, proof_state)
        messageId, response, token_prompt, token_compli = request_gpt(input_prompt,model=self.llm_engine, last_messages=conv_msg,temperature=self.temperature,)
        suggest = self.parse_results(response, premises)
        logger.info(f"# Premises summary\n{suggest}")
        return suggest, token_prompt, token_compli

    def parse_results(self,result, premises):
        # premise_match = re.findall("\[.*?\]",result,re.S)
        selected_p = []
        premises_dict ={}
        premise_appeared = re.findall("<a>(.*?)</a>",result)
        for i, p in enumerate(premises):
            # num+=1
            if re.search(r"\b" + f"Premise_{i}" + r"\b", result) or re.search(r"\b" + f"Premise {i}" + r"\b", result):
                selected_p.append(p)
        for p_m in premise_appeared:
            for p in premises:
                # selected_p.append()
                # pass
                if p_m in p:
                    selected_p.append(p)
        selected_p = list(set(selected_p))
        for idx,p in enumerate(premises):
            premise = ' '.join(p.split(' ')[1:])
            match = re.findall("<a>(.*?)</a>",premise)[0]
            premises_dict[idx] = match
        # return "\n".join(premises)+'\n'+result
        premises_string = "\n".join(selected_p) if len(selected_p) else "See above."
        result_parsed = f"{result}\nTherefore potentially useful premises are:\n{premises_string}"
        response = result_parsed
        pattern_rm = [f"Analysis_{i}" for i in range(30)]+[f"Selection_{i}" for i in range(30)]
        for pattern in pattern_rm:
            response = re.sub(r"\b"+f"{pattern}"+r"\b","",response,)
        pattern_premise_idx = [f"Premise_{i}" for i in range(30)]
        for i,pattern in enumerate(pattern_premise_idx):
            if pattern in response and i in premises_dict:
                response = re.sub(r"\b"+f"{pattern}"+r"\b",premises_dict[i],response,)
                # response.replace(pattern,premises_dict[i])
        response=response.replace("**","")
        response=response.replace("[]","")

        return response
    

if __name__ == "__main__":


    state = "n : ℕ\n⊢ gcd n n = n"
    # theorem_name = 'aime_1983_p1'
    # input_prompt = error_summarizer.prepare_input_prompt("<place_holder>",state,theorem_name)
    # print(input_prompt)
    # suggest, t1, t2 = error_summarizer.run("<place_holder>",state,theorem_name)
    # print(suggest)
    do_premises_ranker = Lean3EncodeRanker(dst_dir="db/mathlib_db/leandojo_minif2f")

    premise_summarizer = PremiseSummarizer(do_premises_ranker,is_conversation=True)
    input_example,_ = premise_summarizer.prepare_input_prompt("pseudo_holder",state)
    print(input_example)
    results = premise_summarizer.run("pseudo_holder",state)

    # do_premises_ranker




    

        