import torch
import numpy as np
import argparse
import json
import torch
import numpy as np
import gym
import messenger
from messenger.models.emma import EMMA
from messenger.models.utils import ObservationBuffer
from observation_process import observationProcessor, numpy_formatter
import torch
from transformers import RobertaTokenizer
from torch.nn.functional import pad
from data_process import getDataset
import numpy as np
from inferenceDataManager import inferenceDataManager
from deepCopy import copier
from pathSolver import pathSolver
import random
import re
# from prompt import LLMPrompter
import openai
def evaluate_episode_rtg(
        model,
        language_model,
        max_ep_len=50,
        scale=100.,
        state_mean=0.,
        state_std=1.,
        device='cuda',
        target_return=None,
        mode='normal',
        newTask=True,
        hasLanguage=True,
        empty_language=None,
        realLang=True,
        eval_mode="validation",
        eval_type="language",
        probability_threshold=33,
        seed=None,
        gpt=False
    ):
        episode_return, episode_length,subgoal_complete,goal_complete,data=evaluate_episode_rtg_messenger(
            model=model,
            language_model=language_model,
            max_ep_len=max_ep_len,
            scale=scale,
            state_mean=state_mean,
            state_std=state_std,
            device=device,
            target_return=target_return,
            mode=mode,
            newTask=newTask,
            hasLanguage=hasLanguage,
            empty_language=empty_language,
            realLang=realLang,
            eval_mode=eval_mode,
            eval_type=eval_type,
            probability_threshold=probability_threshold,
            seed=seed,
            gpt=gpt
        )            
            
        return episode_return, episode_length,subgoal_complete,goal_complete,data

