import numpy as np
import os
import sys
current_dir = os.path.dirname(os.path.abspath(__file__))
metaworld_dir = os.path.join(current_dir, 'metaworld')
if metaworld_dir not in sys.path:
    sys.path.insert(0, metaworld_dir)
import torch
from metaworld import policies
from sentence_transformers import SentenceTransformer
from metaworld.envs import ALL_V2_ENVIRONMENTS_GOAL_OBSERVABLE as env_dict
import argparse
from typing import List
import imageio
import random
from utils import agent_step,get_state,get_action,get_action_space,get_f_language,get_h_language
envList=['assembly-v2-goal-observable', 'basketball-v2-goal-observable', 'bin-picking-v2-goal-observable', 'box-close-v2-goal-observable', 'button-press-topdown-v2-goal-observable', 'button-press-topdown-wall-v2-goal-observable', 'button-press-v2-goal-observable', 'button-press-wall-v2-goal-observable', 'coffee-button-v2-goal-observable', 'coffee-pull-v2-goal-observable', 'coffee-push-v2-goal-observable', 'dial-turn-v2-goal-observable', 'disassemble-v2-goal-observable', 'door-close-v2-goal-observable', 'door-lock-v2-goal-observable', 'door-open-v2-goal-observable', 'door-unlock-v2-goal-observable', 'hand-insert-v2-goal-observable', 'drawer-close-v2-goal-observable', 'drawer-open-v2-goal-observable', 'faucet-open-v2-goal-observable', 'faucet-close-v2-goal-observable', 'hammer-v2-goal-observable', 'handle-press-side-v2-goal-observable', 'handle-press-v2-goal-observable', 'handle-pull-side-v2-goal-observable', 'handle-pull-v2-goal-observable', 'lever-pull-v2-goal-observable', 'peg-insert-side-v2-goal-observable', 'pick-place-wall-v2-goal-observable', 'pick-out-of-hole-v2-goal-observable', 'reach-v2-goal-observable', 'push-back-v2-goal-observable', 'push-v2-goal-observable', 'pick-place-v2-goal-observable', 'plate-slide-v2-goal-observable', 'plate-slide-side-v2-goal-observable', 'plate-slide-back-v2-goal-observable', 'plate-slide-back-side-v2-goal-observable', 'peg-unplug-side-v2-goal-observable', 'soccer-v2-goal-observable', 'stick-push-v2-goal-observable', 'stick-pull-v2-goal-observable', 'push-wall-v2-goal-observable', 'reach-wall-v2-goal-observable', 'shelf-place-v2-goal-observable', 'sweep-into-v2-goal-observable', 'sweep-v2-goal-observable', 'window-open-v2-goal-observable', 'window-close-v2-goal-observable']

MT_10=['llf-metaworld-window-open-v2', 'llf-metaworld-window-close-v2','llf-metaworld-door-open-v2','llf-metaworld-peg-insert-side-v2','llf-metaworld-drawer-open-v2','llf-metaworld-pick-place-v2','llf-metaworld-reach-v2','llf-metaworld-button-press-topdown-v2','llf-metaworld-push-v2','llf-metaworld-drawer-close-v2']

MT_10_modify=[
# Common Open and Close
    'llf-metaworld-window-open-v2', 'llf-metaworld-window-close-v2','llf-metaworld-door-open-v2','llf-metaworld-door-close-v2','llf-metaworld-drawer-close-v2', 'llf-metaworld-drawer-open-v2','llf-metaworld-faucet-open-v2', 'llf-metaworld-faucet-close-v2',

# Common Pull and Push
    'llf-metaworld-coffee-pull-v2', 'llf-metaworld-coffee-push-v2','llf-metaworld-stick-push-v2', 'llf-metaworld-stick-pull-v2','llf-metaworld-push-back-v2', 'llf-metaworld-push-v2',
    
# Common Press
    'llf-metaworld-button-press-v2','llf-metaworld-button-press-topdown-v2','llf-metaworld-handle-press-side-v2', 'llf-metaworld-handle-press-v2','llf-metaworld-button-press-topdown-wall-v2','llf-metaworld-button-press-wall-v2',
    
# Common Pick and Place
    'llf-metaworld-bin-picking-v2', 'llf-metaworld-pick-place-wall-v2', 'llf-metaworld-pick-out-of-hole-v2','llf-metaworld-pick-place-v2','llf-metaworld-shelf-place-v2'
]

