from langchain import PromptTemplate
import json
import random
from agent.gpt_api import request_gpt
from glob import glob



class TemplateManager:
    def __init__(self):
        self.templates = {
           
        }

    def get_template(self, template_name):
        return self.templates.get(template_name, None)
      
class PseudoProofDraft:
  def __init__(self, llm_engine='gpt-4-turbo-128k', temperature = 0.4,is_conversation=False):
    self.is_conversation = is_conversation
    self.template_manager = TemplateManager()
    self.llm_engine = llm_engine
    self.temperature = temperature
    
  def prepare_input_conversation(self, informal_theorem, informal_proof, theorem):
    msg = [
      {"role": "system", "content": self.template_manager.get_template("get_pseudo_proof_system").format()},
    ]
    user_input = self.template_manager.get_template("get_pseudo_proof_instruction").format(
      informal_theorem=informal_theorem.strip(),
      informal_proof=informal_proof.strip(),
      theorem = theorem.strip()
      )
    msg.append({"role":"user", "content":user_input})
    return user_input, msg

  def prepare_input_prompt(self, informal_theorem, informal_proof, theorem):
    input_prompt = self.template_manager.get_template("get_pseudo_proof").format(
      informal_theorem=informal_theorem.strip(),
      informal_proof=informal_proof.strip(),
      theorem = theorem.strip()
      )
    return input_prompt
  
  def run(self, informal_theorem, informal_proof, theorem, max_try=5, conv_msg = []):
    if not self.is_conversation:
      input_prompt = self.prepare_input_prompt(informal_theorem, informal_proof, theorem)
    else:
      input_prompt, conv_msg = self.prepare_input_conversation(informal_theorem, informal_proof, theorem)
    messageId, response, token_prompt, token_compli = request_gpt(input_prompt,model=self.llm_engine,last_messages=conv_msg,temperature=self.temperature)
    pseudo_code = self.parse_results(response)


    return pseudo_code, token_prompt, token_compli
  
  def parse_results(self,result):
    return result
  
class AlignProofDraft:
  def __init__(self, llm_engine='gpt-4-turbo-128k', is_conversation=False, temperature = 0.4):
    self.is_conversation = is_conversation
    self.template_manager = TemplateManager()
    self.llm_engine = llm_engine
    self.temperature = temperature



  def prepare_input_prompt(self, pseudo_code, current_state ):
    input_prompt = self.template_manager.get_template('align_pseudo_to_state').format(
      pseudo = pseudo_code,
      state = current_state,
    )
    print(input_prompt)
    return input_prompt

  def prepare_input_conversation(self, pseudo_code, current_state):
    msg = [
      {"role":"system","content":self.template_manager.get_template("align_pseudo_to_state_system").format()}
    ]
    input_prompt = self.template_manager.get_template('align_pseudo_to_state_instruction').format(
      pseudo = pseudo_code,
      state = current_state,
    )
    msg.append({"role":"user", "content":input_prompt})
    return input_prompt, msg

  def run(self, pseudo_code, current_state, conv_msg = []):
    if not self.is_conversation:
      input_prompt = self.prepare_input_prompt( pseudo_code, current_state)
    else:
      input_prompt, conv_msg = self.prepare_input_conversation( pseudo_code, current_state)
    messageId, response, token_prompt, token_compli = request_gpt(input_prompt,model=self.llm_engine,last_messages=conv_msg,temperature=self.temperature)
    return response.strip() , token_prompt, token_compli

