"""Adapted from https://github.com/ysymyth/ReAct/blob/master/alfworld.ipynb"""

import os
import sys
import json
import yaml
import openai
import importlib
import alfworld
import alfworld.agents.environment
from env_history_elite import EnvironmentHistory
import time
import ipdb
import copy

from tenacity import (
    retry,
    stop_after_attempt, # type: ignore
    wait_random_exponential, # type: ignore
    wait_exponential_jitter,
    retry_if_exception_type,
)

SLEEP_TIME=2
TEMPERATURE=0.2

from typing import List, Dict, Any, Tuple
 
#openai.api_key = os.environ["OPENAI_API_KEY"]
openai.api_key = ""
API_KEYS = ["",
            ""]
current_key_index = 0

FOLDER = './prompts'
PROMPT_FILE = 'alfworld_3prompts.json'


with open(os.path.join(FOLDER, PROMPT_FILE), 'r') as f:
    d = json.load(f)
#with open('./challenge_few_shot_examples.txt', 'r') as f:
#    challenge_examples = f.read()

@retry(wait=wait_exponential_jitter(initial=5, max=30), stop=stop_after_attempt(6))
def llm(prompt, stop=["\n"]):
    global current_key_index

    try:
        cur_try = 0
        while cur_try < 6:
            time.sleep(SLEEP_TIME*(cur_try+1))
            openai.api_key = API_KEYS[current_key_index]
            response = openai.Completion.create(
              model="code-davinci-002",
              prompt=prompt,
              temperature=cur_try * TEMPERATURE,
              max_tokens=100,
              top_p=1,
              frequency_penalty=0.0,
              presence_penalty=0.0,
              stop=stop
            )

            current_key_index = (current_key_index + 1) % len(API_KEYS)

            #ipdb.set_trace()
            text = response["choices"][0]["text"]
            # dumb way to do this
            if len(text.strip()) >= 5:
                return response["choices"][0]["text"]
            cur_try += 1
            
        return ""
    except Exception as e:
        print(prompt)
        print(e)
        import sys
        #sys.exit(1)

#@retry(wait=wait_random_exponential(min=5, max=60), stop=stop_after_attempt(6))
#@retry(retry=retry_if_exception_type(RateLimitError), wait=wait_exponential_jitter(initial=1, max=30))
@retry(wait=wait_exponential_jitter(initial=5, max=30), stop=stop_after_attempt(6))
def llm_chat(messages, stop=["\n"], max_tokens=100):
    openai.api_key = "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" # mnskim0
    openai.api_key = "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" # mnskim0

    try:
        cur_try = 0
        while cur_try < 6:
            time.sleep(SLEEP_TIME)
            response = openai.ChatCompletion.create(
              model="gpt-3.5-turbo-0301",
              messages=messages,
              temperature=cur_try * TEMPERATURE,
              max_tokens=max_tokens,
              top_p=1,
              n=1,
              frequency_penalty=0.0,
              presence_penalty=0.0,
              stop=stop
            )
            text = response["choices"][0]['message']['content']
            # dumb way to do this
            if len(text.strip()) >= 5:
                return response["choices"][0]['message']['content']
            cur_try += 1
        return ""
    except Exception as e:
        print(messages)
        print(e)
        import sys
        #sys.exit(1)

def process_ob(ob):
    if ob.startswith('You arrive at loc '):
        ob = ob[ob.find('. ')+2:]    
    return ob