def process_data(trajectory,name):
        rewards = trajectory["reward"]
        returns = []
        if (len(trajectory["languages"])==0):
            trajectory["languages"]=[""]*len(rewards)
        returns = []
        gamma=1
        cumulative_return = 0
        for reward in reversed(rewards):
            cumulative_return = reward + gamma * cumulative_return
            returns.append(cumulative_return)
        returns.reverse()
        trajectory["return_to_go"] = returns
        hf,f,h,hf1000,hf500,hf200,hf100=[],[],[],[],[],[],[]
        for text in trajectory["languages"]:
            if(text=={}):
                text={"hp":"", "hn":"", "fp":""}
            hf.append(text["hp"]["1500"]+text["hn"]["1500"]+text["fp"]["1500"])
            hf1000.append(text["hp"]["1000"]+text["hn"]["1000"]+text["fp"]["1000"])
            hf500.append(text["hp"]["500"]+text["hn"]["500"]+text["fp"]["500"])
            hf200.append(text["hp"]["200"]+text["hn"]["200"]+text["fp"]["200"])
            hf100.append(text["hp"]["100"]+text["hn"]["100"]+text["fp"]["100"])
            f.append(text["fp"]["1500"])
            h.append(text["hp"]["1500"]+text["hn"]["1500"])
        trajectory["rhf_language"]=hf
        trajectory["hf1000_language"]=hf1000
        trajectory["hf500_language"]=hf500
        trajectory["hf200_language"]=hf200
        trajectory["hf100_language"]=hf100
        trajectory["f_language"]=f
        trajectory["h_language"]=h
        l=hf+f+h+hf1000+hf500+hf200+hf100
        with torch.no_grad():  # No gradient is needed (inference mode)
            l_embedding=torch.tensor(model.encode(l)).reshape(7,len(trajectory["languages"]),1,768)
            hf_embedding=l_embedding[0]
            f_embedding=l_embedding[1]
            h_embedding=l_embedding[2]
            hf_embedding_1000=l_embedding[3]
            hf_embedding_500=l_embedding[4]
            hf_embedding_200=l_embedding[5]
            hf_embedding_100=l_embedding[6]
        assert hf_embedding.shape == (len(trajectory["languages"]), 1, 768)
        manual_embedding = torch.unsqueeze(torch.tensor(model.encode(trajectory["manual"])), dim=0)
        trajectory["encoded_manual"]=manual_embedding
        trajectory["hf_embedding"]=hf_embedding
        trajectory["f_embedding"]=f_embedding
        trajectory["h_embedding"]=h_embedding
        trajectory["hf_embedding_1000"]=hf_embedding_1000
        trajectory["hf_embedding_500"]=hf_embedding_500
        trajectory["hf_embedding_200"]=hf_embedding_200
        trajectory["hf_embedding_100"]=hf_embedding_100
        torch.save(trajectory,name)

def get_task_text(env_name):
    name = " ".join(env_name.split('-')[:-3])
    return name

def transform(strings:List[str])->List[str]:
    for i in range(len(strings)):
        strings[i]=strings[i][0].upper()+strings[i][1:]
    return strings        

def get_policy(env_name):
    name = "".join(transform(get_task_text(env_name).split(" ")))
    # print(name)
    name=name.replace("Insert","Insertion")
    policy_name = "Sawyer" + name + "V2Policy"   
    # print(policy_name)
    try:
        policy = getattr(policies, policy_name)()
    except:
        policy = None
    return policy

