import json
from chatgpt import ChatGPT
from script_kb import ScriptKB
import sys
from sentence_transformers import SentenceTransformer, util
import numpy as np
import faiss
from rank_bm25 import BM25Okapi
from typing import List

class RoleFact:

    def __init__(self,llm,story_kb,use_role_profile=True,use_retrieval=True,time_sensitive=True,retrieval_type="sbert",num_docs=5,scene_len=5,sample_size=5,threshold=0.8,anonymize=False):

        self.llm = llm
        self.llm_function = ChatGPT("gpt-3.5-turbo-0125")
        self.story_kb = story_kb
        self.use_role_profile = use_role_profile
        self.use_retrieval = use_retrieval
        self.time_sensitive = time_sensitive
        
        self.prompt_generator = RoleFactPrompt(
            story_kb,
            use_role_profile=use_role_profile,
            use_retrieval=use_retrieval,
            retrieval_type=retrieval_type,
            num_docs=num_docs,
            scene_len=scene_len
        )
        self.sample_size = sample_size
        self.threshold = threshold
        self.anonymous = anonymize
        self.success_count = 0
        self.failure_count =0

        with open("config/prompts.json",'r') as file:
            self.prompts = json.load(file)

        with open("config/functions.json",'r') as file:
            self.functions = json.load(file)

    def anonymize_prompt(self,text,role_name):
    
        prompt = self.prompts["anonymize_prompt"].replace("<role_name>",role_name)
        prompt+= "Prompt: "+text
        new_text = self.llm.generate(prompt)
    
        return new_text

    def get_prelim_response(self,task):

        if(task["task_type"]=="adversarial_interview" or task["task_type"]=="open_ended_interview"):
            question = "CURIOUS PERSON: "+ task["question"]
        elif(task["task_type"]=="dialogue_completion"):
            question = task["scene_context"]
        elif(task["task_type"]=="scene_grounded_interview"):
            question = "CURIOUS PERSON: "+ task["second_person"]
        else:
            print("Invalid task type name")
            
        role_name = task["target_character"]
 
        try:
            self.prompt_generator.set_role(role_name)
            prompt,knowledge_list = self.prompt_generator.build_prompt(question,task["end_time"])
            if(self.anonymous):
                prompt = self.anonymize_prompt(prompt,role_name)
            response = self.llm.generate(prompt)
            self.success_count+=1
        except:
            self.failure_count+=1

        return question,response,knowledge_list

    def get_atomic_facts(self,response):
    
        prompt = self.prompts["atomic_fact_generation"]
        prompt+= "Utterance:\n"+response
        function = self.functions["atomic_fact_generation"]
        
        try:
            argument = self.llm_function.generate_with_function(prompt,[function],function_call={"name": "atomic_fact_generation"})
            atomic_fact_list = eval(argument)["atomic_fact_list"]
        except:
            print("Exception during atomic fact generation")
            return []
        
        return atomic_fact_list

    def verify_non_parametric(self,knowledge_list,atomic_fact):
    
        prompt = self.prompts["atomic_fact_verification"]
        prompt+= "Statement: "+atomic_fact+"\n"
        function = self.functions["atomic_fact_verification"]
    
        for ki,knowledge in enumerate(knowledge_list):
            prompt+= "Evidence Knowledge "+str(ki+1)+":\n"
            prompt+= knowledge+"\n"
    
        try:
            argument = self.llm_function.generate_with_function(prompt,[function],function_call={"name": "atomic_fact_verification"})
            argument = argument.replace("true","True").replace("false","False")
            is_supported = eval(argument)["is_statement_supported"]
        except:
            print("Exception during atomic fact verification")
            return None
    
        return is_supported

    def verify_parametric(self,atomic_fact,role_name,story_name):
    
        prompt = self.prompts["atomic_fact_verification_script_free"].replace("<story_title>",story_name).replace("<role_name>",role_name)
        prompt+= "Statement: "+atomic_fact+"\n"
        function = self.functions["atomic_fact_verification"]
    
        try:
            argument = self.llm_function.generate_with_function(prompt,[function],function_call={"name": "atomic_fact_verification"})
            argument = argument.replace("true","True").replace("false","False")
            is_supported = eval(argument)["is_statement_supported"]
        except:
            print("Exception during atomic fact verification")
            return None
    
        return is_supported


    def get_response(self,task):

        question,prelim_response,knowledge_list = self.get_prelim_response(task)
        atomic_facts = self.get_atomic_facts(prelim_response)
        unverified_fact_list = []
        verification_list = []

        for fact in atomic_facts:
            verification = {
                "atomic_fact": fact,
                "non_param": False,
                "param": False,
                "param_yes": 0,
                "param_no": 0
            }

            if(self.use_retrieval):
                verification["non_param"] = self.verify_non_parametric(knowledge_list,fact)

                if(verification["non_param"]):
                    verification_list.append(verification)
                    continue

            for i in range(self.sample_size):
                is_supported = self.verify_parametric(fact,task["target_character"],task["story_title"])
                if(is_supported):
                    verification["param_yes"]+=1
                else:
                    verification["param_no"]+=1

            if(self.sample_size>0):
                verification["param"] = (verification["param_yes"]/self.sample_size)>=self.threshold
            
            if(not verification["param"]):
                unverified_fact_list.append(fact)
            verification_list.append(verification)

        if(len(unverified_fact_list)==0):
            return prelim_response,None,prelim_response

        role_prompt = self.prompts["rolefact_rewrite"]
        role_prompt = role_prompt.replace("<role_name>",task["target_character"])
        role_prompt = role_prompt.replace("<story_title>",task["story_title"])

        role_prompt += "\nRESPONSE: \n"+prelim_response+"\n"
        role_prompt += "\nUnsupported Claim List:\n\n"
            
        for fi,fact in enumerate(unverified_fact_list):
            role_prompt+= "Unsupported Claim "+str(fi)+": "+fact+"\n"
        role_prompt+= "\n"

        # role_prompt+= "\nDIALOGUE CONTEXT:\n"
        # role_prompt+=question

        role_prompt = role_prompt.replace("\n\n\n","\n\n")
        role_prompt = role_prompt.replace("\n\n\n\n","\n\n")

        try:
            if(self.anonymous):
                role_prompt = self.anonymize_prompt(role_prompt,task["target_character"])
            response = self.llm.generate(role_prompt)
            self.success_count+=1
        except:
            self.failure_count+=1
            
        return prelim_response,role_prompt,response

    def get_cost(self):

        return self.llm.get_cost()+self.llm_function.get_cost()
        
