# import google.generativeai as genai
import json
from prompt import rag_robot, rag_general
import google.generativeai as genai
import pickle
import argparse

from sentence_transformers import SentenceTransformer, util
rag_model = SentenceTransformer('paraphrase-MiniLM-L6-v2')

genai.configure(api_key='YOUR_API_KEY')

model_id = 'models/text-bison-001'


def load_data(args):
    with open('../Task_Data/vh_config.json', 'r') as f:
        config = json.load(f)
        
    with open('../Task_Data/TASK.json','r') as f:
        TASK = json.load(f)
        
    with open('Data/obs_all_dict.json', 'r') as f:
        obs_dict = json.load(f)

    with open(f'Data/DB_{args.prompt}.pkl', 'rb') as f:
        rag_db = pickle.load(f)

    return config, TASK, obs_dict, rag_db

def object_processing( low_objects, grab_items, available):
    
    if len(grab_items) != 0:
        for grab_item in grab_items:
            if grab_item in low_objects:
                low_objects.remove(grab_item)

    objects_set = []
    
    for obj in low_objects:
        if obj in available:
            objects_set.append(obj)


    objects_set = list(set(objects_set))
    
    if len(objects_set) == 0:
        obj_prompt = "nothing"
    else:
        obj_prompt = ', '.join(objects_set)
    return obj_prompt

def grab_processing( history):

    grab_item = []
    if len(history)!=0:
        for i in history:
            if "grab" in i:
                _, item = i.split()
                grab_item.append(item)
                
                
            if "put in" in i:
                put, putin ,obj1, obj2  = i.split()
                if obj1 in grab_item:
                    grab_item.remove(obj1)
                    
            elif "put" in i and "put in" not in i:
                put, obj1, obj2  = i.split()
                if obj1 in grab_item:
                    grab_item.remove(obj1)
                    
    
    if len(grab_item) == 0:
        grab_prompt = "nothing"
    else:
        grab_prompt =", ".join(grab_item)
    
    return grab_prompt, grab_item

SAFETY_SETTINGS = [
            {
            "category": "HARM_CATEGORY_HARASSMENT",
            "threshold": "BLOCK_NONE"
            },
            {
            "category": "HARM_CATEGORY_HATE_SPEECH",
            "threshold": "BLOCK_NONE"
            },
            {
            "category": "HARM_CATEGORY_SEXUALLY_EXPLICIT",
            "threshold": "BLOCK_NONE"
            },
            {
            "category": "HARM_CATEGORY_DANGEROUS_CONTENT",
            "threshold": "BLOCK_NONE"
            }
        ]


def search_obs(objs, task_id, history):
    if len(history) > 1:
        obs_skill = history[-1].replace(' ', '_')
        searching_path = f'EMNLP/EMNLP_DATA/env0/{task_id}/{obs_skill}/{obs_skill}.png'
        
    elif len(history) == 1:
        obs_skill = history[-1].replace(' ', '_')
        searching_path = f'EMNLP/EMNLP_DATA/env0/Init/{obs_skill}.png'
    else:
        obj = ['character', 'folder', 'wallpictureframe', 'cabinet', 'livingroom', 'bookshelf']
        return obj
    
    for line in objs:
        if line["path"] == searching_path:
            obj = list(set(line["visible"])) 
            return obj
        
def cal_rag(task_id, current_prompt, db, args):
        
    similiarity = []

    for idx, data in enumerate(db): 
        if str(data["task_id"]) != str(task_id):
            distance = 0
                    
            distance += util.pytorch_cos_sim(current_prompt["instruction_emb"], data["embedding"]["instruction_emb"])[0][0]
            distance += util.pytorch_cos_sim(current_prompt["history_emb"], data["embedding"]["history_emb"])[0][0]
            distance += util.pytorch_cos_sim(current_prompt["obj_emb"], data["embedding"]["obj_emb"])[0][0]
            distance += util.pytorch_cos_sim(current_prompt["grab_emb"], data["embedding"]["grab_emb"])[0][0]
            similiarity.append(distance/4)
        else:
            similiarity.append(-99999)

    rag_idx = sorted(range(len(similiarity)), key=similiarity.__getitem__, reverse=True)[:3]
    print(rag_idx)
    rag_prompts = []
    for idx in rag_idx:
        prompt = db[idx]["prompt"]
        answer = db[idx]["answer"]
        rationale = db[idx]["rationale"]
        
        if args.prompt == 'rag_robot':
            rag_prompt = f"{prompt}\n\nRobot: {answer}Relevant objects: {rationale}"
        elif args.prompt == 'rag_general':
            
            prompt = prompt.replace('Human:\n', '')
            rag_prompt = f"{prompt}\nResponse: \n{answer}Relevant objects: {rationale}"

        rag_prompts.append(rag_prompt)
        
    print(rag_prompts)
    return  rag_prompts
    
def process_skill_transitions(task_id, task_info, obs_dict, args, db, available):

    task_description = task_info["task_description"]
    skill_plan = task_info["skill_plan"]
    datas = []
    for skill in skill_plan:
        skill_index_in_plan = skill_plan.index(skill)
        history = skill_plan[:skill_index_in_plan]

        grab_prompt, grab_item = grab_processing(history)
        low_objs = search_obs(obs_dict, task_id, history)
        objects = object_processing(low_objs, grab_item, available)

        if len(history) != 0:
            history_prompt = ""
            for idx, his in enumerate(history):
                history_prompt += f"{idx+1}. {his}, "
            history_prompt = history_prompt[:-2]
            
        else:
            history_prompt = "nothing"
            
        instruction_emb = rag_model.encode(task_description)
        history_emb = rag_model.encode(history_prompt)
        obj_emb = rag_model.encode(objects)
        grab_emb = rag_model.encode(grab_prompt)


        prompt_emb = {
            "instruction_emb": instruction_emb,
            "history_emb": history_emb,
            "obj_emb": obj_emb,
            "grab_emb": grab_emb
        }

        rags = cal_rag(task_id, prompt_emb, db, args)

        if args.prompt == 'rag_robot':
            current_prompt = f"Human:\nTask Description: {task_description}\nPreviously Completed Actions: {history_prompt}\nVisible objects: {objects}\nGrabbed: {grab_prompt}"
            input_prompt = rag_robot(current_prompt, rags)
        elif args.prompt == 'rag_general':
            current_prompt = f"Task Description: {task_description}\nPreviously Completed Actions: {history_prompt}\nVisible objects: {objects}\nGrabbed: {grab_prompt}"
            input_prompt = rag_general(task_description, history_prompt, objects, grab_prompt,  rags)


        completion = genai.generate_text(
            model=model_id,
            prompt=input_prompt,
            temperature=0.7
        )
        answer = completion.result
        data = {
            "task_id" : task_id,
            "instruction": task_description,
            "history": history,
            "answer": answer,
            "visible": objects
        }
        datas.append(data)

    return datas

if __name__ == "__main__":
    
    parser = argparse.ArgumentParser(description="palm")
    parser.add_argument(
        "--prompt", type=str, default="rag_general"
    )   
    args = parser.parse_args()
    config, TASK, obs_dict, db = load_data()

    datas = []
    available = config["available"]
    for task_info in TASK:
        task_id = task_info["task_id"]
        data = process_skill_transitions(task_id, task_info, obs_dict, args, db, available)
        datas+= data
    with open(f'output/{args.prompt}_all.json', 'w') as f:
        json.dump(datas, f, indent=4)

