import os
import json
import sys
from baseline import Baseline
from rolefact import RoleFact
from chatgpt import ChatGPT, FastChat, LLAMA
from script_kb import ScriptKB
from tqdm.auto import tqdm
import argparse
    
def main():
    parser = argparse.ArgumentParser(description='Rolefact response generation')
    parser.add_argument('--model_name', type=str, default='gpt-3.5-turbo-0125', 
    choices=["alpaca-7b", "gpt-3.5-turbo-0125", "meta-llama/Meta-Llama-3-8B-Instruct", "lmsys/vicuna-7b-v1.5"])
    parser.add_argument('--script_folder', type=str, default='./data/sgr/kb')
    parser.add_argument('--task_folder', type=str, default='./data/sgr/eval_task/adversarial_interview')
    parser.add_argument('--ignore_profile', action='store_true')
    parser.add_argument('--ignore_retrieval', action='store_true')
    parser.add_argument('--ignore_temporal', action='store_true')
    parser.add_argument('--retrieval_type', type=str, default='bm25', choices=["bm25", "sbert", "contriever"])
    parser.add_argument('--num_docs', type=int, default=5)
    parser.add_argument('--sample_size', type=int, default=5)
    parser.add_argument('--threshold', type=float, default=0.6)
    parser.add_argument('--anonymize', action='store_true')

    args = parser.parse_args()

    with open("config/story_to_file.json",'r') as file:
        story_to_file = json.load(file)

    with open("config/prompts.json",'r') as file:
        prompts = json.load(file)
    
    model_name = args.model_name
    
    if(model_name.startswith("gpt")):
        model = ChatGPT(model_name)
    elif(model_name.startswith("alpaca")):
        model = FastChat(model_name)
    elif(model_name.startswith("lmsys")):
        model = FastChat(model_name.split("/")[-1])
    elif(model_name.startswith("meta-llama")):
        model = LLAMA(model_name)
    else:
        sys.exit("Invalid model name")    

    script_folder = args.script_folder
    task_folder = args.task_folder
    if(task_folder[-1]=="/"):
        task_folder = task_folder[:-1]
    
    model_tag = model_name.split("/")[-1]
    use_role_profile = not args.ignore_profile
    use_retrieval = not args.ignore_retrieval
    time_sensitive = not args.ignore_temporal
    retrieval_type = args.retrieval_type
    num_docs = args.num_docs
    sample_size = args.sample_size
    threshold = args.threshold
    anonymize = args.anonymize

    method_tag = str(use_role_profile)+"_"+str(use_retrieval)+"_"+str(time_sensitive)+"_"+str(retrieval_type)+"_"+str(num_docs)+"_"+str(sample_size)+"_"+str(threshold)+"_"+str(anonymize)

    output_folder = "responses/rolefact/"+method_tag+"/"+task_folder.split("/")[-1]+"/"+model_tag

    skip_files = []
    
    if(not os.path.exists(output_folder)):
        os.makedirs(output_folder)
    else:
        skip_files = os.listdir(output_folder)
        if(len(skip_files)>0):
            for file_name in skip_files:
                print("Skipping:",file_name)

    file_names = list(story_to_file["priority"].values())
    #file_names = ["M0507.json"]
    file_names = list(set(file_names).difference(set(skip_files)))

    def anonymize_prompt(text,role_name):
    
        prompt = prompts["anonymize_prompt"].replace("<role_name>",role_name)
        prompt+= "Prompt: "+text
        new_text = model.generate(prompt)
    
        return new_text
    
    success_count = 0
    failure_count = 0
    
    for file_name in tqdm(file_names):
    
        response_list = []
        story_kb = ScriptKB(script_folder+"/"+file_name)
    
        rolefact = RoleFact(
            llm=model,
            story_kb=story_kb,
            use_role_profile=use_role_profile,
            use_retrieval=use_retrieval,
            time_sensitive=time_sensitive,
            retrieval_type=retrieval_type,
            num_docs=num_docs,
            sample_size=sample_size,
            threshold=threshold,
            anonymize=anonymize
        )
    
        with open(task_folder+"/"+file_name,'r') as infile:
            task_list = json.load(infile)
    
        for task in tqdm(task_list):
    
            prelim_response,_,response = rolefact.get_response(task)
            response_obj = task.copy()
            response_obj["llm"] = model_tag
            response_obj["method"] = method_tag
            response_obj["responses"] = {}
            response_obj["responses"]["Rolefact"] = {
                "prelim_response": prelim_response,
                "response": response
            }
    
            response_list.append(response_obj)
       
            with open(output_folder+"/"+file_name,'w') as outfile:
        
                json.dump(response_list,outfile,indent=4,ensure_ascii=False)
    
        print(rolefact.get_cost())
        print(rolefact.success_count,rolefact.failure_count)

if __name__ == "__main__":
    main()