os.environ["TOKENIZERS_PARALLELISM"] = "false"
def process_data(trajectory,name):
        rewards = trajectory["reward"]
        returns = []
        if (len(trajectory["languages"])==0):
            trajectory["languages"]=[""]*len(rewards)
        returns = []
        gamma=1
        cumulative_return = 0
        for reward in reversed(rewards):
            cumulative_return = reward + gamma * cumulative_return
            returns.append(cumulative_return)
        returns.reverse()
        trajectory["return_to_go"] = returns
        f,h,hf,rhf=[],[],[],[]
        rf,rh=[],[]
        for text in trajectory["languages"]:
            f.append(text["f"][1])
            h.append(text["h"][1])
            rf.append(text["f"][0])
            rh.append(text["h"][0])
            random_probability=random.randint(0,100)
            if random_probability<70:
                hf.append(text["h"][1]+text["f"][1])
                rhf.append(text["h"][0]+text["f"][0])
            elif random_probability <85:
                rhf.append(text["h"][0])
                hf.append(text["h"][1])
            else:
                rhf.append(text["f"][0])
                hf.append(text["f"][1])
        trajectory["f_language"]=f
        trajectory["h_language"]=h
        trajectory["hf_language"]=hf
        trajectory["rhf_language"]=rhf
        l=f+h+hf+rhf+rf+rh
        with torch.no_grad():  # No gradient is needed (inference mode)
            l_embedding=torch.tensor(model.encode(l)).reshape(6,len(trajectory["languages"]),1,768)
            f_embedding=l_embedding[0]
            h_embedding=l_embedding[1]
            hf_embedding=l_embedding[2]
            rhf_embedding=l_embedding[3]
            rf_embedding=l_embedding[4]
            rh_embedding=l_embedding[5]
        assert f_embedding.shape == (len(trajectory["languages"]), 1, 768)
        manual_embedding = torch.unsqueeze(torch.tensor(model.encode(trajectory["manual"])), dim=0)
        trajectory["encoded_manual"]=manual_embedding
        trajectory["f_embedding"]=f_embedding
        trajectory["h_embedding"]=h_embedding
        trajectory["hf_embedding"]=hf_embedding
        trajectory["rhf_embedding"]=rhf_embedding
        trajectory["rf_embedding"]=rf_embedding
        trajectory["rh_embedding"]=rh_embedding
        torch.save(trajectory,name)

def calculate_angle(curr_pos,expert_target_pos,agent_target_pos):
    vector_expert = expert_target_pos - curr_pos
    vector_agent = agent_target_pos - curr_pos
    if (np.linalg.norm(vector_expert))!=0 and np.linalg.norm(vector_agent)!=0:
        vector_expert_normalized = vector_expert / np.linalg.norm(vector_expert)
        vector_agent_normalized = vector_agent / np.linalg.norm(vector_agent)
        dot_product = np.dot(vector_expert_normalized, vector_agent_normalized)
        angle_cosine = round(dot_product,3)
        angle = np.arccos(angle_cosine)  
        angle_degrees = np.degrees(angle)  
        return angle_degrees
    elif((np.linalg.norm(vector_expert)+np.linalg.norm(vector_agent))!=0):
        return 90
    else:
        return 0
        

