import json
from prompt import naive, Cot, ICL
import torch
import google.generativeai as genai
import random
genai.configure(api_key='AIzaSyCHGBqvOcEtq34AhNjWCoH-682e_ZEOFH8')
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

model = genai.GenerativeModel('gemini-pro')
import natsort


def load_data():
    with open('../Task_Data/vh_config.json', 'r') as f:
        config = json.load(f)
        
    with open('../Task_Data/TASK.json','r') as f:
        TASK = json.load(f)
        
    with open('Data/obs_all_dict.json', 'r') as f:
        obs_dict = json.load(f)

    with open('Data/cot_samples.json', 'r') as f:
        cot_samples = json.load(f)

    with open('Data/icl_samples.json', 'r') as f:
        icl_samples = json.load(f)

    with open('Data/done_trajectory.json', 'r') as f:
        done_traj = json.load(f)
        
    with open('Data/trajectory.json', 'r') as f:
        next_obs_dict = json.load(f)
           
    return config, TASK, obs_dict, next_obs_dict, done_traj, cot_samples, icl_samples

def object_processing( low_objects, history, grab_items, available):
    
    if len(grab_items) != 0:
        for grab_item in grab_items:
            if grab_item in low_objects:
                low_objects.remove(grab_item)

    objects_set = []
    
    for obj in low_objects:
        if obj in available:
            objects_set.append(obj)


    objects_set = list(set(objects_set))
    
    if len(objects_set) == 0:
        obj_prompt = "nothing"
    else:
        obj_prompt = ', '.join(objects_set)
    return obj_prompt

def grab_processing( history):

    grab_item = []
    if len(history)!=0:
        for i in history:
            if "grab" in i:
                _, item = i.split()
                grab_item.append(item)
                
                
            if "put in" in i:
                put, putin ,obj1, obj2  = i.split()
                if obj1 in grab_item:
                    grab_item.remove(obj1)
                    
            elif "put" in i and "put in" not in i:
                put, obj1, obj2  = i.split()
                if obj1 in grab_item:
                    grab_item.remove(obj1)
                    
    
    if len(grab_item) == 0:
        grab_prompt = "nothing"
    else:
        grab_prompt =", ".join(grab_item)
    
    return grab_prompt, grab_item

def search_next_obs(objs, task_id, history, skill, action=None):
    skill = skill.replace(' ', '_')
    
    if action == None:
        action = skill
    
    action =  action.replace(' ', '_')
    if len(history) >= 1:
        ##### when history = [A, B], get B's observation 
        obs_skill = history[-1].replace(' ', '_')
        searching_path = f'EMNLP/EMNLP_DATA/env0/{task_id}/{skill}/{action}.png'
        
    elif len(history) == 0:
        searching_path = f'EMNLP/EMNLP_DATA/env0/Init/{action}.png'


    for line in objs:
        if line["path"] == searching_path:
            obj = line["visible"]
            return obj

    return obj

def search_obs(objs, task_id, history):

    if len(history) > 1:
        ##### when history = [A, B], get B's observation 
        obs_skill = history[-1].replace(' ', '_')
        searching_path = f'EMNLP/EMNLP_DATA/env0/{task_id}/{obs_skill}/{obs_skill}.png'
        
    elif len(history) == 1:
        obs_skill = history[-1].replace(' ', '_')
        searching_path = f'EMNLP/EMNLP_DATA/env0/Init/{obs_skill}.png'
    else:
        obj = ['character', 'folder', 'wallpictureframe', 'cabinet', 'livingroom', 'bookshelf']
        return obj
    
    for line in objs:
        if line["path"] == searching_path:
            obj = line["visible"]
            return obj
        

def naive(instruction, history, skill_list,  objects, grabs, args):
    
    prompt_text = naive(instruction, history, skill_list, objects, grabs)
    completion = model.generate_content(
        prompt_text,
        generation_config=genai.types.GenerationConfig(
            temperature=0.7)
        )
    low_answer = completion.text
    return low_answer

def cot(task_id, instruction, history, skill_list,  objects, grabs, cot_samples):
        
    sample_list = []
    for key, val in cot_samples.items():
        if key != str(task_id): 
            sample_list.append(val)
            
    sample_list = random.sample(sample_list, 3)
    prompt_text = Cot(instruction, history, skill_list, objects, grabs, sample_list)
    
    completion = model.generate_content(
        prompt_text,
        generation_config=genai.types.GenerationConfig(
            temperature=0.7),

        )

    low_answer = completion.text

    return low_answer

def icl(task_id, instruction, history, skill_list,  objects, grabs, icl_samples):
    sample_list = []
    for key, val in icl_samples.items():
        if key != str(task_id): 
            sample_list.append(val)
    sample_list = random.sample(sample_list, 3)
    prompt_text = ICL(instruction, history, skill_list, objects, grabs, sample_list)
    completion = model.generate_content(
        prompt_text,
        generation_config=genai.types.GenerationConfig(
            temperature=0.7),

        )
    low_answer = completion.text
    return low_answer

def process_skill_transitions(task_id, task_info, obs_dict, VH_SKILL, args, available, cot_samples, icl_samples):

    task_description = task_info["task_description"]
    skill_plan = task_info["skill_plan"]
    env_id ="env0"
    datas = []
    for skill in skill_plan[:1]:
        skill_index_in_plan = skill_plan.index(skill)
        history = skill_plan[:skill_index_in_plan]

        grab_prompt, grab_item = grab_processing(history)
        low_objs = search_obs(obs_dict, task_id, history)
        objects = object_processing(low_objs, history, grab_item, available)
        VH_SKILL = natsort.natsorted(VH_SKILL)

        skill_list = f", ".join(VH_SKILL)

        if len(history) != 0:
            history_prompt = ""
            for idx, his in enumerate(history):
                history_prompt += f"{idx+1}. {his}, "
            history_prompt = history_prompt[:-2]
            
        else:
            history_prompt = "nothing"
      
        if args.prompt == "naive":
            answer_dict = naive(task_description, history_prompt, skill_list, objects, grab_prompt, args)
            
        elif args.prompt == "cot":
            answer_dict = cot(task_id, task_description, history, skill_list,  objects, grab_prompt, cot_samples)
        
        elif args.prompt == "icl":
            answer_dict = icl(task_id, task_description, history, skill_list,  objects, grab_prompt, icl_samples)

        print(answer_dict)
        data = {
            "task_id" : task_id,
            "instruction": task_description,
            "history": history,
            "env_id": env_id,
            "answer": answer_dict,
            "visible": objects
        }
        
        datas.append(data) 

    return datas

import argparse
if __name__ == "__main__":
    
    parser = argparse.ArgumentParser(description="palm")
    parser.add_argument(
        "--model", type=str, default="gemini" 
    ) 
    parser.add_argument(
        "--prompt", type=str, default="naive"  # [naive, icl, cot]
    )   
    args = parser.parse_args()
    
    config, TASK, obs_dict, next_obs_dict, done_traj, cot_samples, icl_samples = load_data()

    datas = []
    VH_SKILL = config["VH_SKILL"]
    available = config["available"]
    for task_info in TASK[:1]:
        task_id = task_info["task_id"]
        
        data = process_skill_transitions(task_id, task_info, obs_dict, VH_SKILL,args, available, cot_samples, icl_samples)
        datas+= data

        
    with open(f'output/{args.prompt}.json', 'w') as f:
        json.dump(datas, f, indent=4)