class BM25:
    def __init__(self, documents: List[str]):
        # Tokenize the documents, simple whitespace tokenizer and convert to lowercase
        self.corpus = [doc.lower().split() for doc in documents]
        # Initialize the BM25 model
        self.bm25 = BM25Okapi(self.corpus)

    def search(self, query: str, num_docs: int) -> List[str]:
        # Tokenize the query and convert to lowercase
        query_tokens = query.lower().split()
        # Get scores for each document based on the query
        doc_scores = self.bm25.get_scores(query_tokens)
        # Retrieve the indices of documents with the highest scores
        top_indices = sorted(range(len(doc_scores)), key=lambda i: doc_scores[i], reverse=True)[:num_docs]
        # Return the corresponding documents, converting token lists back to strings
        return top_indices

class SBERT:
    def __init__(self, documents: List[str]):
        # Load a pre-trained Sentence Transformer model
        self.model = SentenceTransformer('all-MiniLM-L6-v2')
        # Encode all documents to get their embeddings
        self.document_embeddings = self.model.encode(documents, convert_to_tensor=True)
        # Store the original documents
        self.documents = documents

    def search(self, query: str, num_docs: int) -> List[int]:
        # Encode the query to get its embedding
        query_embedding = self.model.encode(query, convert_to_tensor=True)
        # Compute cosine similarities between the query embedding and the document embeddings
        cosine_scores = util.pytorch_cos_sim(query_embedding, self.document_embeddings)[0]
        # Extract top `num_docs` indices with the highest cosine similarity scores
        top_indices = cosine_scores.argsort(descending=True)[:num_docs]
        # Convert tensor indices to list of indices (Python integers)
        return top_indices.tolist()

