from dataclasses import dataclass, field
from typing import Dict, Optional
from common.vh_invariants import n_skills, available, id2skill
from config.args import BaseDataArguments, BaseModelArguments, BaseTrainingArguments
from models.mm_student import SkillDecoder
from models.multimodal_encoders import VitBertMultiModalEncoderForCaption
import sys
import transformers
import jax
import json
import torch

sys.path.append('Virtualhome')
from Virtualhome.environment.unity_environment import UnityEnvironment
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

@dataclass
class ModelArguments(BaseModelArguments):
    
    skill_decoder_cfg: Optional[Dict] = field(
        default_factory=lambda: {
            "mode": "rl",
            "arch": "transformer",
            "gpt2_config": {
                "vocab_size": 1,  # Doesn't matter
                "n_positions": 768 * 2,
                "n_layer": 2,
                "n_head": 4,
                "activation_function": "relu",
                "resid_pdrop": 0.1,
                "embd_pdrop": 0.1,
                "attn_pdrop": 0.1,
                "layer_norm_epsilon": 0
            },
            "multimodal_embed_dim": 768 ,
            "n_skills": n_skills,
            "lr": 1e-4,
            "net_arch": [512, 512]
        }
    )

@dataclass
class TrainingArguments(BaseTrainingArguments):
    training_mode: Optional[str] = field(default="rl")
    per_device_train_batch_size: Optional[str] = field(default=16)
    per_device_eval_batch_size: Optional[str] = field(default=16)
    logging_steps: Optional[int] = field(default=10)
    eval_steps: Optional[int] = field(default=10)
    num_train_epochs: Optional[int] = field(default=10000)
    save_steps: Optional[int] = field(default=50000)
    
    
def grab_processing(history):
    grab_item = []
    if len(history)!=0:
        for i in history:
            if "grab" in i:
                _, item = i.split()
                grab_item.append(item)
            if "put in" in i:
                _, _, item, _  = i.split()
                if item in grab_item:
                    grab_item.remove(item)
                    
            elif "put" in i:
                _, item, _  = i.split()
                if item in grab_item:
                    grab_item.remove(item)
                    
    if len(grab_item) == 0:
        grab_prompt = "nothing"
    else:
        grab_prompt =", ".join(grab_item)
    
    return grab_prompt, grab_item

def object_processing( objects_list, history, grab_items):
    
    processing = []
    for i in objects_list:
        if i in available:
            processing.append(i)

    if len(history)!=0 and "find" in history[-1]:
        _, item = history[-1].split()
        
        if item not in processing:
            processing.append(item)
    
    objects_set = processing
    if len(grab_items) != 0:
        for grab_item in grab_items:
            if grab_item in objects_set:
                objects_set.remove(grab_item)

    objects_set = set(objects_set)
    
    obj_prompt = ', '.join(objects_set)

    return obj_prompt
    
def program():
    
    ############################# MODEL & ENV #############################
    parser = transformers.HfArgumentParser((TrainingArguments, ModelArguments))
    train_args, model_args = parser.parse_args_into_dataclasses()  # type. ModelArguments, DataArguments
    multimodal_encoder = VitBertMultiModalEncoderForCaption(model_args.multimodal_encoder_cfg)
    skill_decoder = SkillDecoder(seed=train_args.seed, cfg=model_args.skill_decoder_cfg, init_build_model=True)

    model_file ="YOUR_MODEL_PATH"
    skill_decoder = skill_decoder.load(model_file)
    skill_decoder.multimodal_encoder = multimodal_encoder
    env = UnityEnvironment(base_port=8080, num_agents=1, max_episode_length=10)

    ############################# DATA #############################
    with open('../Task_Data/Sample.json','r') as f:
        test_data = json.load(f)
    print(f"len test dataset: {len(test_data)}")  
    
    task_id_list = set([task_list["task_id"] for task_list in test_data])

    ############################# CONFIG ###########################
    success_rate = 0
    cgc = 0
    planning = 0
    max_length = 10
    total_result = {}
    
    for i in task_id_list:
        total_result[str(i)] = {
            "success_rate": 0,
            "cgc": 0,
            "planning": 0,
            "data_len": 0,
        }

    his_list = {}

    for idx, data in enumerate(test_data):
        instruction = data["instruction"]
        goal = data["goal"]
        plan = data["plan"]
        task_num = str(data["task_id"])
        
        if task_num not in his_list.keys():
            his_list[task_num] = {}
        if instruction not in his_list[task_num].keys():
            his_list[task_num][instruction] = []
        ### virtualhome ####                
        obs = env.reset(environment_id=0,init_rooms=["bedroom"])
        _,_,_,_ = env.step("find wallpictureframe")
        caption = list(env.get_visible_objects()[1].values())
        history = []
        success = False
    
        for step in range(max_length):
            if len(history) == 0:
                history_prompt = ""
            else:
                history_prompt =  ", ".join(history) + "."

            grab_prompt, grab_items = grab_processing(history)

            obj_prompt = object_processing(caption, history, grab_items)
            obj_prompt  = f"Visible objects: {obj_prompt}\nGrabbed: {grab_prompt}"

            if history != [] and "find" in history[-1]:
                history_prompt += "\nAvailable skill type: grab, put, sit."
    
            prompts = f"Instruction: {instruction}\nHistory: {history_prompt}\nNext Skill: [MASK]"
            predictions = skill_decoder.predict_action(prompts=[prompts], captions=[obj_prompt], deterministic=True)

            all_q_values = jax.nn.softmax(predictions[0], axis=-1)
            
            sorted_pair = sorted(enumerate(all_q_values), key=lambda x:x[1], reverse=True)
            rank_total = [index for index, value in sorted_pair]
            ranking_label = [id2skill[id] for id in rank_total]

            skill = ranking_label[0]
            print(skill)

            try:
                next_obs, reward, done, info = env.step(skill)
                history.append(skill)
            except: 
                print("can't action")
                break 
            
            caption = list(env.get_visible_objects()[1].values())

            success = env.goalCond(goal)

            if success == 1 :
                break
            
        print(f"{instruction}: [{history}]")
        ########### one ins done ##########   
        total_result[task_num]["cgc"] += env.goalCond(goal)
        total_result[task_num]["data_len"] += 1
        cgc += env.goalCond(goal)
        success = env.goalCond(goal)
        if success == 1:
            success_rate += success
            total_result[task_num]["success_rate"] += success
        plan_match = 0
        if len(history) != 0:
            for idx, p_idx in enumerate(history[:len(plan)]):
                if p_idx == plan[idx]:
                    plan_match += 1
                else: break
            
            planning += plan_match/len(plan)
            total_result[task_num]["planning"] +=  plan_match/len(plan)

        his_list[task_num][instruction].append(history)
    
    #################    total    #################
    for task_num, metric  in total_result.items():
        data_len = total_result[task_num]["data_len"]
        metric["success_rate"] = total_result[task_num]["success_rate"]/data_len * 100
        metric["cgc"] = total_result[task_num]["cgc"]/data_len * 100
        metric["planning"] = total_result[task_num]["planning"]/data_len * 100
        total_result[task_num] = metric

            
    print("---------------------------------------")
    print(f'success_rate: {success_rate/len(test_data) * 100}')
    print(f'cgc: {cgc/len(test_data) *  100}')
    print(f'plannig: {planning/len(test_data) *  100}')


if __name__ == "__main__":
    program()
