from accelerate import Accelerator
from accelerate.utils import gather_object
from torch.utils.data import DataLoader
import json
import fire
from tqdm import tqdm
import transformers
import torch
from pprint import pprint 



class GeneratorDataset:
    
    def __init__(self, 
                 month: int,
                 data_path: str,
                 baseline_mode: str ,
                 er_top_k: int = 1,
                ):
        self.month = month 
        self.data_path = data_path 
        try:
            self.dataset = [json.loads(i) for i in open(self.data_path).readlines()] # jsonlines
        except:
            self.dataset = json.load(open(self.data_path[:-1])) # json
        self.er_top_k = er_top_k
        
        self.baseline_mode = baseline_mode
        assert baseline_mode in ["lm", "ralm", "ralm_instruct_first", "ralm_cot_first", "ralm_cot_second", "er_lm", "er_ralm", "er_ralm_cot", "er_lm_can", "er_ralm_can",
                                 "er_lm_in_context", "er_ralm_in_context", "er_ralm_in_context_instruct_first", "er_ralm_mention_in_context"]
        if baseline_mode == "lm":
            self.generate_prompt = self._get_lm_prompt
        elif baseline_mode == "ralm":
            self.generate_prompt = self._get_ralm_prompt
        elif baseline_mode == "ralm_instruct_first":
            self.generate_prompt = self._get_ralm_instruct_first_prompt
        elif baseline_mode == "ralm_cot_first":
            self.generate_prompt = self._get_ralm_cot_first_prompt
        elif baseline_mode == "ralm_cot_second":
            self.generate_prompt = self._get_ralm_cot_second_prompt
        elif baseline_mode == "er_lm":
            self.generate_prompt = self._get_er_lm_prompt
        elif baseline_mode == "er_ralm":
            self.generate_prompt = self._get_er_ralm_prompt
        elif baseline_mode == "er_ralm_cot":
            self.generate_prompt = self._get_er_ralm_cot_prompt
        elif baseline_mode == "er_lm_can":
            self.generate_prompt = self._get_er_lm_can_prompt
        elif baseline_mode == "er_ralm_can":
            self.generate_prompt = self._get_er_ralm_can_prompt
        elif baseline_mode == "er_lm_in_context":
            self.generate_prompt = self._get_er_lm_in_context_prompt
        elif baseline_mode == "er_ralm_in_context":
            self.generate_prompt = self._get_er_ralm_in_context_prompt
        elif baseline_mode == "er_ralm_in_context_instruct_first":
            self.generate_prompt = self._get_er_ralm_in_context_instruct_first_prompt
        elif baseline_mode == "er_ralm_mention_in_context":
            self.generate_prompt = self._get_er_ralm_mention_in_context_prompt
        self._prepare_prompts()
        
        
    def __len__(self):
        return len(self.dataset)  
    
    def __getitem__(self, idx):
        return self.dataset[idx]["prompt"], self.dataset[idx]

    def _prepare_prompts(self):        
        for item in self.dataset:
            prompt = self.generate_prompt(item)
            item["prompt"] = prompt
    
    def _get_lm_prompt(self, item):
        prompt = "Given a question, please provide a short answer. \n\nQuestion: " + item['question'] + "\n\nAnswer:"
        return prompt
    
    def _get_ralm_prompt(self, item):
        prompt = "Context: " + item['context'] + "\nGiven a question, please provide a short answer. \n\nQuestion: " + item['question'] + "\n\nAnswer:"
        return prompt
    
    def _get_ralm_instruct_first_prompt(self, item):
        prompt = "Given a question, please provide a short answer. \n\nContext: " + item['context'] + "\n\nQuestion: " + item['question'] + "\n\nAnswer:"
        return prompt
    
    def _get_ralm_cot_first_prompt(self, item):
        question = item["question"]
        mention = question[item["mention_idx"][0] : item["mention_idx"][1] + 1]
        prompt = "Context: " + item['context'] + "\n\nQuestion: " + question + f" {mention[0].upper()}{mention[1:]} is"
        return prompt
    
    def _get_ralm_cot_second_prompt(self, item):
        first_prompt = item["prompt"]
        first_prediction = item["prediction"][0]["generated_text"][len(first_prompt):] # parse prediction
        first_prediction = first_prediction.split("\n")[0] # only consider the answer before \n
        if "." in first_prediction: # remove truncated answer
            fisrt_prediction = first_prediction[:first_prediction.rfind(".") + 1] 
        prompt = first_prompt + first_prediction + "\n\nAnswer:"
        return prompt
    
    def _get_er_lm_can_prompt(self, item):
        question = item["question"]
        insert_start_idx = item["mention_idx"][1] + 1
        entity = item["entity_pred"][0]
        question = question[:insert_start_idx] + f"(can be called as {entity})" + question[insert_start_idx:] # insert predictions with bracket into the question
        prompt = "Question: " + question + "\n\nAnswer:" # make a prompt
        return prompt
    
    def _get_er_lm_prompt(self, item):
        question = item["question"]
        mention = question[item["mention_idx"][0]:item["mention_idx"][1] + 1]
        entity = item["entity_pred"][0]
        prompt = f"The mention {mention} may also be referred to as {entity}. " +\
            "Given a question, please provide a short answer. " +\
            "\n\nQuestion: " + question + "\n\nAnswer:" # make a prompt
        return prompt  
    
    def _get_er_ralm_can_prompt(self, item):
        question = item["question"]
        insert_start_idx = item["mention_idx"][1] + 1
        entity = item["entity_pred"][0]
        question = question[:insert_start_idx] + f"(can be called as {entity})" + question[insert_start_idx:] # insert predictions with bracket into the question
        prompt = "Context: " + item['context'] + "\n\nQuestion: " + question + "\n\nAnswer:" # make a prompt
        return prompt 
    
    def _get_er_ralm_prompt(self, item): #_get_er_ralm_mention_in_context_prompt
        question = item["question"]
        mention = question[item["mention_idx"][0]:item["mention_idx"][1]]
        entity = item["entity_pred"][0]
        prompt = "Context: " + item['context'] + f"\nThe mention '{mention}' may also be referred to as '{entity}'. " +\
            "Given a question, please provide a short answer. \n\nQuestion: " + question + "\n\nAnswer:" # make a prompt
        return prompt
    
    def _get_er_ralm_cot_prompt(self, item):
        question = item["question"]
        mention = question[item["mention_idx"][0] : item["mention_idx"][1] + 1]
        prompt = "Context: " + item['context'] + "\n\nQuestion: " + question + f" {mention[0].upper()}{mention[1:]} is {item['entity_pred'][0]}." + "\n\nAnswer:"
        return prompt 
            
    def collate_fn(self, batch):
        return [i[0] for i in batch], [i[1] for i in batch]
    

