import numpy as np
import torch

from decision_transformer.models.sentencebert import sentencebert_encode
from decision_transformer.evaluation.eval_utils import extract_task_name, parse_observation, form_action_str, get_subgoal_finished, generate_gpt_feedback

import random
import json

from llfbench.prompt_template import PROMPT, PROMPT_ENABLE_NO_FEEDBACK
import openai
import re

rng = np.random.default_rng()

def evaluate_episode_rtg(
    env,
    state_dim,
    act_dim,
    model,
    max_ep_len=1000,
    scale=1000.0,
    state_mean=0.0,
    state_std=1.0,
    device="cuda",
    target_return=None,
    mode="normal",
    seed=None,
    real_lang=False,
    encoder_type="sentencebert",
    only_fp=False,
    only_hindsight=False,
    train_ratio=0.8,
    lang_mode="val",
    no_lang_eval=False,
    use_gpt=False,
    use_gpt_enable_no_feedback=False,
    attack_hind=0,
    attack_fore=0,
    random_lang=False,
    random_rate=0.3
):
    encoder = sentencebert_encode

    model.eval()
    model.to(device=device)

    state_mean = torch.from_numpy(state_mean).to(device=device)
    state_std = torch.from_numpy(state_std).to(device=device)

    env.seed(seed)
    np.random.seed(seed)
    
    state, _ = env.reset(seed=seed)

    task_name = extract_task_name(state["instruction"])
    task = encoder([task_name]).to(device=device, dtype=torch.float32)
        
    state["observation"] = state["instruction"]

    state = parse_observation(state["observation"], task_name, 0, 0, 0)

    if mode == "noise":
        state = state + np.random.normal(0, 0.1, size=state.shape)
    # we keep all the histories on the device
    # note that the latest action and reward will be "padding"
    states = (
        torch.from_numpy(state)
        .reshape(1, state_dim)
        .to(device=device, dtype=torch.float32)
    )

    actions = torch.zeros((0, act_dim), device=device, dtype=torch.float32)
    rewards = torch.zeros(0, device=device, dtype=torch.float32)

    languages = (encoder([task_name])).to(device=device, dtype=torch.float32)

    languages = languages.reshape(1, -1).to(device=device, dtype=torch.float32)

    ep_return = target_return
    target_return = torch.tensor(ep_return, device=device, dtype=torch.float32).reshape(
        1, 1
    )
    timesteps = torch.tensor(0, device=device, dtype=torch.long).reshape(1, 1)

    episode_return, episode_length = 0, 0
    success = 0
    give_lang_steps = 0
    
    subgoal_1, subgoal_2, subgoal_3 = 0, 0, 0


    for t in range(max_ep_len):
        if lang_mode == "val":
            idx = np.random.randint(int(200*0.8), int(200*0.9))
        else:
            idx = np.random.randint(int(200*0.9), 200)
        env.set_paraphrase_method(idx)
        # add padding
        actions = torch.cat([actions, torch.zeros((1, act_dim), device=device)], dim=0)
        rewards = torch.cat([rewards, torch.zeros(1, device=device)])
        action = model.get_action(
            task.to(dtype=torch.float32),
            (states.to(dtype=torch.float32) - state_mean) / state_std,
            languages.to(dtype=torch.float32),
            actions.to(dtype=torch.float32),
            rewards.to(dtype=torch.float32),
            target_return.to(dtype=torch.float32),
            timesteps.to(dtype=torch.long),
        )

        action = action.detach().cpu().numpy()

        action = np.array(action, dtype=float)

        action_str = form_action_str(action, task_name)

        actions[-1] = torch.tensor(action)
        action = list(action)

        state, reward, done, truncated, info = env.step(action_str)

        
        feedback = generate_gpt_feedback(env, action_str, reward, info, env.last_infos, attack_hind=attack_hind, attack_fore=attack_fore)

        lang = feedback.hn + " " + feedback.fp

        if no_lang_eval:
            lang = ""

        if no_lang_eval:
            assert lang == ""
        if random_lang and rng.random() < random_rate:
            lang = ""
        if use_gpt:
            assert real_lang
            openai.api_key = ""
            if not use_gpt_enable_no_feedback:
                prompt = (
                    PROMPT.replace("{feedback}", state["feedback"])
                    .replace("{task_name}", task_name)
                )
            else:
                prompt = (
                    PROMPT_ENABLE_NO_FEEDBACK.replace("{feedback}", state["feedback"])
                    .replace("{task_name}", task_name)
                )

            try:
                response = openai.chat.completions.create(
                    model="gpt-3.5-turbo-1106",
                    messages=[
                        {"role": "user", "content": prompt},
                    ],
                    temperature=0.7,
                )
                response = response.choices[0].message.content

                pattern = r"\{([^{}]*)\}"
                matches = re.findall(pattern, response)
                response = "{" + matches[0] + "}"

                response = json.loads(response)

                lang = response["response"]

                if use_gpt_enable_no_feedback:
                    if not response["if_give_response"]:
                        lang = ""
                    else:
                        give_lang_steps += 1
                else:
                    give_lang_steps += 1

            except Exception as e:
                return 0, 0, False, 1, 0, (0, 0, 0), None
            
        lang = encoder([lang])
        
        _state = parse_observation(state["observation"], task_name, subgoal_1, subgoal_2, subgoal_3)
        
        subgoal1, subgoal2, subgoal3 = get_subgoal_finished(task_name, state["observation"])
        
        subgoal_1 = max(subgoal_1, subgoal1)
        subgoal_2 = max(subgoal_2, subgoal2)
        subgoal_3 = max(subgoal_3, subgoal3)

        cur_state = torch.from_numpy(_state).to(device=device).reshape(1, state_dim)
        cur_lang = lang.to(device=device).reshape(1, 768)
        states = torch.cat([states, cur_state], dim=0)
        languages = torch.cat([languages, cur_lang], dim=0)

        if mode != "delayed":
            pred_return = target_return[0, -1] - (reward / scale)
        else:
            pred_return = target_return[0, -1]
        target_return = torch.cat([target_return, pred_return.reshape(1, 1)], dim=1)
        timesteps = torch.cat(
            [timesteps, torch.ones((1, 1), device=device, dtype=torch.long) * (t + 1)],
            dim=1,
        )

        episode_return += reward
        episode_length += 1

        if done or truncated:
            if done:
                success = 1
            break
        
    have_lang_ratio = give_lang_steps / episode_length
    return episode_return, episode_length, True, success, have_lang_ratio, (subgoal_1, subgoal_2, subgoal_3), None