def evaluate_episode_rtg_messenger(
        model,
        language_model,
        max_ep_len=50,
        scale=100.,
        state_mean=0.,
        state_std=1.,
        device='cuda',
        target_return=None,
        mode='normal',
        newTask=True,
        hasLanguage=True,
        empty_language=None,
        realLang=True,
        eval_mode="validation",
        eval_type="no",
        prompter=None,
        probability_threshold=33,
        seed=None,
        gpt=False
    ):
    if(seed is None):
        seed=random.randint(0,1000)
    np.set_printoptions(formatter={'int': numpy_formatter})
    model.eval()
    model.to(device=device)
    language_model.eval()
    language_model.to(device=device)
    stateProcessor=observationProcessor()
    state_mean = state_mean.to(device=device)
    state_std = state_std.to(device=device)
    tempenv = gym.make("msgr-train-v3")
    obs, manual = tempenv.reset(seed=seed)
    if(newTask):
        manual.insert(0,"First go to the goal, then go to the message. ")
    else:
        manual.insert(0,"First go to the message, then go to the goal. ")
    envCopier=copier(tempenv)
    env=envCopier.newTask(tempenv,newTask)
    stateProcessor=observationProcessor()
    state=stateProcessor.generate_trajectory_state(obs)
    states,actions,rewards = [state],[],[]
    currentState=stateProcessor.generate_state(env)
    envCopier=copier(env)
    expertEnv=envCopier.deep_copy(env,newTask)
    stateContainer=[]
    expertEnv=envCopier.deep_copy(env,newTask)
    fp,fn,distance_to_target,distance_to_enemy,fnFlag,expert_action=getExpertInstruction(expertEnv,stateProcessor,newTask,mode=eval_mode,realLang=realLang)
    language_feedback=fp
    with torch.no_grad():
        encoded_manual = torch.unsqueeze(torch.tensor(language_model.encode(" ".join(manual))), dim=0).to('cuda')
        encoded_language=torch.unsqueeze(torch.tensor(language_model.encode(language_feedback)), dim=0).to('cuda')
    target_return_list = [target_return/scale]
    languages=encoded_language.unsqueeze(0)
    episode_return, episode_length = 0,0
    subgoal_complete=0
    goal_complete=0
    stateContainer.append(stateProcessor.simplifyState(currentState))
    data={"languages":[]}
    on_track=[]
    for t in range(max_ep_len):
        actions.append(4)
        rewards.append(0)
        action = model.get_action(encoded_manual,states,actions,rewards,target_return_list,languages,hasLanguage=hasLanguage)
        actions[-1] = action
        action = action.detach().cpu().numpy()
        obs, reward, done, _ = env.step(action)
        currentState=stateProcessor.generate_state(env)
        # print(stateProcessor.generate_grid(obs))
        reward=reward*100+stateProcessor.process_reward(currentState)-0.5
        if(reward>45 and reward<70):
            subgoal_complete+=1
        if(reward>95):
            goal_complete+=1
        state=stateProcessor.generate_trajectory_state(obs)
        stateContainer.append(stateProcessor.simplifyState(currentState))
        states.append(state)
        rewards[-1] = reward
        if mode != 'delayed':
            pred_return = target_return_list[-1] - (reward/scale)
        else:
            pred_return = target_return_list[-1]
        target_return_list.append(pred_return)
        episode_return += reward
        episode_length += 1
        action_list=["up","down","left","right","noMotion"]
        last_action=action_list[action]
        if done:
            # print("done")
            break
        if (eval_type=="no"):
            hindsight=""
            foresight=""
        else:
            randomProbability=random.randint(0,100)
            # print(t,randomProbability,probability_threshold)
            if(randomProbability<probability_threshold) :
                # print("Generating language")
                hindsight_raw,hnFlag=stateProcessor.process_state_for_GPT_train(stateContainer[t:t+2],mode=eval_mode,moreInfo=True,expert_action=expert_action)
                hindsight=""
                language_template_human="human" if (realLang) else "template"
                hp,hn="",""
                if ("hindsight positive" in hindsight_raw):
                    hp=hindsight_raw["hindsight positive"][language_template_human]
                if ("hindsight negative" in hindsight_raw):
                    hn=hindsight_raw["hindsight negative"][language_template_human]
                hindsight=hp+hn
                expertEnv=envCopier.deep_copy(env,newTask)
                fp,fn,distance_to_target,distance_to_enemy,_,expert_action=getExpertInstruction(expertEnv,stateProcessor,newTask,mode=eval_mode,realLang=realLang)
                foresight=fp+fn
            else:
                hindsight=""
                foresight=""
        with torch.no_grad():
            on_track.append("bad" not in hindsight)
            # print(t,hindsight)
            # hindsight=""
            # hindsight="" if random.uniform(0,1)<0.33 else hindsight
            # foresight="" if random.uniform(0,1)<0.33 else foresight
            # encoded_language_h=torch.unsqueeze(torch.tensor(language_model.encode(hindsight)), dim=0).to('cuda')      
            # encoded_language_f=torch.unsqueeze(torch.tensor(language_model.encode(foresight)), dim=0).to('cuda')
            # language_from_GPT=foresight
            if gpt:
                language,isSpeak=get_GPT_response(hindsight,foresight,distance_to_target,distance_to_enemy,on_track)
                language=language if isSpeak else ""
                # print("GPT response before translated: ",language)
                # print("given: ",hindsight+foresight,"GPT: ",language,isSpeak)
                if language != "":
                    language=GPT_translation(language)
            else:
                language=hindsight+foresight
            # print("given: ",hindsight+foresight,"GPT: ",language)
            #################################################
            if language !="":
                # print(language)
                segments = language.split('.')
                # Remove any empty segments or segments with only whitespace
                segments = [seg.strip() for seg in segments if seg.strip()]
                # Encode each segment and store the embeddings
                embeddings = []  
                for segment in segments:
                    # print(segment)
                    encoded_segment = torch.unsqueeze(torch.tensor(language_model.encode(segment+". ")), dim=0).to('cuda')
                    embeddings.append(encoded_segment)
                encoded_language = torch.mean(torch.stack(embeddings), dim=0)
                # print("")
            else:
                encoded_language=torch.unsqueeze(torch.tensor(language_model.encode(language)), dim=0).to('cuda')
            ##################################################
            # print(hindsight+foresight)
            # encoded_language=torch.unsqueeze(torch.tensor(language_model.encode(hindsight+foresight)), dim=0).to('cuda')
            # encoded_language=(encoded_language_h+encoded_language_f)/2
            # encoded_language=encoded_language_f
        # data["languages"].append((language_from_GPT,isSpeak))
        languages=torch.cat((languages,encoded_language.unsqueeze(0)),dim=0)
    print("Episode ends: ",goal_complete==1)
    return episode_return, episode_length,subgoal_complete,goal_complete,data

def getExpertInstruction(expertEnv,observation_Processor,newTask=False,mode="validation",realLang=True):
    path_Solver = pathSolver()
    stateList=[]
    currentState = observation_Processor.generate_state(expertEnv)
    stateList.append(observation_Processor.simplifyState(currentState))
    if("goal" not in currentState or "agent" not in currentState):
            return
    for i in range(1):
                currentState = observation_Processor.generate_state(expertEnv)
                path_Solver.update(currentState)
                action = path_Solver.get_action()
                # path_Solver.print_map()
                _,_, done, _ = expertEnv.step(action)
                currentState = observation_Processor.generate_state(expertEnv)
                stateList.append(observation_Processor.simplifyState(currentState))
                if done:
                    break
    future_pos=""
    try:
        instruct,distance_to_target,distance_to_enemy,fn=observation_Processor.process_state_for_GPT_expert(stateList,newTask,mode=mode,moreInfo=True)
        language_template_human="human" if (realLang) else "template"
        future_pos="" if ("foresight positive" not in instruct) else instruct["foresight positive"][language_template_human]
        future_neg="" if ("foresight negative" not in instruct) else instruct["foresight negative"][language_template_human]
    except:
        future_pos=""
        future_neg=""
        distance_to_enemy,distance_to_target=np.inf
        fn=False
    return future_pos,future_neg,distance_to_target,distance_to_enemy,fn,action

