import json
from chatgpt import ChatGPT
from script_kb import ScriptKB
import sys
from sentence_transformers import SentenceTransformer
import numpy as np
import faiss

class Baseline:

    def __init__(self,story_kb,use_role_profile=False,use_retrieval=False,use_knowledge_supression=None,num_scenes=5,scene_len=5):

        self.role_name = None
        self.story_kb = story_kb
        self.role_profile = None
        self.use_knowledge_supression = use_knowledge_supression
        self.use_role_profile = use_role_profile
        self.use_retrieval = use_retrieval
        self.num_scenes = num_scenes
        self.scene_width = (scene_len-1)//2
        self.memory_contents = self.story_kb.story["content"]
        self.retreival_model = SentenceTransformer('all-MiniLM-L6-v2')
        self.index = None
        self.build_memories()
        self.tag = "profile_" + str(use_role_profile)  + "_retrieval_" +str(use_retrieval) + "_supression_"+str(use_knowledge_supression) + "_num_scenes_"+ str(num_scenes) + "_scene_len_"+str(scene_len)
        
        with open("config/prompts.json",'r') as file:
            self.prompts = json.load(file)

    def set_role(self,role_name):
        self.role_name = role_name
        self.role_profile = self.story_kb.get_character_profile(role_name)

    def get_tag(self):

        return  self.tag

    def build_prompt(self,query):

        if(self.use_knowledge_supression is None):
            role_prompt = self.prompts["profile_prompting"]
        elif(self.use_knowledge_supression=="instruction"):
            role_prompt = self.prompts["profile_prompting_instruction_restriction"]
        elif(self.use_knowledge_supression=="retrieval"):
            role_prompt = self.prompts["profile_prompting_retreival_restriction"]
        else:
            sys.exit("Knowledge supression undefined")

        role_prompt = role_prompt.replace("<role_name>",self.role_name)
        role_prompt = role_prompt.replace("<story_title>",self.story_kb.get_title())

        if(self.use_role_profile):
            role_prompt += "\n" + self.role_profile["second_person"] + "\n"

        if(self.use_retrieval):
            memories = self.get_similar_memories(query,k=self.num_scenes)
            role_prompt += "\nRelevant scenes for the given context are as follows:\n\n"
            for mi,memory in enumerate(memories):
                role_prompt+= "SCENE "+str(mi)+"\n"
                role_prompt+= memory+"\n"
            role_prompt+= "\n"

        role_prompt+= "\nDIALOGUE CONTEXT:\n"
        role_prompt+=query

        role_prompt = role_prompt.replace("\n\n\n","\n\n")
        role_prompt = role_prompt.replace("\n\n\n\n","\n\n")

        return role_prompt

    def build_memories(self):

        documents = {}

        for xi,x in enumerate(self.memory_contents):
            documents[xi]=x["text"]
        embeddings, doc_ids = self.encode_memories(documents)
        self.build_index(embeddings)
        

    def encode_memories(self, documents):
        # expects documents to be a dict with doc_id and doc_text
        doc_ids = list(documents.keys())
        doc_texts = list(documents.values())

        # SentenceTransformer supports batch processing
        embeddings = self.retreival_model.encode(doc_texts)

        return embeddings, np.array(doc_ids)

    def build_index(self, embeddings):
        dimension = embeddings.shape[1]
        self.index = faiss.IndexFlatL2(dimension)
        faiss.normalize_L2(embeddings)
        self.index.add(embeddings)

    def get_similar_memories(self, query, k=5):
        
        query_embedding = self.encode_memories({0: query})[0]
        faiss.normalize_L2(query_embedding)
        D, I = self.index.search(query_embedding, k)
        
        memories = []

        for mi in I[0,:].tolist():
            memory= ""
            if(mi!=-1):
                for ci in range(mi-self.scene_width,mi+self.scene_width):
                    if(self.memory_contents[ci]["content_type"]=="setting"):
                        if(memory==""):
                            continue
                        else:
                            break
                    else:
                        memory += self.story_kb.get_scene_content(self.memory_contents[ci])+"\n"
                memories.append(memory)
                memory= ""
            
        return memories