def alfworld_run(env, base_prompt, memory: List[str], to_print=True, ob='', world_model: List[str]=[]) -> Tuple[EnvironmentHistory, bool]:
    #ipdb.set_trace()

    
    

    if len(world_model) > 0:
        if len(memory) > 3:
            env_history = EnvironmentHistory(base_prompt, ob, memory[-3:], [], world_model)
        else:
            env_history = EnvironmentHistory(base_prompt, ob, memory, [], world_model)
    else:    
        if len(memory) > 3:
            env_history = EnvironmentHistory(base_prompt, ob, memory[-3:], [])
        else:
            env_history = EnvironmentHistory(base_prompt, ob, memory, [])
    env_history.reset()
    # init_prompt = prompt + ob + '\n>'
    # prompt = ''
    if to_print:
        print(ob)
        sys.stdout.flush()
    cur_step = 0
    do_analysis = False
    while cur_step < 50:
        
        # action = llm(init_prompt + prompt, stop=['\n']).strip()
        #ipdb.set_trace()
        openai.api_key = "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"
        action = llm(str(env_history) + ">", stop=['\n'])
        if action is None:
            ipdb.set_trace()
            action = ''
        else:
            action = action.strip()            
        # catch attribute error for None action
                
            
        #env_history.add("action", action)

        if not action.startswith('think:'):
            reached_first_nonthink = True            
            #env_history._history.pop()

            

        observation, reward, done, info = env.step([action])
        observation, reward, done = process_ob(observation[0]), info['won'][0], done[0]
        
        #if action.startswith('think:'):
        if 'think' in action:
            observation = 'OK.'

            
            # Do analysis after a thought and subsequent action
            

        env_history.add("action", action)
        env_history.add("observation", observation)
        if to_print:
            print(f'> {action}\n{observation}')
            sys.stdout.flush()
        # prompt += f' {action}\n{observation}\n>'
        if done:
            #ipdb.set_trace()
            return env_history, reward

        
        #if False:
        elif env_history.check_is_exhausted():

            #if action.startswith('go'):                               
                #target_loc = action.split('go to')[1].strip()
                #env_history._history[-1]['value'] = f"You are already at {target_loc}"
                #env_history._is_exhausted = False
                #continue # skip the rest of the loop

            # Try to recover from exhausted state
            print('-ANALYZE-')                    

            inp = env_history.__str__(exclude_last=2)
            traj_only = ' '.join(inp.split('Here is the task:')[1:])
            traj_only_no_think = env_history.__str__(exclude_thoughts=True)
            
            demos = inp.split('Here is the task:')[0]
            #ipdb.set_trace()

            #env_history._history.pop() # remove the last observation (nothing happens)
            # enforce negative feedback
            #observation = f"When you attempted to execute the action {action}, it could not be executed."
            #env_history.add('observation', observation)
            #print(f'NEG FEEDBACK: {action}\n{observation}')            


            #ipdb.set_trace()
            

            #state = llm_chat(state_m, stop=['END'], max_tokens=300)
            print("[Generating state...]")
           

            # read world modeling prompt
            with open("./elite_few_shot_world_modeling.txt", 'r') as f:
                WM_FEW_SHOT_EXAMPLES = f.read()    
            
            with open("./elite_few_shot_world_modeling_nothoughts.txt", 'r') as f:
                WM_FEW_SHOT_EXAMPLES = f.read()    
            
            with open("./elite_few_shot_world_modeling_nothoughts_qa.txt", 'r') as f:
                WM_FEW_SHOT_EXAMPLES = f.read()    
            
            with open("./reflexion_few_shot_examples_codex.txt", 'r') as f:
                RF_FEW_SHOT_EXAMPLES = f.read()    
            
            with open("./elite_few_shot_world_modeling_no_annot.txt", 'r') as f:
                ANNOT_FEW_SHOT_EXAMPLES = f.read()    

                        
            # 1-step backtrack
            backtrack_1step = [env_history._history[-2], env_history._history[-1]]

            # 2-step backtrack
            #ipdb.set_trace()
            if len(env_history._history) > 4:
                if env_history._history[-3]['label'] == 'observation' and env_history._history[-3]['value'].startswith('OK'):
                    backtrack = [env_history._history[-4], env_history._history[-3], env_history._history[-2], env_history._history[-1]]
                else:
                    backtrack = backtrack_1step                
            else:
                backtrack = backtrack_1step
            
            recent_past_str = ''
            for i, item in enumerate(backtrack):
                if item['label'] == 'observation':
                    recent_past_str += '\n' + item['value']
                elif item['label'] == 'action':
                    recent_past_str +=  '\n> ' + item['value']

            #env_history._history[-4:]
            print(f"## Backtrack portion: {backtrack}")
            traj_backtrack = env_history.__str__(exclude_last=len(backtrack))
            
            
            scenario = 'Here is the task' + traj_backtrack.split('Here is the task')[1]
            
            traj_backtrack_nothoughts = env_history.__str__(exclude_last=len(backtrack), exclude_thoughts=True)
            scenario = 'Here is the task' + traj_backtrack_nothoughts.split('Here is the task')[1]

            prompted_wm_m = [{"role": "system", "content": "You are an agent acting in a virtual household environment, operating by following goals at any given time."}, {"role": "user", "content": f'As a world model, your job is to accurately provide information about the current environment. Here are two demonstrations of the world modeling task: {WM_FEW_SHOT_EXAMPLES}. Here is your current trajectory: {scenario}\nSTATUS: ANALYSIS\n Give the current world state of the trajectory. Start generating with STATUS: ANALYSIS and end with the word "END".'},]  
                        
            #state =  llm_chat(prompted_wm_m, stop=['END'], max_tokens=300)   
            #codex_prompted_wm_m =  WM_FEW_SHOT_EXAMPLES + scenario + "STATUS: ANALYSIS\n"
            #state = llm(codex_prompted_wm_m, stop=['END']) 

 

            #qa_m = [{"role": "system", "content": "You are an agent acting in a virtual household environment, operating by following goals at any given time."}, {"role": "user", "content": f'As a world model, your job is to accurately provide information about the current environment. Here are two demonstrations of the world modeling task: {WM_FEW_SHOT_EXAMPLES}. Here is your current trajectory: {scenario}\n Write a list of questions that are relevant to the current subgoal, and the answers to those questions. if you do not know an answer, answer I don\'t know. end with the word "END".'},]
            #qa = llm_chat(qa_m, stop=['END'], max_tokens=300)   
            #codex_prompted_qa_m =  WM_FEW_SHOT_EXAMPLES + scenario + "STATUS: ANALYSIS\n> Useful questions:"
            #qa = llm(codex_prompted_qa_m, stop=['END'])
            #print(f"\n## State: {qa}")

            obj_m = [{"role": "system", "content": "You are an agent acting in a virtual household environment, operating by following goals at any given time."}, {"role": "user", "content": f'As a world model, your job is to accurately provide information about the current environment. Here is your task: {env_history._history[0]}\n. What are the objects of interest? Return only the name of each object per line, and end with the word "END".'},]
            obj = llm_chat(obj_m, stop=['END'], max_tokens=300)
            if not obj is None:
                objs = obj.split('\n')
                objs = [o for o in objs if o != '']
            else:
                objs = []

            query_str = f"""As a world model, your job is to accurately provide information about the current environment. Here are two demonstrations of the world modeling task: {WM_FEW_SHOT_EXAMPLES}. Here is your current trajectory: {scenario}\n Answer the following questions, first the question followed by its answer. If you do not know an answer, answer I don\'t know. end with the word "END". Here are the questions:"""

            query_str = f"""As a world model, your job is to accurately provide information about the current environment. Here are two demonstrations of the world modeling task: {WM_FEW_SHOT_EXAMPLES}. Here is your current trajectory: {scenario}\n Answer the following questions. Here are the questions:"""

            ipdb.set_trace()
            q_id = 1
            for o in objs:
                query_str += f"\n {q_id}) Where was the {o} located initially?"
                q_id += 1
                query_str += f"\n {q_id}) Where is the {o} located now?"
                q_id += 1
            query_str += f"\n {q_id}) Where am I located now?"

            query_str += f"For each question, repeat the question itself, then answer it. If you do not know an answer, answer I don\'t know. end with the word \"END\""            

            qa_m = [{"role": "system", "content": "You are an agent acting in a virtual household environment, operating by following goals at any given time."}, {"role": "user", "content": query_str},]            
            qa = llm_chat(qa_m, stop=['END'], max_tokens=300) 
            print(f"\n## State: {qa}")
            
            #qgen_m = [{"role": "system", "content": "You are an agent acting in a virtual household environment, operating by following goals at any given time."}, {"role": "user", "content": f"Here are two demonstrations of tasks similar to the one you are trying to solve: {demos}. Analyze the demonstrations and think about the knowledge-seeking queries that are required to solve the tasks. Based on this, generate a list of questions that you would ask to solve the current task. Here is the current task: {env_history._history[0]} End with the word 'END'."},]
            #qgen = llm_chat(qgen_m, stop=['END'], max_tokens=300)
            
            #ipdb.set_trace()

            interaction_m = [{"role": "system", "content": "You are an agent acting in a virtual household environment, operating by following goals at any given time."},                    
            {"role": "user", "content": f'You are solving a task in a virtual household environment. The environment refers to the virtual space in which you are operating. The state refers to the current state of the world, which is a combination of the current state of the objects in the environment, and your current state. Here are two demonstrations from randomly selected instances of this environment, followed by the current task: {traj_backtrack}. Given your current subgoal, what relevant interaction and syntax rules are there, based on successful demonstrations and past experiences? End with the word "END".'},
            ]
            #print("[Generating interaction rules...]")
            #interaction = llm_chat(interaction_m, stop=['END'], max_tokens=300)        
            interaction = ''
            #print(f"\n## Interaction: {interaction}")

            
            # backup original env_history
            backup_env_history = copy.deepcopy(env_history)           

            # replan
            if len(backtrack) <= 2:
                # 1-step backtrack
                env_history._history = env_history._history[:-2]

            elif len(backtrack) > 2:
                # 2-step backtrack
                env_history._history = env_history._history[:-4]

            
            

            thought_repair_m = [{"role": "system", "content": "You are an agent acting in a virtual household environment, operating by following goals at any given time."},                    
            {"role": "user", "content": f'You are solving a task in a virtual household environment. Here are demonstration(s) from randomly selected successful instances of a similar environment, followed by the current task: {traj_backtrack}. Your next trajectory originally was {recent_past_str}, but it led to a failure. This is what you know so far: {qa}. Write a new thought that is more likely to lead to a successful trajectory, and only base your answer on what you know. End with the word "END".'},]

            rethink = llm_chat(thought_repair_m, stop=['END','\n'], max_tokens=300)

            #ipdb.set_trace()

            #if 'I am not sure' in rethink:
                # revert env_history
                #env_history = copy.deepcopy(backup_env_history)                
                #print(f"\n## Rethink Uncertain: {rethink}")
                #rethink = 'I should try something else.'
            #else:                                            
                #print(f"\n## Rethink Certain: {rethink}")
                #ipdb.set_trace()
            
            
            if rethink == None:
                return env_history, reward

                def retry_with_less_examples():
                    # try with only one example
                    reduced_prompt1 =  'Interact with a household to solve a task. Here is one example.\n' + 'You are in the middle' +  base_prompt.split('You are in the middle')[1]
                    reduced_prompt0 =  'Interact with a household to solve a task. Here is one example.\n' + 'You are in the middle' +  base_prompt.split('You are in the middle')[2]
                    # 
            
                    # continue updating chatgpt-reflection direction            
                    rethink_m = [{"role": "system", "content": "You are an agent acting in a virtual household environment, operating by following goals at any given time."},                    
            {"role": "user", "content": f'You are solving a task in a virtual household environment. Here are two demonstrations from randomly selected instances of this environment, followed by the current task: {str(env_history)}. State the current subgoal and what relevant observation you are acting upon, to complete it. You have learned that {action} did not work as expected. What do , give your best suggestion for the next environment-specific action to take. End with the word "END".'},
            ]
            
                    #world_m = rethink_m
                    rethink = llm_chat(rethink_m, stop=['END'], max_tokens=300) 
                    #rethink = llm_chat(rethink_m, stop=['\n'], max_tokens=300) 
                    return rethink
                
                rethink = retry_with_less_examples()

                if rethink == None:
                    # Retried with fewer examples, but failed - Quit the trial
                    #ipdb.set_trace()                     
                    # add ?
                    #env_history._is_exhausted = True
                    return env_history, reward

            # Produced a rethink
            #rethink = 'think: Something unexpected happened. It\'s time to reasses. ' + rethink
            #rethink = f'think: I was unable to carry out the action, {action}. ' + rethink
            
            # preemptive: although we do not know why the action failed, we can preemptively try to do something else. note that we do have to try the action and observe the failure first.
            
            #if not rethink.startswith('think: '):
            #    rethink = f'think: ' + rethink
            if not ('think:' in rethink):
                rethink = f'think: ' + rethink

            # optional: remove the last action and observation
            #env_history._history.pop()
            #env_history._history.pop()            

            if rethink == env_history._history[-1]['value']:
                # Rethinking produces the same thought - Quit the trial
                #env_history._is_exhausted = True
                # add ?
                return env_history, reward
            else:
                # Successfully recovered from exhausted state
                env_history._is_exhausted = False 
                #rethink = 'think: ' + rethink
                action = rethink
                observation = 'OK.'
                env_history.add("action", action)
                env_history.add("observation", observation)
                print(f'> {action}\n{observation}')
                sys.stdout.flush()
                cur_step += 1
            #ipdb.set_trace()
            
        cur_step += 1
    return env_history, reward