PROMPT_ENABLE_NO_FEEDBACK = """
You are a human expert that teaches non-expert agent to improve its performance in a grid world. The robot task is to find the message and then send it to the goal.

Here's the action space of the robot:
["right", "left", "up", "down", "noMotion"]

The game simulator provides the following hint due to the last action:
(1) hindsight feedback: {hindsight}
(2) foresight feedback: {foresight}
(3) Supporting state information: {states}

As a human expert, fully consider the given information and hint. 

First decide whether it is necessary to intervene as expert. Keep in mind that you don't want to give any language feedback if the agent stays on the right track during the last few steps. You really need to keep in mind that you are reluctant to give language feedback if everything looks good.

You are expected to give the agent hindsight compliment/criticism on the agent's previous actions, and give instructions for the agent's future actions. Don't only give foresight instructions, hindsight feedback is also important. 

Simply give the hindsight and foresight information is enough, you don't need to include other supporting information.
You can freely decide whether to include the hindsight or foresight information.

You should only respond in a json format as described below:

{
   "response": "your hindsight compliment/criticism and foresight instruction to give to the robot. (empty string if_give_response is false)",
   "if_give_response": true/false (Python Boolean), true if you feel necessary to give response, otherwise false
}

Make sure the response contains all keys listed in the above example and must be parsed by Python json.loads().
"""

PROMPT_TRANSLATOR="""
You are a helpful translator. You are given a language: {response}. 

You will convey the language to an agent as language feedback. You need to rephrase or modify some vocabulary in the sentence to make it a natural and diverse daily human language. The meaning should be exactly the same.

Try to use more diverse languages to replace some verbs or adjectives.

Keep in mind that the action space is only "right,left,up,down" for the agent, and the agent won't understand other actions or directions.
If the language contains more than 1 sentences, translate within each sentence. The length of the result language should be roughly the same as the input to you.


Your should only respond in a json format as described below:

{
   "response": "your translated language response should be natural.",
}

Make sure the response contains all keys listed in the above example and must be parsed by Python json.loads().
"""

def get_GPT_response(hindsight,foresight,distance_to_target,distance_to_enemy,on_track):
    openai.api_key = ""
    is_good="on the right track" if "bad" not in hindsight else "making mistakes"
    count=0
    for i in range(len(on_track)):
        count+=on_track[len(on_track)-i-1]==True
    state_info=f"The distance from the agent to the target is {distance_to_target}, the distance from the agent to the enemy is {distance_to_enemy}. Whether the agent is on the right track for the last three steps: {on_track[-3:]}"
    # if count==3:
    #     state_info+=" The agent is on the right track for the previous 3 steps."
    # if (hindsight+foresight)=="" or count==6:
    #     return "",False
    try:
        prompt = (
                        PROMPT_ENABLE_NO_FEEDBACK.replace("{hindsight}",hindsight).replace("{foresight}",foresight).replace("{states}",state_info)
                )
        # print("prompt: ",prompt)
        response = openai.chat.completions.create(
                        model="gpt-3.5-turbo-1106",
                        messages=[
                            {"role": "user", "content": prompt},
                        ],
                        # seed=42,
                        temperature=1.2,
                    )
        response = response.choices[0].message.content
        # response = response[8:-3]
        pattern = r"\{([^{}]*)\}"
        matches = re.findall(pattern, response)
        response = "{" + matches[0] + "}"
        response = json.loads(response)
        lang = response["response"]
        isSpeak=response["if_give_response"]
    except:
        return "",False
    return lang,isSpeak

def GPT_translation(response):
    import os
    openai.api_key = os.environ.get("OPENAI_API_KEY")
    try:
        prompt = (
                        PROMPT_TRANSLATOR.replace("{response}",response)
                )
        response = openai.chat.completions.create(
                        model="gpt-3.5-turbo-1106",
                        messages=[
                            {"role": "user", "content": prompt},
                        ],
                        temperature=1.2,
                    )
        response = response.choices[0].message.content
        # response = response[8:-3]
        pattern = r"\{([^{}]*)\}"
        matches = re.findall(pattern, response)
        response = "{" + matches[0] + "}"
        response = json.loads(response)
        lang = response["response"]
    except:
        return ""
    return lang

if __name__ == "__main__":
    response="It's a bad move not moving effectively towards the goal queen. Move left to approach the goal queen."
    print(GPT_translation(response))