from db.reprover_assist_db.reprover_db import ReproverAug
from db.reprover_assist_db.reprover_gen import ReproverTacticGen

from db.mathlib_db.build_copra_aug import Lean3Bm25ReRanker,Lean3EncodeRanker
from transformers import AutoTokenizer, AutoModelForTextEncoding, AutoModelForSeq2SeqLM

from db.mathlib_db.src.lean_server.lean3_search_tool import Lean3SearchTool
from db.mathlib_db.src.tools.lean_cmd_executor import Constants
from agent.summerizer import ErrorSummarizer, PremiseSummarizer


class KnowledgeRetrieval:
    def __init__(self,
                 dst_dir = "leandojo_minif2f",
                 encoder_path = "leandojo-lean3-retriever-byt5-small",
                 generator_path = "leandojo-lean3-retriever-tacgen-byt5-small",
                 use_cuda = True,
                 device_id = 3,
                 is_conversation = True,
                 ):
        self.dst_dir = dst_dir
        self.device_id = device_id
        if use_cuda:
            self.device = 'cuda:{}'.format(self.device_id)
        else:
            self.device = 'cpu'
                
        self.gen_tokenizer = AutoTokenizer.from_pretrained(generator_path)
        self.generator = AutoModelForSeq2SeqLM.from_pretrained(generator_path).to(self.device)
        
        self.enc_tokenizer = AutoTokenizer.from_pretrained(encoder_path)
        self.encoder = AutoModelForTextEncoding.from_pretrained(encoder_path).to(self.device)
        

        '''
        Input: Lean State ""σ : ℝ ≃ ℝ,\nh : σ.to_fun 2 = σ.inv_fun 2\n⊢ σ.to_fun (σ.to_fun 2) = 2""
        Ouput: scores and tactic "[79.41515371159629]: simp [h]"
        suggest, scores = retriv_aug.get_bm_scores(inp)
        suggest, scores = retriv_aug.get_encode_scores(inp)
        '''
        self.reprover_aug:ReproverAug = ReproverAug(use_cuda=use_cuda,device_id=device_id,lean_encoder_path=encoder_path)
        self.reprover_aug.load_db()
        self.reprover_aug.load_encoder(model=self.encoder)
        
        """
        reprover_gen.tactic_gen(inp)
        """
        emb_ranker = Lean3EncodeRanker(model=self.encoder,dst_dir=self.dst_dir)
        self.reprover_tactic_gen:ReproverTacticGen = ReproverTacticGen(use_cuda=use_cuda, device_id=device_id, 
                                                     emb_ranker= emb_ranker,
                                                     generator=self.generator,
                                                     )
        
        
        """
        scores = nn_ranker.get_scores(inp)
        for idx, score in scores[:10]:
            print(f"[{score}]: {nn_ranker._responses[idx]}")
        """
        copra_mathlib_path = "Lean3miniF2F/_target/deps/mathlib"
        lean3_search_tool = Lean3SearchTool(mathlib_path=copra_mathlib_path)
        lean3_search_tool.initialize()
        lean3_bm25_reranker = Lean3Bm25ReRanker()
        index_lemmas = [str(lemma) for lemma in lean3_search_tool.lemmas]
        # lean3_bm25_reranker.reindex(index_lemmas)
        self.co_premises_ranker:Lean3EncodeRanker = Lean3EncodeRanker(model=self.encoder, device_id=self.device_id)
        self.co_premises_ranker.reindex(index_lemmas)
        
        """
        scores = nn_ranker.get_scores(inp)
        for idx, score in scores[:10]:
            print(f"[{score}]: {nn_ranker._responses[idx]}")
        """
        self.do_premises_ranker:Lean3EncodeRanker = Lean3EncodeRanker(model=self.encoder, dst_dir=dst_dir)
        
        # premises summerizer
        self.premise_summerizer = PremiseSummarizer(self.do_premises_ranker,pre_k = 20,is_conversation=is_conversation,temperature=0.4)

        # error summerizer
        # self.error_summerizer = ErrorSummarizer(re_model=self.encoder)
    
