import random
import argparse
import llfbench
import numpy as np
import gymnasium as gym

from tqdm import tqdm
from llfbench.envs.llf_env import LLFWrapper

from test_utils import parse_action, parse_observation, extract_task_name, get_subgoal_finished, generate_gpt_feedback

from sentencebert import sentencebert_encode
import pickle
import os
import re
from test_utils import ACTIONS_MAP

random_nums = pickle.load(open("random_nums.pkl", "rb"))
random_int_nums = pickle.load(open("random_int_nums.pkl", "rb"))

def get_random_action(env, env_name, info, index):

    if isinstance(env.action_space, gym.spaces.Text):
        if "alfworld" in env_name.lower():
            new_action_map = [action for action in ACTIONS_MAP if action not in ["inventory", "examine"]]
            valid_actions = [action for action in info["admissible_commands"] if action.split()[0] in new_action_map]
            valid_actions.sort()
            action = valid_actions[index % len(valid_actions)]
        else:
            action = None
    else:
        action = env.action_space.sample()

    return action


def get_expert_action(env, env_name, info):

    if "expert_action" in info:
        action = info["expert_action"]
        if action is None:
            # IF action is None, it means that all expert actions are taken so just take random actions
            action = get_random_action(env, env_name, info)
        else:
            assert NotImplementedError
    else:
        action = None

    return action


def obs_space_contains_obs(obs, obs_space):
    for k in obs:
        if isinstance(obs[k], str):
            if len(obs) == 0:
                return True
        return obs_space[k].contains(obs[k]) or (obs[k] is None)


def get_return(env, env_name, agent, seed, only_fp=False, only_hindsight=False, no_lang=False, real_lang=False, encoder=None):
    np.random.seed(seed)
    random.seed(seed)
    cur_random_nums = random_nums[seed*100: (seed+1)*100]
    cur_random_int_nums = random_int_nums[seed*100: (seed+1)*100]
    observations = []
    languages = []
    rewards = []
    actions = []
    terminateds = []
    
    assert len(env.reward_range) == 2

    env.seed(seed=seed)
    obs, info = env.reset(seed=seed)
    total_return = 0.0
    completed = False

    horizon = 1000 
    task_name = extract_task_name(obs["instruction"])
    
    obs["observation"] = obs["instruction"]
    
    if os.path.exists("log.txt"):
        os.remove("log.txt")
    
    lang = task_name + '.'
    lang = lang.capitalize()
    subgoal_1 = 0
    subgoal_2 = 0
    subgoal_3 = 0
    while not completed:
        idx = np.random.randint(0, int(200*0.8))
        env.set_paraphrase_method(idx)
        horizon -= 1

        if agent == "random":
            action = get_random_action(env, env_name, info)
        elif agent == "expert":
            random_num = cur_random_nums.pop()
            if random_num < 0.3:
                random_num_int = cur_random_int_nums.pop()
                action = get_random_action(env, env_name, info, random_num_int)
            else:
                action = get_expert_action(env, env_name, info)
        else:
            raise AssertionError(f"Unhandled agent type {agent}. Can only handle random and expert agent.")

        if action is None:
            print(f"Cannot evaluate {agent} for {env_name}")
            return float("nan")
        
        next_obs, reward, terminated, truncated, next_info = env.step(action)
        observations.append(parse_observation(obs["observation"], task_name, subgoal_1, subgoal_2, subgoal_3))
        subgoal1, subgoal2, subgoal3 = get_subgoal_finished(task_name, next_obs["observation"])

        subgoal_1 = max(subgoal1, subgoal_1)
        subgoal_2 = max(subgoal2, subgoal_2)
        subgoal_3 = max(subgoal3, subgoal_3)
        actions.append(parse_action(action))
        rewards.append(reward)
        terminateds.append(terminated)
        
        languages.append(encoder([lang]).numpy())
        
        if real_lang:
            feedback = generate_gpt_feedback(env, action, reward, info, env.last_infos)
            if only_fp:
                lang = feedback.fp
            elif only_hindsight:
                lang = feedback.hn
            else:
                lang = feedback.hn + ' ' + feedback.fp
        elif next_obs["feedback"] is None or no_lang:
            lang = ""
        else:
            lang = next_obs["feedback"]
        
        total_return += reward
        completed = terminated or truncated or (horizon <= 0)

        assert env.action_space.contains(action)
        assert obs_space_contains_obs(obs, env.observation_space)
        assert obs_space_contains_obs(next_obs, env.observation_space)
        assert type(terminated) == bool
        assert type(truncated) == bool
        assert type(reward) == float or type(reward) == int
        assert type(info) == dict and type(next_info) == dict
        assert env.reward_range[0] <= reward <= env.reward_range[1]
        info = next_info
        
        obs = next_obs

    path = {
        "task_name": encoder([task_name]),
        "observations": np.array(observations),
        "languages": np.array(languages),
        "rewards": np.array(rewards),
        "actions": np.array(actions),
        "terminals": np.array(terminateds),
    }
    return total_return, path


