import torch
import numpy as np
import random
import numpy as np
import os
import sys
import openai 
import re 
import json
from state_processor import get_GPT_state
current_dir = os.path.dirname(os.path.abspath(__file__))
metaworld_dir = os.path.join(current_dir, 'metaworld')
if metaworld_dir not in sys.path:
    sys.path.insert(0, metaworld_dir)
import torch
from metaworld import policies
from sentence_transformers import SentenceTransformer
from metaworld.envs import ALL_V2_ENVIRONMENTS_GOAL_OBSERVABLE as env_dict
from typing import List
import random
from data_generation import get_policy,get_task_text,get_action,get_h_language,get_f_language
from utils import agent_step,get_state
def evaluate_episode_rtg(
        model,
        max_ep_len=50,
        device='cuda',
        target_return=None,
        eval_mode="validation",
        probability_threshold=33,
        task_name=None,
        language_model=None,
        empty_language=None,
        seed=0,
        config=None,
        gpt=False
    ):
        return evaluate_episode_rtg_metaworld(model,max_ep_len,device,target_return=target_return,mode='normal',env_name=task_name,eval_mode=eval_mode,probability_threshold=probability_threshold,language_model=language_model,empty_language=empty_language,seed=seed,config=config,gpt=gpt)


def evaluate_episode_rtg_metaworld(
        model,
        max_ep_len=30,
        device='cuda',
        target_return=None,
        mode='normal',
        env_name=None,
        eval_mode="validation",
        probability_threshold=33,
        language_model=None,
        empty_language=None,
        seed=None,
        config=None,
        gpt=False
):    
    model.eval()
    model.to(device=device)
    benchmark_env = env_dict[env_name]
    task_name=get_task_text(env_name)
    policy=get_policy(env_name)
    manual=task_name
    env = benchmark_env(seed)
    observation,info=env.reset()
    success,cumulative_reward,timestep=0,0,0
    with torch.no_grad():
        encoded_manual = torch.unsqueeze(torch.tensor(language_model.encode(manual)), dim=0).to('cuda')
        encoded_language=torch.unsqueeze(torch.tensor(language_model.encode("")), dim=0).to('cuda')
    state=get_state(observation,policy,env_name)
    states,actions,target_return_list,languages=[state],[],[target_return],encoded_language.unsqueeze(0)
    episode_return,success=0,0
    duration=0
    max_distort=3
    non_expert_step=random.randint(2,6)
    disturb=False
    non_expert_step_list=[] # This record the agent's steps that are not optimal for giving hindsight languages.
    length=0
    for t in range(30):
        actions.append([0,0])
        if config is None:
            if (t==non_expert_step and max_distort>0) or duration > 0:
                # print("disturbing")
                disturb=True
                if (duration==0):
                    duration=random.randint(0,2)
                    action=(random.randint(0,6), pred_gripper)
                    non_expert_step=random.randint(t+duration+4,t+duration+8)
                    max_distort-=1
                else:
                    duration-=1
                pred_pos,pred_gripper=action
                pred_pos=torch.tensor(pred_pos)
            else:
                action = model.get_action(encoded_manual,states,actions,target_return_list,languages,   env_name="metaworld")
                pred_pos,pred_gripper=(action[0].detach().cpu(),action[1].detach().cpu())
        else:
            if(t in config[seed][0]):
                pred_pos=torch.tensor(config[seed][1][str(t)])
                action=(pred_pos,pred_gripper)
            else:
                action = model.get_action(encoded_manual,states,actions,target_return_list,languages,   env_name="metaworld")
                pred_pos,pred_gripper=(action[0].detach().cpu(),action[1].detach().cpu())
                
        (expert_action_id,expert_gripper_id),f_language_id=policy.get_action(observation)
        # observation, reward, terminated, truncated, info = agent_step(policy,env,observation,get_action(policy,(expert_action_id,expert_gripper_id),observation,env_name),env_name)
        observation, reward, terminated, truncated, info = agent_step(policy,env,observation,get_action(policy,action,observation,env_name),env_name)
        reward=((pred_pos,pred_gripper)==(expert_action_id,expert_gripper_id))-0.5
        # print("time step: ",t, " expert action: ",(expert_action_id,expert_gripper_id), " agent action: ",pred_pos.item(),pred_gripper.item(), (pred_pos,pred_gripper)==(expert_action_id,expert_gripper_id))
        state=get_state(observation,policy,env_name)
        states.append(state)
        actions[-1]=[pred_pos,pred_gripper]
        if mode != 'delayed':
            pred_return = target_return_list[-1] - reward
        else:
            pred_return = target_return_list[-1]
        target_return_list.append(pred_return)
        if(info["success"]):
            reward+=20
        episode_return += reward
        done = terminated or truncated
        length+=1
        if(info["success"]):
            success=1
            break
        if done:
            break
        random_probability=random.randint(0,100)
        if((pred_pos,pred_gripper)!=(expert_action_id,expert_gripper_id)):
            non_expert_step_list.append(t)
        it=0
        while(t-it in non_expert_step_list):
            it+=1
        it=it if it<3 else 3
        # pred_pos=random.randint(0,6)
        # f_language_id=random.randint(0,6)
        h_language=get_h_language(env_name,it,eval_mode,pred_pos)[0]
        (expert_action_id,expert_gripper_id),f_language_id=policy.get_action(observation)
        f_language=get_f_language(env_name,f_language_id,eval_mode)[0]
        # if(random_probability<probability_threshold):
        #     if random_probability<70:
        #         f_language=f_language
        #         h_language=h_language
        #     elif random_probability<85:
        #         h_language=""
        #     else:
        #         f_language=""
        # else:
        #     f_language=""
        #     h_language=""      
        if gpt:
            language,isSpeak=get_GPT_response(h_language,f_language,state,env_name)
            language=language if isSpeak else ""
            # print("GPT response before translated: ",language)
            # print("given: ",hindsight+foresight,"GPT: ",language,isSpeak)
            if language != "":
                language=GPT_translation(language)
        else:
            language=h_language+f_language
        with torch.no_grad():
            encoded_language=torch.unsqueeze(torch.tensor(language_model.encode(language)), dim=0).to('cuda')
        languages=torch.cat((languages,encoded_language.unsqueeze(0)),dim=0)
        # print(f_language, h_language)
    # print(episode_return,success,length)
    # exit()
    return episode_return,success,length

