import sys
from langchain import PromptTemplate
import json
import random
from agent.gpt_api import request_gpt
from re_assist import KnowledgeRetrieval
import re
from loguru import logger

class TemplateManager:

    def __init__(self):
        self.templates = {
         

        }

    def get_template(self, template_name):
        return self.templates.get(template_name, None)

class NextTacticFroamProof:
  def __init__(self, retrive_agent:KnowledgeRetrieval, llm_engine = 'gpt-4-turbo-128k',k=12,is_conversation=False,temperature=0.0):
    self.template_manager = TemplateManager()

    # self.algebra_examples = {}
    # self.numbert_examples = {}
    self.curiosity="lv1"
    self.llm_engine = llm_engine
    self.re_agent = retrive_agent
    self.temperature = temperature
    self.is_conversation = is_conversation
    self.in_context_example_k = 2
    self.k = k 
    self.in_context_tactic = [
      "/- Evaluate and normalize numeric expressions. -/\nnorm_num",
      "/- Solve non linear arithmetic problems in the state. -/\nnlinarith",
      "/- Normalize ring expressions (like those involving + and * over the integers). -/\nring",
      "/- Performs the calculation described in the theorem's statement automatically and verifies if the result is indeed as specified. -/\ndec_trivial!",
      "/- Current goal is exactly equal to one of the hypotheses in the local context, use assumption to succeed and close the goal. -/\nassumption"
    ]

  def prepare_error_log_prompt_v1(self, state):
    # pass
    
    return {
      "episodic_error":"None",
      "cross_error":"None"
    }

  def prepare_error_log_prompt(self, full_pseudo_code,state,theorem_name):
    # error_summary,_,_ = self.re_agent.error_summerizer.run(full_pseudo_code,state,theorem_name)
    # logger.info(f"# error summary\n{error_summary}")
    return {
      "errors":"None",
    }

  def prepare_similar_examples(self, state, k = 3):
    state_tactic_pairs = self.re_agent.reprover_aug.get_encode_scores(state)
    similar_examples = ""
    for state,tactic in state_tactic_pairs[0][:k]:
      template = "[State]\n{}\n[Tactic]\n```lean\n{}\n```\n".format(state, tactic)
      similar_examples+= template
    return {"similar_tactics":similar_examples}

  def prepare_premise_p1(self, state, k = 8):
    scores = self.re_agent.do_premises_ranker.get_scores(state)
    premises = [self.re_agent.do_premises_ranker._responses[i] for i,s in scores]
    premises_examples = ""
    for p in premises[:k]:
      template = "[Premise]{}\n".format(p.strip())
      premises_examples += template
    return {"premises_p1":premises_examples}

  def prepare_premise_p2(self, state, k = 2):
    scores = self.re_agent.co_premises_ranker.get_scores(state)
    premises = [self.re_agent.co_premises_ranker._responses[i] for i,s in scores]
    premises_examples = ""
    for p in premises[:k]:
      template = "[Premise]{}\n".format(p.strip())
      premises_examples += template
    return {"premises_p2":premises_examples}
    # return prompt.format(premises_p2=premises_examples)

  def prepare_premise_summary(self, full_pseudo_code, state):
    premises, _, _ = self.re_agent.premise_summerizer.run(full_pseudo_code, state)
    # logger.info(f"# Premises summary\n{premises}")
    premises = premises#.split("Therefore potentially useful premises are:")[-1]
    return {
      "premises":premises
    }

  def prepare_suggest_tactics_v1(self, state ,k = 3):
    tactics = self.re_agent.reprover_tactic_gen.tactic_gen(state)
    suggest_tactics = ""
    for n,tac in enumerate(tactics[:k]):
      template = "#{}\n[State]\n{}\n[Tactic]\n```lean\n{}\n```\n".format(n+1,state,tac)
      suggest_tactics += template
    return {"suggested_tactics":suggest_tactics}
  
  def prepare_suggest_tactics(self, state ,k = 3):
    tactics = self.re_agent.reprover_tactic_gen.tactic_gen(state)
    suggest_tactics = ""
    for n,tac in enumerate(tactics[:k]):
      template = "[State]\n{}\n[Tactic]\n```lean\n{}\n```\n".format(state,tac)
      suggest_tactics += template
    return {"suggested_tactics":suggest_tactics}
    # return prompt.format(suggested_tactics=suggest_tactics)
  

  def prepare_input_prompt(self, current_state, pseudo_code, informal_statement,theorem_name,premise_inject=None, input_template_name = "get_formal_tactic_retireval"):
    # k_start = 3
    input_prompt = self.template_manager.get_template(input_template_name)
    input_constructor = {
      "k":self.k,
      "pseudo_code":pseudo_code,
      "informal_statement":informal_statement
    }
    input_constructor.update(self.prepare_error_log_prompt(pseudo_code, current_state,theorem_name))
    # input_constructor.update(self.prepare_similar_examples(current_state))
    if premise_inject is None:
      input_constructor.update(self.prepare_premise_summary(pseudo_code, current_state))
    else:
      # logger.info(f"# Premises summary\n{premise_inject}")
      premises = premise_inject#.split("Therefore potentially useful premises are:")[-1]
      input_constructor.update({
        "premises":premises
      })
    # input_constructor.update(self.prepare_premise_p1(current_state))
    # input_constructor.update(self.prepare_premise_p2(current_state))
    # input_constructor.update(self.prepare_suggest_tactics(current_state,k = k_start))
    if self.in_context_tactic is None:
      input_constructor.update({"current_state":"#{}\n[State]{}\n".format(1,current_state)})
    else:
      random.shuffle(self.in_context_tactic)
      current_state_prefix_samples = ""
      for idx,tac in enumerate(self.in_context_tactic[:self.in_context_example_k]):
        fix_sample_template = "#{}\n[Current State]\n{}\n[Tactic]\n```lean\n{}\n```\n".format(idx+1,current_state,tac)
        current_state_prefix_samples+=fix_sample_template
      input_constructor.update({
        "current_state":current_state_prefix_samples+'\n'+\
          ""
          # "#{}\n[Current State]\n{}\n[Tactic]\n".format(self.in_context_example_k+1,current_state)
      })
      # input_constructor

    input_prompt = input_prompt.format(**input_constructor)
    logger.info(input_prompt)
    return input_prompt

  def prepare_input_conversation(self, current_state, pseudo_code, informal_statement,theorem_name,premise_inject=None):
    msg = [
      {"role":"system","content":self.template_manager.get_template("get_formal_tactic_retireval_system").format(k=self.k)}
    ]
    prompt_input = self.prepare_input_prompt(
      current_state, pseudo_code, informal_statement,theorem_name,premise_inject, input_template_name="get_formal_tactic_retireval_instruction")
    msg.append(
      {'role':'user',"content":prompt_input}
    )
    return "", msg





  def run(self, state, pseudo_code, informal_statement,theorem_name,premise_inject=None,conv_msg = []):
    if not self.is_conversation:
      input_prompt = self.prepare_input_prompt(state, pseudo_code, informal_statement,theorem_name=theorem_name,premise_inject=premise_inject)
    else:
      input_prompt, conv_msg = self.prepare_input_conversation(state, pseudo_code, informal_statement,theorem_name=theorem_name,premise_inject=premise_inject)
    print(input_prompt)
    messageId, response, token_prompt, token_compli = request_gpt(input_prompt,model=self.llm_engine,last_messages=conv_msg,temperature=self.temperature)
    formal_proof = self.parse_results(response)
    return formal_proof, token_prompt, token_compli

  def parse_results(self,result):
    sep = f"tactic(s) with comments:\n"
    state_tac = result
    tactics = re.findall(r"```lean(.*?)```",state_tac,re.S)+self.in_context_tactic
    tactics = [re.sub('/-[\s\S]*?-/','',t).strip()for t in tactics]
    tactics = [re.sub("--.*",'',t).strip()for t in tactics]
    tactics = [re.sub('```lean','',t).strip()for t in tactics]
    tactics = [re.sub('```','',t).strip()for t in tactics]
    tactics = [re.sub('\n','',t).strip()for t in tactics]
    """```lean
    /- Assert that 1 / a is not zero to avoid division by zero issues. -/
    apply one_div_ne_zero, norm_num
    ```"""
    logger.info(f"# {len(tactics)} Tactics (unverified)")
    tactics = list(set(tactics))
    return tactics