class Contriever:
    def __init__(self, documents: List[str]):
        # Load the 'facebook/contriever' model from the Sentence Transformers library
        self.model = SentenceTransformer('facebook/contriever')
        # Encode all documents to get their embeddings
        self.document_embeddings = self.model.encode(documents, convert_to_tensor=True)
        # Store the original documents
        self.documents = documents

    def search(self, query: str, num_docs: int) -> List[int]:
        # Encode the query to get its embedding
        query_embedding = self.model.encode(query, convert_to_tensor=True)
        # Compute cosine similarities between the query embedding and the document embeddings
        cosine_scores = util.pytorch_cos_sim(query_embedding, self.document_embeddings)[0]
        # Extract top `num_docs` indices with the highest cosine similarity scores
        top_indices = cosine_scores.argsort(descending=True)[:num_docs]
        # Convert tensor indices to list of integers for easier handling outside the class
        return top_indices.tolist()

class RoleFactPrompt:

    def __init__(self,story_kb,use_role_profile=True,use_retrieval=True,retrieval_type="contriever",num_docs=5,scene_len=5):

        self.role_name = None
        self.story_kb = story_kb
        self.scene_width = (scene_len-1)//2
        self.role_profile = None
        self.use_role_profile = use_role_profile
        self.use_retrieval = use_retrieval
        self.num_docs = num_docs
        self.documents = self.preprocess_docs()
        if(retrieval_type=="bm25"):
            self.retrieval_system = BM25(self.documents)
        elif(retrieval_type=="contriever"):
            self.retrieval_system = Contriever(self.documents)
        elif(retrieval_type=="sbert"):
            self.retrieval_system = SBERT(self.documents)
        else:
            if(use_retrieval):
                sys.exit("invalid retriever")
            else:
                self.retrieval_system = None
        
        with open("config/prompts.json",'r') as file:
            self.prompts = json.load(file)

    def preprocess_docs(self):

        memories = []

        for mi in range(len(self.story_kb.story["content"])):
            memory= ""
            for ci in range(mi-self.scene_width,mi+self.scene_width):
                if(ci<0 or ci>=len(self.story_kb.story["content"])):
                    continue
                if(self.story_kb.story["content"][ci]["content_type"]=="setting"):
                    if(memory==""):
                        continue
                    else:
                        break
                else:
                    memory += self.story_kb.get_scene_content(self.story_kb.story["content"][ci])+"\n"
            memories.append(memory)
            memory= ""

        return memories

    def set_role(self,role_name):
        self.role_name = role_name
        self.role_profile = self.story_kb.get_character_profile(role_name)

    def build_prompt(self,query,end_time=None):

        role_prompt = self.prompts["profile_prompting"]
        role_prompt = role_prompt.replace("<role_name>",self.role_name)
        role_prompt = role_prompt.replace("<story_title>",self.story_kb.get_title())

        if(self.use_role_profile):
            role_prompt += "\n" + self.role_profile["second_person"] + "\n"

        memories = []
        
        if(self.use_retrieval):
            
            memory_ids = self.retrieval_system.search(query,num_docs=self.num_docs)
            
            if(end_time is not None):
                memory_ids = [idx for idx in memory_ids if self.story_kb.story["content"][idx]["timestep"]<=end_time]
                
                # if(len(memory_ids)<self.num_docs):
                #     print("filtered")
                    
            memories = [self.documents[idx] for idx in memory_ids]
            
            role_prompt += "\nRelevant scenes for the given context are as follows:\n\n"
            
            for mi,memory in enumerate(memories):
                role_prompt+= "SCENE "+str(mi)+"\n"
                role_prompt+= memory+"\n"
            role_prompt+= "\n"

        role_prompt+= "\nDIALOGUE CONTEXT:\n"
        role_prompt+=query

        role_prompt = role_prompt.replace("\n\n\n","\n\n")
        role_prompt = role_prompt.replace("\n\n\n\n","\n\n")

        return role_prompt,memories