if __name__ == "__main__":
    KR = KnowledgeRetrieval(device_id=1)
    inp = """σ : ℝ ≃ ℝ,
h₀ : σ.inv_fun 2 = 10,
h₁ : σ.inv_fun 10 = 1,
h₂ : σ.inv_fun 1 = 2,
⊢ σ.to_fun (σ.to_fun (σ.inv_fun 2)) = 1"""
    # get dojo premises
    results = KR.do_premises_ranker.get_scores(inp,k=24)
    print(results)

    for idx, score in results:
        rp = KR.do_premises_ranker._responses[idx]
        print(f"[{score}]: {rp}")

    # compute tokens
    

    # append state

    # get 64 tactics



    # results = KR.reprover_aug.get_encode_scores(inp)
    # print(results[0])
    # ['norm_num', 'simp [h]', 'norm_num', 'norm_num', 'ring', 'linarith', 'norm_num', 'ring', 'norm_num', 'norm_num']
    # [('⊢ 91 ^ 2 = 8281', 'norm_num'), ('σ : ℝ ≃ ℝ,\nh : σ.to_fun 2 = σ.inv_fun 2\n⊢ σ.to_fun (σ.to_fun 2) = 2', 'simp [h]'), ('⊢ 71 % 3 = 2', 'norm_num'), ('⊢ 121 * 122 * 123 % 4 = 2', 'norm_num'), ('a b : ℂ\n⊢ (a + a) * (a + b) = 2 * a ^ 2 + 2 * (a * b)', 'ring'), ('x : ℝ,\nh₀ : 125 / 8 = x / 12\n⊢ x = 375 / 2', 'linarith'), ('⊢ 17 * 18 % 4 = 2', 'norm_num'), ('x : ℝ\n⊢ (x + 1) ^ 2 * x = x ^ 3 + 2 * x ^ 2 + x', 'ring'), ('⊢ 7! % 23 = 3', 'norm_num'), ('⊢ 8 * 9 ^ 2 + 5 * 9 + 2 = 695', 'norm_num')]
    # results = KR.reprover_tactic_gen.tactic_gen(inp)
    # print(results)
    # ['simp', 'classical', 'split', 'refl', 'dsimp', 'apply witt_injective', 'rw witt_strict_mono', 'rw witt_vector.witt_vector', 'simp [witt_vector.witt_vector]', 'rw [witt_vector.witt_vector]', 'simp only [witt_vector]', 'apply witt_vector.witt_vector', 'simp only [witt_vector.witt_vector]', 'exact witt_strict_mono _ _', 'simp only [witt_vitt, witt_vitt]', 'exact witt_strict_mono_right _ _', 'rw [witt_vector.witt_vector, witt_vector.witt_vector]', 'exact witt_vector.witt_vector', 'rw witt_vector.witt_vector_witt', 'exact witt_strict_mono witt_strict_mono', 'apply witt_vector.witt_strict_mono', 'simp only [witt_vector, witt_vector]', 'exact witt_vector.witt_vector _', 'exact witt_strict_mono (witt_strict_mono _)', 'simp only [witt_vector.witt_verts]', 'exact witt_strict_mono_right witt_strict_mono_right', 'simp only [witt_vector.witt_to_witt]', 'exact witt_strict_mono _ (witt_strict_mono _)', 'simp only [witt_vector.witt_strict_mono]', 'simp only [witt_vitt, witt_vitt, witt_vitt]', 'exact witt_strict_mono witt_strict_mono_right', 'exact witt_strict_mono_right witt_strict_mono']

    # results = KR.co_premises_ranker.get_scores(inp)
    # print(results)
    # [(1127, 0.89488285779953), (1161, 0.7934194207191467), (1160, 0.7576545476913452), (228, 0.7310068607330322), (48, 0.7002017498016357), (1213, 0.695798397064209), (772, 0.6679669618606567), (1060, 0.5790637731552124), (846, 0.5298870205879211), (610, 0.5104066133499146)]


    # for idx, score in results[:10]:
    #     print(f"[{score}]: {KR.co_premises_ranker._responses[idx]}")