class FormalProof:
  def __init__(self, llm_engine = 'gpt-3.5-turbo'):
    self.template_manager = TemplateManager()
    # get examples
    _algebra_examples = json.load(open("algebra.json"))
    _numbert_examples = json.load(open("numbertheory.json"))

    self.algebra_examples = {}
    self.numbert_examples = {}
    self.curiosity="lv1"
    self.llm_engine = llm_engine
    # exclude filed
    for ex in _algebra_examples:
      # ex.pop("informal_solution")
      self.algebra_examples[ex['theorem_name']] = ex
    for ex in _numbert_examples:
      self.numbert_examples[ex['theorem_name']] = ex

  def prepare_example_prompt(self,theorem_name="mathd_algebra_478",example_num=2):
    if "numbertheory" not in theorem_name:
      algebra_examples = self.algebra_examples.copy()
      if theorem_name in algebra_examples: algebra_examples.pop(theorem_name)
      examples = random.sample(list(algebra_examples.values()),example_num)
    else:
      numbert_examples = self.numbert_examples.copy()
      if theorem_name in numbert_examples: numbert_examples.pop(theorem_name)
      examples = random.sample(list(numbert_examples.values()),example_num)

    example_template = "### Informal proof:\n{}\n### Formal proof:\n```lean\n{}\n```\n"
    example_string = ""
    for ex in examples:
      example_string+=example_template.format(ex['informal_proof'].strip(),ex['proof'].strip())
    return example_string

  def prepare_input_prompt(self, statement,informal_proof, theorem_name="mathd_algebra_478", error_db=[]):
    in_context_examples = self.prepare_example_prompt(theorem_name)
    if len(error_db) == 0:
      input_prompt = self.template_manager.get_template("get_formal_proof").format(examples=in_context_examples,informal_proof=informal_proof.strip(),theorem=statement.strip())
    else:
      error_db_random = sorted(error_db, key=lambda x: random.random())
      select = "\n".join(["```lean\n{}\n```\n".format(s) for s in error_db_random[:3]])
      input_prompt = self.template_manager.get_template("get_formal_proof_with_error").format(examples=in_context_examples,informal_proof=informal_proof.strip(),theorem=statement.strip(),error=select)
    return input_prompt

  def run(self, statement, informal_proof ,theorem_name):
    input_prompt = self.prepare_input_prompt(statement=statement, informal_proof=informal_proof, theorem_name=theorem_name)
    messageId, response, token_prompt, token_compli = request_gpt(input_prompt,curiosity=self.curiosity,model=self.llm_engine)
    formal_proof = self.parse_results(response)
    return formal_proof, token_prompt, token_compli

  def parse_results(self,result):
    return result