class StateExplainer:
  def __init__(self, llm_engine='gpt-4-turbo-128k', is_conversation=False, temperature = 0.4):
    self.is_conversation = is_conversation
    self.template_manager = TemplateManager()
    self.llm_engine = llm_engine
    self.temperature = temperature


  def prepare_input_prompt(self, current_state ):
    input_prompt = self.template_manager.get_template('current_state_informal').format(
      state = current_state,
    )
    return input_prompt

  def prepare_input_conversation(self, current_state):
    msg = [
      {"role":"system","content":self.template_manager.get_template("current_state_informal_system").format()}
    ]
    input_prompt = self.template_manager.get_template('current_state_informal_instruction').format(
      state = current_state,
    )
    msg.append({"role":"user", "content":input_prompt})
    return input_prompt, msg

  def run(self, current_state, conv_msg = []):
    if not self.is_conversation:
      input_prompt = self.prepare_input_prompt(current_state)
    else:
      input_prompt, conv_msg = self.prepare_input_conversation(current_state)
    
    messageId, response, token_prompt, token_compli = request_gpt(input_prompt,model=self.llm_engine,last_messages=conv_msg,temperature=self.temperature)
    return response.strip(), token_prompt, token_compli

class InformalProofDraft:
  def __init__(self, llm_engine='gpt-4-turbo-128k', is_conversation=False, temperature = 0.4):
    self.template_manager = TemplateManager()
    self.is_conversation = is_conversation
    self.llm_engine = llm_engine
    self.temperature = temperature
    # get examples
    _full_algebra_examples=self.load_dir_examples([
      "data/informal/test/*.json",
      "data/informal/valid/*.json"],
      example_type = "algebra")

    _full_numbert_examples=self.load_dir_examples([
      "data/informal/test/*.json",
      "data/informal/valid/*.json"],
      example_type = "numbertheory")

    self.algebra_examples = {}
    self.numbert_examples = {}

    self.full_algebra_examples = {}
    self.full_numbert_examples = {}
    # exclude filed
    # for ex in _algebra_examples:
    #   # ex.pop("informal_solution")
    #   ex['problem_name'] = ex['theorem_name']
    #   self.algebra_examples[ex['theorem_name']] = ex
    # for ex in _numbert_examples:
    #   ex['problem_name'] = ex['theorem_name']
    #   self.numbert_examples[ex['theorem_name']] = ex
    
    for ex in _full_algebra_examples:
      self.full_algebra_examples[ex['problem_name']] = ex

    for ex in _full_numbert_examples:
      self.full_numbert_examples[ex['problem_name']] = ex

    self.algebra_examples = self.full_algebra_examples
    self.numbert_examples = self.full_numbert_examples

  def load_dir_examples(self, paths, example_type):
    examples = []
    for path in paths:
      files = glob(pathname=path)
      for f in files:
        if example_type == "numbertheory" and "numbertheory" in f:
          # load numbertheory only
          examples.append(json.load(open(f)))
        elif example_type == "algebra":
          examples.append(json.load(open(f)))
    return examples

  def exact_math_example(self, theorem_name):
    if theorem_name in self.full_numbert_examples:
      return (self.full_numbert_examples[theorem_name]['informal_statement'], 
      self.full_numbert_examples[theorem_name]['informal_proof'])
    elif theorem_name in self.full_algebra_examples:
      return (self.full_algebra_examples[theorem_name]['informal_statement'], 
      self.full_algebra_examples[theorem_name]['informal_proof'])
    else:
      return False,False



  def prepare_example_prompt(self,theorem_name="mathd_algebra_478",example_num=3):
    if "numbertheory" not in theorem_name:
      algebra_examples = self.algebra_examples.copy()
      if theorem_name in algebra_examples: algebra_examples.pop(theorem_name)
      examples = random.sample(list(algebra_examples.values()),example_num)
    else:
      numbert_examples = self.numbert_examples.copy()
      if theorem_name in numbert_examples: numbert_examples.pop(theorem_name)
      examples = random.sample(list(numbert_examples.values()),example_num)

    example_template = "### Formal theorem:\n{}\n### Informal statement:\n{}\n### Informal proof:\n{}\n\n"
    example_string = ""
    for ex in examples:
      example_string+=example_template.format(ex['formal_statement'].strip(),ex['informal_statement'].strip(),ex['informal_proof'].strip())
    return example_string

  def prepare_input_conversation(self, statement, theorem_name,informal_statement,example_num=3):
    msg = [
      {"role":"system","content":self.template_manager.get_template("get_informal_proof_system").format()}
    ]
    if "numbertheory" not in theorem_name:
      algebra_examples = self.algebra_examples.copy()
      if theorem_name in algebra_examples: algebra_examples.pop(theorem_name)
      examples = random.sample(list(algebra_examples.values()),example_num)
    else:
      numbert_examples = self.numbert_examples.copy()
      if theorem_name in numbert_examples: numbert_examples.pop(theorem_name)
      examples = random.sample(list(numbert_examples.values()),example_num)
    for ex in examples:
      conver_template = [
        {"role":"user","content":"### Formal theorem:\n{}\n### Informal statement:\n{}\n### Informal proof:".format(ex['formal_statement'].strip(),ex['informal_statement'].strip())},
        {"role":"assistant","content":"\n{}\n".format(ex['informal_proof'].strip())}
      ]
      msg.extend(conver_template)
    msg.append(
      {"role":"user","content": self.template_manager.get_template("get_informal_proof_instruction").format(theorem=statement.strip(),informal_statement=informal_statement)},
    )
    input_prompt = ''
    return input_prompt, msg

  def prepare_input_prompt(self, statement, theorem_name,informal_statement):
    in_context_examples = self.prepare_example_prompt(theorem_name)
    input_prompt = self.template_manager.get_template("get_informal_proof").format(examples=in_context_examples,theorem=statement.strip(),informal_statement=informal_statement)
    return input_prompt



  def run(self, statement ,theorem_name, exact_match = False, conv_msg = []):
    informal_statement, informal_proof = self.exact_math_example(theorem_name)
    if exact_match and informal_statement:
        return informal_statement, informal_proof , 0, 0
    if not self.is_conversation:
      input_prompt = self.prepare_input_prompt(statement=statement,theorem_name=theorem_name,informal_statement = informal_statement)
    else:
      input_prompt, conv_msg = self.prepare_input_conversation(statement=statement,theorem_name=theorem_name,informal_statement = informal_statement,example_num=3)
    print(input_prompt)
    print(conv_msg)

    messageId, response, token_prompt, token_compli = request_gpt(input_prompt,model=self.llm_engine,last_messages=conv_msg,temperature=self.temperature)
    informal_proof = self.parse_results(response)

    return informal_statement, informal_proof , token_prompt, token_compli

  def parse_results(self,result):
    try:
      informal_proof = result.split('### Informal proof:')[1].strip()
    except:
      informal_proof = result
    return  informal_proof