def test_wrapper(env):
    if isinstance(env, LLFWrapper):
        return True
    elif hasattr(env, 'env'):
        return test_wrapper(env.env)
    else:
        return False


def test_env(env_name, agent, num_eps=1, seed=0, save=None, no_lang=False, only_fp=False, only_hindsight=False, real_lang=False, encoder=None):
    traj = []
    instruction_types, feedback_types = llfbench.supported_types(env_name)
    feedback_types = list(feedback_types) + ['n', 'a', 'm']
    
    assert not (only_fp and only_hindsight)
    if only_fp:
        feedback_type = ['fp']
    elif only_hindsight:
        feedback_type = ['hn']
    else:
        feedback_type = ['hn', 'fp']
    
    configs = [{"instruction_type": 'b', "feedback_type": feedback_type}]

    for config in configs:

        random.seed(seed)
        np.random.seed(seed)
      
        env = llfbench.envs.alfworld.make_env(env_name, **config)  # test llfbench.make
        
        
        assert test_wrapper(env)                 # test LLFWrapper is used

        all_returns = []
        
        for i in range(num_eps):
            if i % 10 == 0:
                print(i)
            ret, path = get_return(env=env, env_name=env_name, agent=agent, seed=i, only_fp=only_fp, only_hindsight=only_hindsight, no_lang=no_lang, real_lang=real_lang, encoder=encoder)
            all_returns.append(ret)
            traj.append(path)
        
        # all_returns = [get_return(env=env, env_name=env_name, agent=agent) for _ in range(num_eps)]

        print(f"Environment: {env_name}, Config {config}, Number of episodes {num_eps}, "
              f"Mean return {np.mean(all_returns):.3f}, "
              f"Std return {np.std(all_returns):.3f}, "
              f"Max return {np.max(all_returns):.3f}, "
              f"Min return {np.min(all_returns):.3f}")
    if save is not None:
        data_path = f"alfworld-{num_eps}shot-hypo1"
        if real_lang:
            data_path += "-real-lang"
        if only_fp:
            data_path += "-only-fp"
        elif no_lang:
            data_path += "-no-lang"
        elif only_hindsight:
            data_path += "-only-hindsight"
        else:
            assert ValueError("Invalid save option")
        data_path += ".pkl"
        with open(data_path, "wb") as f:
            pickle.dump(traj, f)

    print(f"save offline data at {data_path}")
    print("traj len", len(traj))


def test_benchmark(benchmark_prefix, num_eps, agent, save, no_lang=False, only_fp=False, real_lang=False, only_hindsight=False):

    all_envs = []
    for env_name in gym.envs.registry:
        if benchmark_prefix in env_name:
            all_envs.append(env_name)

    print(f'Number of {benchmark_prefix} environments: ', len(all_envs))

    encoder = sentencebert_encode
    
    for env_name in tqdm(all_envs):
        test_env(env_name=env_name,
                 agent=agent,
                 num_eps=num_eps,
                 save=save,
                 no_lang=no_lang,
                 only_fp=only_fp,
                 only_hindsight=only_hindsight,
                 real_lang=real_lang,
                 encoder=encoder)


if __name__ == '__main__':

    parser = argparse.ArgumentParser(description='Script to test heuristic agents on LLF environments.')
    parser.add_argument('-b', '--benchmark_prefix', type=str, default='llf-alfworld', help='Prefix of the suite of environments to test')
    parser.add_argument('-n', '--num_eps', type=int, default=10, help="Number of episodes to evaluate")
    parser.add_argument('-a', '--agent', type=str, default="expert", help="type of agent: random or expert",
                        choices=["random", "expert"])
    parser.add_argument('--save', type=int, default=1)
    parser.add_argument('--no_lang', type=bool, default=False)
    parser.add_argument('--only_fp', type=bool, default=False)
    parser.add_argument("--only_hindsight", type=bool, default=False)
    parser.add_argument('--real_lang', type=bool, default=False)
    test_benchmark(**vars(parser.parse_args()))