def parse_formal_steps(proof:str):
  proof_steps = proof.split('\n')
  clean_steps = []
  for row,step in enumerate(proof_steps):
    step_ = step.strip()
    if step_.startswith(":=") or step_.startswith("by") or step_.startswith("theorem") or step_.startswith("import"):
      continue
    elif step_.startswith("--") or step_.startswith("begin") or step_.startswith("end"):
      continue
    elif len(step_)==0:
      continue
    else:
      if step.rstrip()[-1]==',':
        step=step.rstrip()[:-1]
      clean_steps.append((row, step.rstrip()))

  return clean_steps

if __name__ == "__main__":
  testPrompt = FormalProof()
  test_formal_theorem = """theorem mathd_algebra_478 (b h v : ℝ) (h₀ : 0 < b ∧ 0 < h ∧ 0 < v) (h₁ : v = 1 / 3 * (b * h))
      (h₂ : b = 30) (h₃ : h = 13 / 2) : v = 65"""
  informal_theorem = "We are given that $B = 30$ and $h = 6.5$ and asked to find $\\frac{1}{3}Bh$.  We find that \\[\\frac{1}{3}Bh = \\frac{1}{3}(30)(6.5) = (10)(6.5) = 65.\\]"
  print(informal_theorem)
  input_prompt = testPrompt.prepare_input_prompt(test_formal_theorem,informal_theorem,"mathd_algebra_478")
  print(input_prompt)

  testProve = FormalProof()
  # gpt4 做对了
  print(testProve.run(test_formal_theorem,informal_theorem,"mathd_algebra_478"))

  retriver = KnowledgeRetrieval(device_id=2)
  state_test= "σ : ℝ ≃ ℝ,\nh : σ.to_fun 2 = σ.inv_fun 2\n⊢ σ.to_fun (σ.to_fun 2) = 2"
  tac_generator = NextTacticFroamProof(retrive_agent=retriver,llm_engine="gpt-4 8K",is_conversation=True)
  pseudo_code = "/- We start by declaring our theorem and its parameters. -/\n-- theorem mathd_algebra_478 (b h v : ℝ) (h₀ : 0 < b ∧ 0 < h ∧ 0 < v) (h₁ : v = 1 / 3 * (b * h)) (h₂ : b = 30) (h₃ : h = 13 / 2) : v = 65\nbegin\n/- We enter into our proof environment. -/\n-- begin\n\n  /- Our strategy is to substitute the given values and simplify the expression. We replace the value of 'b' in the equation 'v = 1/3 * (b * h)' with 30. -/\n  -- rw h₂ at h₁,\n\n  /- Then, we replace the value of 'h' in the equation 'v = 1/3 * (b * h)' with 13/2 to get 'v = 1 / 3 * (30 * 13 / 2)'. -/\n  -- rw h₃ at h₁,\n\n  /- Simplifying the RHS, we get 65. -/\n  -- simplify equation at h₁,\n\n  /- Therefore, we have established that v = 65 as required. -/\n  -- exact h₁,\n\n/- We close our proof block. -/\n-- end\n\n/- This completes our Lean Theorem Prover pseudo code for the given problem. -/\n-- end mathd_algebra_478"
  informal_proof = "Denote $a$ and $b$ as the tens and units digit of $N$, respectively. Then $N = 10a+b$. It follows that $10a+b=ab+a+b$, which implies that $9a=ab$. Since $a\neq0$, $b=9$. So the units digit of $N$ is $(\text{E})9$."
  input_prompt,msg = tac_generator.prepare_input_conversation(state_test,pseudo_code,informal_theorem,theorem_name="mathd_algebra_478")
  print(input_prompt)
  output = tac_generator.run(state_test,pseudo_code,informal_theorem,theorem_name="mathd_algebra_478")
  print(output[0])