if __name__ == "__main__":
  # get_informal_proof = template_manager.get_template("get_informal_proof").format(examples=formal_theorem,theorem=formal_theorem)
  test_formal_theorem = """theorem mathd_algebra_209
  (σ : equiv ℝ ℝ)
  (h₀ : σ.2 2 = 10)
  (h₁ : σ.2 10 = 1)
  (h₂ : σ.2 1 = 2) :
   1 = σ.1 (σ.1 10) :="""
  testInformal = InformalProofDraft(is_conversation=True)
  # testInformal.prepare_prompt()
  # print(testInformal.prepare_input_prompt(test_formal_theorem))

  # Pipeline:

  informal_theorem,informal_proof,_,_ = (testInformal.run(test_formal_theorem,"mathd_algebra_209",exact_match=False))
  print(informal_proof)
  
  # testPseudo = PseudoProofDraft(llm_engine="gpt-4 8K",is_conversation=True)
  # full_pseudo_code = testPseudo.run(informal_theorem,informal_proof,test_formal_theorem)
  # print(full_pseudo_code[0])

  # testStateExpl = StateExplainer(is_conversation=True)
  # state_explain = testStateExpl.run(test_formal_theorem)
  # print(state_explain[0])

  # testAlign = AlignProofDraft(is_conversation=True)
  # next_step_align = testAlign.run(full_pseudo_code,test_formal_theorem)
  # print(next_step_align[0])

  # pseudo_place_holder = "begin\npseduo here\nend"
  # current_state_holder = "current_state"

