from torch.utils.data import Dataset, DataLoader
from langchain.embeddings import HuggingFaceEmbeddings
from langchain.vectorstores import FAISS
from accelerate import Accelerator
from accelerate.utils import gather_object
from tqdm import tqdm
from string import Template
import os
import fire
import json
import warnings
from typing import Optional, Union
from glob import glob

warnings.filterwarnings("ignore")

def load_vectorstore(
        month: int,
        db_root: str,
        model_name: str = None
    ) -> FAISS:
    
    # if vectorstore exists
    db_faiss_dir = f'{db_root}/{month}'
    if os.path.exists(f'{db_faiss_dir}/index.faiss'):
        embeddings = HuggingFaceEmbeddings(model_name=model_name, 
                                           model_kwargs={'device': 'cuda'},
                                           encode_kwargs={'batch_size': 2048,
                                                         #'show_progress_bar': False,
                                                         'device': 'cuda'
                                                         }
                                           )
        db = FAISS.load_local(db_faiss_dir, embeddings=embeddings, allow_dangerous_deserialization=True) 
        return db
    else:
        raise Exception(f'DB directory {db_faiss_dir} is invalid.')


class RetrievalDataset(Dataset):
    
    def __init__(
        self,
        data_path: str,
        baseline_mode: str
    ):  
        # read data
        self.data_path = data_path
        with open(self.data_path) as f:
            self.dataset = json.load(f)
        print(f'Total {len(self.dataset)} data points') 
        
        # set baseline mode
        self.baseline_mode = baseline_mode 
        print(f"The baseline mode is {baseline_mode}")
        if baseline_mode == "question":
            self.get_retrieval_query = self._get_question
        elif baseline_mode == "parentheses":
            self.get_retrieval_query = self._get_parentheses_query
        elif baseline_mode == "replacement":
            self.get_retrieval_query = self._get_replacement_query
            
        # set retrieval query
        self._prepare_retrieval_query()
    
    def __getitem__(self, idx):
        return self.dataset[idx]
    
    def __len__(self):
        return len(self.dataset)
    
    def _prepare_retrieval_query(self):
        for item in self.dataset:
            item["retrieval_query"] = self.get_retrieval_query(item)
    
    def _get_question(self, item):
        return item["question"]
    
    def _get_parentheses_query(self, item):
        question = item["question"]
        entity_pred = item["entity_pred"][0]
        insert_start_idx = item["mention_idx"][1] + 1
        query = question[:insert_start_idx] + f" (also known as {entity_pred})" + question[insert_start_idx:] # insert predictions with bracket into the question
        return query
    
    def _get_replacement_query(self, item):
        question = item["question"]
        entity_pred = item["entity_pred"][0]
        start_idx, end_idx = item["mention_idx"]
        query = question[:start_idx] + entity_pred + question[end_idx + 1:]
        return query
    
    def collate_fn(self, batch):
        return batch[0]

class Retrieval:
    
    def __init__(
        self,
        db,
        baseline_mode: str,
        search_kwargs = None,
    ):
        self.db = db
        self.baseline_mode = baseline_mode
        self.search_kwargs = search_kwargs
        
    def _get_context(self, inputs):
        retrieval_query = inputs["retrieval_query"]
        docs = self.db.similarity_search_with_score(retrieval_query, **self.search_kwargs)
        retrieved = []
        context = ''
        gold = inputs["grounded_text"]
        for doc, score in docs:
            # record retrieval
            doc.metadata['document'] = doc.page_content
            # add result
            retrieved.append(doc.metadata)
            context = context + doc.page_content + '\n' 
        inputs['context'] = context[:-1]
        inputs[f"retrieval_{self.baseline_mode}"] = retrieved
        return inputs
    
    def __call__(self, inputs):
        return self._get_context(inputs)   
        
        
         
def retrieve(
        # accelerator: Accelerator,
        baseline_mode: str,
        month: str,
        data_root: str,
        save_root: str,
        db_faiss_dir: str,
        top_k: int = 3,
        er_top_k: int = 1,
        model_name: str = "intfloat/e5-base"
    ):
  
    # check baseline
    assert baseline_mode in ["question", "parentheses", "replacement"]
    
    # load dataset
    data_path = f"{data_root}/resolved_qa_{month}.json"
    dataset = RetrievalDataset(data_path, baseline_mode)    
    
    # make chain class
    retriever_db = load_vectorstore(month, db_faiss_dir, model_name=model_name)
    retrieval = Retrieval(retriever_db, baseline_mode, search_kwargs={"k":top_k})
    
    # get dataset
    dataloader = DataLoader(dataset, 
                            batch_size=1, 
                            collate_fn=dataset.collate_fn
                            )
    #dataloader = accelerator.prepare(dataloader)
    
    # path to save
    save_path = f'{save_root}/{baseline_mode}/{month}.jsonl'
    assert not os.path.exists(save_path)
    print(f"Save to {save_path}")
        
    # execute inference
    results = []
    for _, batch in tqdm(enumerate(dataloader), total=len(dataloader)):
        output = retrieval(batch) 
        results.append(output)
        with open(save_path, 'a') as f:
            f.write(json.dumps(dict(output)) + '\n')


if __name__ == "__main__":
    fire.Fire(retrieve)
    
    