class simulator():
    def __init__(self,env_name,seed=0,verbose=False):
        self.verbose=verbose
        self.nonExpert_steps=0
        self.nonExpert_steps_max=3
        self.env_name=env_name
        benchmark_env = env_dict[env_name]
        task_name=get_task_text(env_name)
        self.policy=get_policy(env_name)
        self.env = benchmark_env(seed)
        self.observation,self.info=self.env.reset()
        self.success,self.cumulative_reward,self.timestep=0,0,0
        (self.expert_action_id,self.expert_gripper_id),f_language_id=self.policy.get_action(self.observation)
        (self.action_id,self.gripper_id)=(self.expert_action_id,self.expert_gripper_id)
        state=get_state(self.observation,self.policy,env_name)
        self.observation, self.reward, self.terminate,self.truncate, self.info= agent_step(self.policy,self.env,self.observation,get_action(self.policy,(self.action_id,self.gripper_id),self.observation,self.env_name),self.env_name)
        self.task_description=task_name
        self.non_expert_step_list=[]
        self.images=[]
        self.reward=-0.5
        self.timestep=0
        # print("Expert Step: ","action selection: ",(self.action_id,self.gripper_id)," ",self.timestep,"",f_language,"reward: ",self.reward)
        self.timestep+=1
        f_language=get_f_language(self.env_name,f_language_id,"training")
        # self.trajectory={"state":[init_state],"action":[self.action],"reward":[self.reward],"manual":self.task_description,"languages":[{"h":"","f":""}]}
        self.trajectory={"state":[state],"action":[np.array([self.action_id,self.gripper_id])],"reward":[self.reward],"manual":self.task_description,"languages":[{"h":("",""),"f":f_language}]}
        self.done=False
        self.success=0
        self.disturbed=False
    
    def run_expert_episode(self,steps):
        for _ in range(steps):
            if (self.done):
                break
            # h language
            if((self.action_id,self.gripper_id)!=(self.expert_action_id,self.expert_gripper_id)):
                self.non_expert_step_list.append(self.timestep-1)
            it=0
            while(self.timestep-it-1 in self.non_expert_step_list):
                it+=1            
            # h_language="You are disturbed. " if self.disturbed else "You are not disturbed. "
            h_language=get_h_language(self.env_name,self.disturbed,"training",self.action_id)
            self.disturbed=False
            # f language
            (self.expert_action_id,self.expert_gripper_id),f_language_id=self.policy.get_action(self.observation)
            f_language=get_f_language(self.env_name,f_language_id,"training")
            (self.action_id,self.gripper_id)=(self.expert_action_id,self.expert_gripper_id)
            self.trajectory["languages"].append({"h":h_language,"f":f_language})
            state=get_state(self.observation,self.policy,env_name)
            self.trajectory["state"].append(state)
            self.observation, self.reward, self.terminate,self.truncate, self.info= agent_step(self.policy,self.env,self.observation,get_action(self.policy,(self.action_id,self.gripper_id),self.observation,self.env_name),self.env_name)
            # image=self.env.render() # image shape: 240,320,3
            # self.images.append(image)
            # Design a better reward:
            self.reward=((self.action_id,self.gripper_id)==(self.expert_action_id,self.expert_gripper_id))*0.5
            self.trajectory["action"].append(np.array([self.action_id,self.gripper_id]))
            if self.info["success"]:
                self.reward+=20
            if self.verbose:
                print("Expert Step: ",self.timestep,h_language,f_language,"action selection: ",(self.action_id,self.gripper_id)," ","reward: ",self.reward)
            self.trajectory["reward"].append(self.reward)
            self.cumulative_reward += self.reward
            self.timestep+=1
            if self.info["success"]:
                self.success=1
                self.done=True
                break
            if self.terminate or self.truncate:
                self.done=True
                break
    
    def run_non_expert_episode(self,steps):
        action_space,gripper_space=get_action_space(self.env_name)
        self.disturbed=False
        for i in range(steps):
            if (self.done):
                break
            # h language
            if((self.action_id,self.gripper_id)!=(self.expert_action_id,self.expert_gripper_id)):
                self.non_expert_step_list.append(self.timestep-1)
            it=0
            while(self.timestep-it-1 in self.non_expert_step_list):
                it+=1
            # if it==steps or it==0:
            h_language=get_h_language(self.env_name,self.disturbed,"training",self.action_id)
            # h_language="You are disturbed. " if self.disturbed else "You are not disturbed. "
            self.disturbed=True
            # else:
                # h_language=""
            if (i==0):
                self.action_id=random.randint(0,action_space)
                # self.gripper_id=random.randint(0,gripper_space)
            # f language
            (self.expert_action_id,self.expert_gripper_id),f_language_id=self.policy.get_action(self.observation)
            f_language=get_f_language(self.env_name,f_language_id,"training")
            self.trajectory["languages"].append({"h":h_language,"f":f_language})
            state=get_state(self.observation,self.policy,env_name)
            self.trajectory["state"].append(state)
            self.observation, self.reward, self.terminate,self.truncate, self.info= agent_step(self.policy,self.env,self.observation,get_action(self.policy,(self.action_id,self.gripper_id),self.observation,self.env_name),self.env_name)
            # image=self.env.render() # image shape: 240,320,3
            # self.images.append(image)
            self.trajectory["action"].append(np.array([self.action_id,self.gripper_id]))
            self.reward=((self.action_id,self.gripper_id)==(self.expert_action_id,self.expert_gripper_id))*0.5-1
            if self.info["success"]:
                self.reward+=20
            if self.verbose:
                print("Non Expert Step: ",self.timestep,"h: ",h_language,"f: ",f_language,"expert: ", (self.expert_action_id,self.expert_gripper_id),"action selection: ",(self.action_id,self.gripper_id)," ","reward: ",self.reward)
            self.trajectory["reward"].append(self.reward)
            self.cumulative_reward += self.reward
            self.timestep+=1
            if self.info["success"]:
                self.done=True
                self.success+=1
                break
            if self.terminate or self.truncate:
                self.done=True
                break
    
    def run_episode(self):
        non_expert_step=random.randint(2,6)
        while(not self.done and self.timestep<30):
            if(self.timestep!=non_expert_step or self.nonExpert_steps>=self.nonExpert_steps_max):
                self.run_expert_episode(1)
            else:
                step1=random.randint(1,3)
                self.run_non_expert_episode(step1)
                non_expert_step=random.randint(self.timestep+step1+3,self.timestep+step1+7)
                self.nonExpert_steps+=1
        self.trajectory["languages"]=self.trajectory["languages"]
        self.trajectory["nonExpertTime"]=self.nonExpert_steps