PREFIXES = {
    'pick_and_place': 'put',
    'pick_clean_then_place': 'clean',
    'pick_heat_then_place': 'heat',
    'pick_cool_then_place': 'cool',
    'look_at_obj': 'examine',
    'pick_two_obj': 'puttwo'
}

def run_trial(
        trial_log_path: str,
        world_log_path: str,
        trial_idx: int,
        env_configs: List[Dict[str, Any]],
        use_memory: bool,
        use_state: bool,
        is_resume_crash: bool,        
    ) -> List[Dict[str, Any]]:
    importlib.reload(alfworld)
    importlib.reload(alfworld.agents.environment)

    with open('base_config.yaml') as reader:
        config = yaml.safe_load(reader)
    split = "eval_out_of_distribution"

    env = getattr(alfworld.agents.environment, config["env"]["type"])(config, train_eval=split)
    env = env.init_env(batch_size=1)

    num_successes: int = 0
    num_additional_successes: int = 0
    num_envs: int = len(env_configs)
    
    if is_resume_crash: # only for trial 0 TODO: later trials
        completed_envs_in_trial = {}
        with open(world_log_path, 'r') as rf:
            lines = rf.readlines()
            for line in lines:            
                if line.startswith('Environment'):
                    # line format: Environment #0 Trial #0: SUCCESS\n
                    # parse env and trial idx and success or not
                    #ipdb.set_trace()
                    log_env_idx = int(line.split('#')[1].split(' ')[0])
                    log_trial_idx = int(line.split('#')[2].split(':')[0])
                    if not trial_idx == log_trial_idx:
                        continue
                    success = line.split(':')[-1].strip() == 'SUCCESS'
                    #if success:
                        #num_successes += 1
                        #num_additional_successes += 1
                    completed_envs_in_trial[log_env_idx] = {'success': success}
        #ipdb.set_trace()
        print(f'RESUMING TRIAL {trial_idx} FROM ENV {len(completed_envs_in_trial)}')


    for z, env_config in enumerate(env_configs):
        ob, info = env.reset()
        ob = '\n'.join(ob[0].split('\n\n')[1:])
        name = '/'.join(info['extra.gamefile'][0].split('/')[-3:-1])
        
        
        #ipdb.set_trace()
        #if not name in ['pick_cool_then_place_in_recep-Potato-None-Microwave-10/trial_T20190907_033306_962974',
        #                'pick_two_obj_and_place-SoapBar-None-GarbageCan-424/trial_T20190909_064309_357168',
        #                'pick_cool_then_place_in_recep-Potato-None-Microwave-10/trial_T20190907_033306_962974']:
        #    env_config["is_success"] = True
        #if not name in ['pick_heat_then_place_in_recep-Cup-None-Cabinet-10/trial_T20190907_083346_800823',
        #                'look_at_obj_in_light-Mug-None-DeskLamp-308/trial_T20190908_201421_021646'
        #            ]:

        #if not name in ['pick_cool_then_place_in_recep-Potato-None-Microwave-10/trial_T20190907_033306_962974',
        #                'pick_two_obj_and_place-SoapBar-None-GarbageCan-424/trial_T20190909_064309_357168',
        #                'pick_heat_then_place_in_recep-Cup-None-Cabinet-10/trial_T20190907_083346_800823',
        #                'look_at_obj_in_light-Mug-None-DeskLamp-308/trial_T20190908_201421_021646']:
        #    env_config["is_success"] = True

        print(f"using {name}")

        if is_resume_crash:
            if z in completed_envs_in_trial:
                print(f'Skipping env {z} because it was completed in trial {trial_idx}')
                #ipdb.set_trace()
                if completed_envs_in_trial[z]['success']:
                    num_successes += 1                    
                    num_additional_successes += 1 # TODO only makes sense for trial 0
                    env_config["is_success"] = True
                # No need to write anything to world log or trial log
                continue

        else:
            if env_config["is_success"]:
                num_successes += 1

                # log to world log
                with open(world_log_path, 'a') as wf:
                    wf.write(f'Environment #{z} Trial #{trial_idx}: SUCCESS\n')
                with open(trial_log_path, 'a') as wf:
                    wf.write(f'\n#####\n\nEnvironment #{z}: Success\n\n#####\n')
                continue

        for i, (k, v) in enumerate(PREFIXES.items()):
            if name.startswith(k):
                base_prompt = 'Interact with a household to solve a task. Here are two examples.\n' + d[f'react_{v}_1'] + d[f'react_{v}_0']
                

                final_env_history, is_success = alfworld_run(env, base_prompt, env_config["memory"] if use_memory else [], to_print=True, ob=ob, world_model=env_config["world_model"] if use_state else [])

                # update env config
                if is_success:
                    status_str: str = f'Environment #{z} Trial #{trial_idx}: SUCCESS'
                    env_configs[z]['is_success'] = True
                    num_successes += 1
                    num_additional_successes += 1
                else:
                    status_str: str = f'Environment #{z} Trial #{trial_idx}: FAIL'

                # log to world log
                with open(world_log_path, 'a') as f:
                    f.write(status_str + '\n')

                # log env results to trial log
                with open(trial_log_path, 'a') as wf:
                    wf.write(f'\n#####\n\nEnvironment #{z}:\n{str(final_env_history)}\n\nSTATUS: {"OK" if is_success else "FAIL"}\n\n#####\n')

    # close environment object
    env.close()

    # log trial results to trial and world logs
    log_str: str = f"""
-----
SUCCESS: {num_successes}
ADDITIONAL SUCCESS: {num_additional_successes}
FAIL: {num_envs - num_successes}
TOTAL: {num_envs}
ACCURACY: {round(num_successes / num_envs, 2)}
-----"""
    with open(trial_log_path, 'a') as wf:
        wf.write(log_str)
    with open(world_log_path, 'a') as wf:
        wf.write(log_str + '\n')

    return env_configs
