from langchain import PromptTemplate
import json
import random
from agent.gpt_api import request_gpt
from re_assist import KnowledgeRetrieval
import re
from loguru import logger


class TemplateManager:

    def __init__(self):
        self.templates = {
           

        }

    def get_template(self, template_name):
        return self.templates.get(template_name, None)


class ReverseHypo:
    def __init__(self, llm_engine = "gpt-4-turbo-128k", k =8 , curiosity = "lv2", is_conversation=False,temperature = 0.0):
        self.template_manager = TemplateManager()
        self.curiosity = curiosity
        self.llm_engine = llm_engine
        self.k = k+1
        self.is_conversation = is_conversation
        self.temperature = temperature

    def prepare_input_conversation(self, pseudo_code, informal_statement, informal_proof, current_state, premise=None):
        input_constructor = {
            "pseudo_code":pseudo_code,
            "informal_statement":informal_statement,
            "informal_proof":informal_proof,
            "current_state":current_state,
            "k":self.k
        }
        if premise is None:
            msg = [
                {"role":"system",
                "content":self.template_manager.get_template("reverse_chain_system").format()}
            ]
            input_prompt = self.template_manager.get_template("reverse_chain_instruction").format(**input_constructor)
        else:
            msg = [
                {"role":"system",
                "content":self.template_manager.get_template("reverse_chain_premise_system").format()}
            ]
            input_constructor.update(
                {"premises":premise}
            )
            input_prompt = self.template_manager.get_template("reverse_chain_premise_instruction").format(**input_constructor)
        msg.append(
            {"role":"user","content":input_prompt}
        )
        return "", msg

    def prepare_input_prompt(self, pseudo_code, informal_statement, informal_proof, current_state, premise=None):
        # k_start = 1
        input_constructor = {
            "pseudo_code":pseudo_code,
            "informal_statement":informal_statement,
            "informal_proof":informal_proof,
            "current_state":current_state,
            "k":self.k
        }
        if premise is None:
            input_prompt = self.template_manager.get_template("reverse_chain").format(**input_constructor)
        else:
            input_constructor.update(
                {"premises":premise}
            )
            input_prompt = self.template_manager.get_template("reverse_chain_premise").format(**input_constructor)
        return input_prompt
    
    def run(self, current_state, pseudo_code, informal_statement, informal_proof, premise=None,conv_msg = []):
        # premise_brief = premise
        # premise = premise_brief
        if not self.is_conversation:
            input_prompt = self.prepare_input_prompt(pseudo_code, informal_statement, informal_proof, current_state,premise)
        else:
            input_prompt, conv_msg = self.prepare_input_conversation(pseudo_code, informal_statement, informal_proof, current_state,premise)
        messageId, response, token_prompt, token_compli = request_gpt(input_prompt,model=self.llm_engine,last_messages=conv_msg,temperature=self.temperature)
        reverse_hypo = self.parse_results(response)
        return reverse_hypo, token_prompt, token_compli

    def parse_results(self,result):
        print("parsing have statement")
        tactics = re.findall(r"```lean(.*?)```",result,re.S)
        if len(tactics)<=2:
            new_tactics = []
            for tac in tactics:
                tac_sep = tac.split('\nhave')
                for t in tac_sep:
                    t = re.sub("--.*",'',t).strip()
                    t = re.sub("/-[\s\S]*?-/",'',t).strip()
                    if len(t):
                        # print('=====')
                        # print(t)
                        new_tactics.append('\nhave '+t )
            tactics = new_tactics
        tactics = [re.sub('/-[\s\S]*?-/','',t,re.S).strip()for t in tactics]
        tactics = [re.sub("--.*",'',t,re.S).strip()for t in tactics]
        tactics = [re.sub('```lean','',t).strip()for t in tactics]
        tactics = [re.sub('```','',t).strip()for t in tactics]
        tactics = [re.sub('\n','',t).strip()for t in tactics]
        """```lean
        /- Assert that 1 / a is not zero to avoid division by zero issues. -/
        apply one_div_ne_zero, norm_num
        ```"""
        logger.info(f"# {len(tactics)} Hypo generated (unverified)")
        return tactics

class ReverseGoal:
    def __init__(self, llm_engine = "gpt-4-turbo-128k", k =8 , curiosity = "lv1"):
        self.template_manager = TemplateManager()
        self.curiosity = curiosity
        self.llm_engine = llm_engine
        self.k = k+2
        
    def prepare_input_prompt(self, informal_statement, informal_proof, current_state, premise=None,pseudo_code=None):
        # k_start = 1
        input_constructor = {
            "pseudo_code":pseudo_code,
            "informal_statement":informal_statement,
            "informal_proof":informal_proof,
            "current_state":current_state,
            "k":self.k
        }
        if premise is None:
            input_prompt = self.template_manager.get_template("reverse_chain_suffices").format(**input_constructor)
        else:
            input_constructor.update(
                {"premises":premise}
            )
            input_prompt = self.template_manager.get_template("reverse_chain_suffices_premise").format(**input_constructor)
        return input_prompt
    
    def run(self, current_state, informal_statement, informal_proof, premise):
        if premise:
            premise_brief = premise#.split("Therefore potentially useful premises are:")[-1]
            premise = premise_brief
        input_prompt = self.prepare_input_prompt(informal_statement, informal_proof, current_state,premise)
        messageId, response, token_prompt, token_compli = request_gpt(input_prompt,curiosity=self.curiosity,model=self.llm_engine)
        reverse_hypo = self.parse_results(response)
        return reverse_hypo, token_prompt, token_compli

    def parse_results(self,result):
        # print()
        tactics = re.findall(r"```lean(.*?)```",result,re.S)
        tactics = [re.sub('/-[\s\S]*?-/','',t).strip()for t in tactics]
        tactics = [re.sub("--.*",'',t).strip()for t in tactics]
        tactics = [re.sub('```lean','',t).strip()for t in tactics]
        tactics = [re.sub('```','',t).strip()for t in tactics]
        tactics = [re.sub('\n','',t).strip()for t in tactics]
        """```lean
        /- Assert that 1 / a is not zero to avoid division by zero issues. -/
        apply one_div_ne_zero, norm_num
        ```"""
        logger.info(f"# {len(tactics)} Suffices generated (unverified)")
        logger.info(f"# suffices/rw tactics\n{tactics}")
        return tactics


if __name__ == "__main__":
    reverse_chain_agent = ReverseHypo(is_conversation=True)
    place_holder = "<place_holder>"
    input_prompt = reverse_chain_agent.prepare_input_conversation(place_holder,place_holder,place_holder,place_holder)
    print(input_prompt)