def images_to_video(image_path_list, output_path, fps=2):
    with imageio.get_writer(output_path, fps=fps) as writer:
        for image_path in image_path_list:
            if isinstance(image_path, str):  # Check if the path is a string path
                image = imageio.imread(image_path)
            elif isinstance(image_path, np.ndarray):  # Check if it's already an image array
                image = image_path
            else:
                continue  # Skip unknown types
            writer.append_data(image)

def set_seed(seed=42):
    """reproduce"""
    torch.manual_seed(seed)
    torch.cuda.manual_seed(seed)
    torch.cuda.manual_seed_all(seed) 
    np.random.seed(seed)
    random.seed(seed)
    
if __name__=="__main__":
    parser = argparse.ArgumentParser()
    parser.add_argument(
        "--start",
        default=0,
        type=int,
    )
    parser.add_argument(
        "--end",
        default=1,
        type=int,
    )
    args=parser.parse_args()
    model = SentenceTransformer("sentence-transformers/paraphrase-TinyBERT-L6-v2")
    model.eval()
    set_seed()
    trajectory_dir = '/data/metaworld_hammer_dataset'
    if not os.path.exists(trajectory_dir):
        os.makedirs(trajectory_dir)
    # for i in range(args.start,args.end):
    for i in range(0,25):
    # reward=0
        # for env_name in ['assembly-v2-goal-observable']:
        for env_name in ['hammer-v2-goal-observable']:
            # try:
                print(i)
                success=0
                while(success!=1):
                    # seed=random.randint(0,20000)
                    seed=i%5
                    metaworldSimulator=simulator(env_name=env_name,seed=seed,verbose=False)
                    metaworldSimulator.run_episode()
                    if not os.path.exists(f"{trajectory_dir}/{env_name}/"):
                        os.makedirs(f"{trajectory_dir}/{env_name}/")
                    success+=metaworldSimulator.success
                    # reward+=metaworldSimulator.cumulative_reward
                    # print(metaworldSimulator.non_expert_step_list)
                    print(f"Seed {seed}, {env_name} success: ",metaworldSimulator.success, " cumulative reward: ",metaworldSimulator.cumulative_reward, metaworldSimulator.timestep)
                    # images_to_video(metaworldSimulator.images,f'./{env_name}.mp4',fps=10)
                process_data(metaworldSimulator.trajectory,f"{trajectory_dir}/{env_name}/trajectory_{i+1}.pth")
            # except:
                # print("Error")
    # print(success/50)
    # print(reward/50)
    