def main(month: int,
         model: str,
         data_root: str,
         save_root: str,
         baseline_mode: str,
         batch_size: int = 8,
        ):
    
    # init model
    accelerator = Accelerator()
    rank = torch.distributed.get_rank()
    
    data_path = f"{data_root}/{month}.jsonl"
    dataset = GeneratorDataset(month, data_path, baseline_mode)
    dataloader = DataLoader(dataset, batch_size=batch_size, collate_fn=dataset.collate_fn)
    dataloader = accelerator.prepare(dataloader)
    
    assert model in ["llama2", "llama3"]
    pipeline = transformers.pipeline(
                            "text-generation",
                            model="meta-llama/Llama-2-7b-hf" if model=="llama2" 
                                    else "meta-llama/Meta-Llama-3-8B-Instruct",
                            device="cuda"
                            )
    
    if baseline_mode == "ralm_cot_first":
        temperature = 0.1
        max_new_tokens = 10
    else:
        temperature = 0.3
        max_new_tokens = 30
        
    # inference
    results = []
    for batch in tqdm(dataloader):
        prompts, datapoints = batch
        outputs = pipeline(
                        prompts,
                        pad_token_id = pipeline.tokenizer.eos_token_id,
                        temperature=temperature,
                        max_new_tokens=max_new_tokens
                        )
        for out, data in zip(outputs, datapoints):
            data['prediction'] = out 
            results.append(data)
            with open(f'{save_root}_{model}/{baseline_mode}/{month}_{rank}.jsonl', 'a') as f:
                f.write(json.dumps(data) + "\n")
            
    all = gather_object(results)
    all = list({i["qa_id"]: i for i in all}.values())
    with open(f"{save_root}_{model}/{baseline_mode}/{month}.json", "w") as f:
        json.dump(all, f, indent = 2)
        
        
        
if __name__ == "__main__":
    fire.Fire(main)