PROMPT_ENABLE_NO_FEEDBACK = """
You are a human expert that teaches non-expert agent to improve its performance in a robotic manipulation task. The robot task is to {task}.

The game simulator provides the following hint due to the last action:
(1) hindsight feedback: {hindsight}
(2) foresight feedback: {foresight}
(3) Supporting state information: {states}

As a human expert, fully consider the given information and hint. 

First decide whether it is necessary to intervene as expert. Keep in mind that you don't want to give any language feedback if the agent stays on the right track during the last few steps. You really need to keep in mind that you are reluctant to give language feedback if everything looks good.

You are expected to give the agent hindsight compliment/criticism on the agent's previous actions, and give instructions for the agent's future actions. Don't only give foresight instructions, hindsight feedback is also important. 

Simply give the hindsight and foresight information is enough, you don't need to include other supporting information.

You should only respond in a json format as described below:

{
   "response": "your hindsight compliment/criticism and foresight instruction to give to the robot. (empty string if_give_response is false)",
   "if_give_response": true/false (Python Boolean), true if you feel necessary to give response, otherwise false
}

Make sure the response contains all keys listed in the above example and must be parsed by Python json.loads().
"""

PROMPT_TRANSLATOR="""
You are a helpful translator. You are given a language: {response}. 

You will convey the language to an agent as language feedback. You need to rephrase or modify some vocabulary in the sentence to make it a natural and diverse daily human language. The meaning should be exactly the same.

Try to use more diverse languages to replace some verbs or adjectives.

If the language contains more than 1 sentences, translate within each sentence. The length of the result language should be roughly the same as the input to you.


Your should only respond in a json format as described below:

{
   "response": "your translated language response should be natural.",
}

Make sure the response contains all keys listed in the above example and must be parsed by Python json.loads().
"""

def get_GPT_response(hindsight,foresight,state,env_name):
    openai.api_key = ""
    state_info=get_GPT_state(env_name,state)
    task="pick up the wrench with the gripper and put it accurately on the peg" if env_name=='assembly-v2-goal-observable' else "pick up the hammer and hit accurately at the nail."
    try:
        prompt = (
                        PROMPT_ENABLE_NO_FEEDBACK.replace("{hindsight}",hindsight).replace("{foresight}",foresight).replace("{states}",state_info).replace("{task}",task)
                )
        # print("prompt: ",prompt)
        response = openai.chat.completions.create(
                        model="gpt-3.5-turbo-1106",
                        messages=[
                            {"role": "user", "content": prompt},
                        ],
                        temperature=1.2,
                    )
        response = response.choices[0].message.content
        pattern = r"\{([^{}]*)\}"
        matches = re.findall(pattern, response)
        response = "{" + matches[0] + "}"
        response = json.loads(response)
        lang = response["response"]
        isSpeak=response["if_give_response"]
    except:
        return "",False
    return lang,isSpeak

def GPT_translation(response):
    import os
    openai.api_key = os.environ.get("OPENAI_API_KEY")
    try:
        prompt = (
                        PROMPT_TRANSLATOR.replace("{response}",response)
                )
        response = openai.chat.completions.create(
                        model="gpt-3.5-turbo-1106",
                        messages=[
                            {"role": "user", "content": prompt},
                        ],
                        temperature=1.2,
                    )
        response = response.choices[0].message.content
        # response = response[8:-3]
        pattern = r"\{([^{}]*)\}"
        matches = re.findall(pattern, response)
        response = "{" + matches[0] + "}"
        response = json.loads(response)
        lang = response["response"]
    except:
        return ""
    return lang