import json
import os
import pickle
import torch
import argparse
from sentence_transformers import SentenceTransformer
rag_model = SentenceTransformer('paraphrase-MiniLM-L6-v2')
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

def load():
    with open('Data/human.json', 'r') as f:
        expert = json.load(f)
        
    with open('Data/trajectory.json', 'r') as f:
        objs = json.load(f)
        
    with open('../Task_Data/vh_config.json', 'r') as f:
        config = json.load(f)   
    
    with open('Data/MDP_dataset.json', 'r') as f:
        dataset = json.load(f) 
    
    return expert, objs, config, dataset

def rag_robot_prompt(instruction, history, objects, grab): 
    return \
f"""Human:
Task Description: {instruction}
Previously Completed Actions: {history}
Visible objects: {objects}
Grabbed: {grab}"""

def rag_general_prompt(instruction, history, answer,  rationale): 
    return \
f"""Task Description: {instruction}
Previously Completed Actions: {history}
Response: 
{answer}

Relevant objects: {rationale}"""

def object_processing( low_objects, history, grab_items, available):
    
    if len(grab_items) != 0:
        for grab_item in grab_items:
            if grab_item in low_objects:
                low_objects.remove(grab_item)

    objects_set = []
    for obj in low_objects:
        if obj in available:
            objects_set.append(obj)

    objects_set = list(set(objects_set))

    if len(objects_set) == 0:
        obj_prompt = "nothing"
    else:
        obj_prompt = ', '.join(objects_set)
        
    return obj_prompt

def grab_processing( history):

    grab_item = []
    if len(history)!=0:
        for i in history:
            if "grab" in i:
                _, item = i.split()
                grab_item.append(item)
                
                
            if "put in" in i:
                put, putin ,obj1, obj2  = i.split()
                if obj1 in grab_item:
                    grab_item.remove(obj1)
                    
            elif "put" in i and "put in" not in i:
                put, obj1, obj2  = i.split()
                if obj1 in grab_item:
                    grab_item.remove(obj1)
                    
    if len(grab_item) == 0:
        grab_prompt = "nothing"
    else:
        grab_prompt =", ".join(grab_item)
    
    return grab_prompt, grab_item

def search_obs(objs, task_id, history):

    if len(history) > 1:
        obs_skill = history[-1].replace(' ', '_')
        searching_path = f'EMNLP/EMNLP_DATA/env0/{task_id}/{obs_skill}/{obs_skill}.png'
        
    elif len(history) == 1:
        obs_skill = history[-1].replace(' ', '_')
        searching_path = f'EMNLP/EMNLP_DATA/env0/Init/{obs_skill}.png'
    else:
        obj = ['character', 'folder', 'wallpictureframe', 'cabinet', 'livingroom', 'bookshelf']
        return obj
    
    for line in objs:
        if line["path"] == searching_path:
            obj = list(set(line["visible"])) 
            return obj

def search_rationale(dataset, task_id, history):
    rationale = ""
    for line in dataset:
        if task_id == line["task_id"] and history == line["history"]:
            rationale = line["answer"]
            break     
    if rationale=="": 
        print("ERROR")
        exit()
    return rationale

def save(data, args):
    with open(f'Data/DB_{args.prompt}.pkl', 'wb') as f:
        pickle.dump(data,f)
    
if __name__ == "__main__":
    
    parser = argparse.ArgumentParser(description="palm")
    parser.add_argument(
        "--prompt", type=str, default="rag_general"
    )   
    args = parser.parse_args()
    
    
    total = {}
    expert, objs, config, dataset = load()
    available = config["available"]
    VH_SKILL = config["VH_SKILL_FOR_PROMPT"]
    data = []
    
    with open('../Task_Data/TASK.json', 'r') as f:
        human_reward = json.load(f)
    
    for expert_idx, file in enumerate(expert):
        task_id = file["task_id"] 
    
        if task_id not in total.keys():
            total[task_id] = []
    
        history = file["history"]
        task_description = file["instruction"]
        answer = file["reward"]
        
        answer_dict = human_reward[task_id]["reward"][len(history)]["reward"]
        reward_dict = dict(sorted(answer_dict.items(), key=lambda item: item[1], reverse=True))
        
        answer_prompt = ""
        for key, val in reward_dict.items():
            
            if "put in " in key:
                put, how, obj1, obj2 = key.split()
                key = f"{put} {obj1} {how} {obj2}"
            elif "put" in key:
                put, obj1, obj2 = key.split()
                key = f"{put} {obj1} on {obj2}"
                
            answer_prompt += f"{key}: {val}\n"
        answer_prompt += "remainig find skills: -1\nremainig other skills: -2"

        if len(history) != 0:
            history_prompt = ""
            for idx, his in enumerate(history):
                history_prompt += f"{idx+1}. {his}, "
            history_prompt = history_prompt[:-2]
            
        else:
            history_prompt = "nothing"
            
        grab, grab_item = grab_processing(history)
        low_objs = search_obs(objs, task_id, history)
        objects = object_processing(low_objs, history, grab_item, available)
        
        rationale = search_rationale(dataset, task_id, history)
        
        
        if args.prompt == "rag_robot":
            prompt = rag_robot_prompt(task_description, history_prompt, answer_prompt, rationale)
        else:
            prompt = rag_general_prompt(task_description, history_prompt, answer_prompt, rationale)
        
        instruction_emb = rag_model.encode(task_description)
        history_emb = rag_model.encode(history_prompt)
        obj_emb = rag_model.encode(objects)
        grab_emb = rag_model.encode(grab)

        prompt_emb = {
            "instruction_emb": instruction_emb,
            "history_emb": history_emb,
            "obj_emb": obj_emb,
            "grab_emb": grab_emb
        }
        
        data_form = {
            "task_id": task_id,
            "prompt": prompt, 
            "embedding": prompt_emb,
            "answer": answer_prompt,
            "rationale": rationale
        }
        data.append(data_form)     
    save(data, args)
            
            
        