from decision_transformer.evaluation.evaluate_episodes_other_task import evaluate_episode_rtg
from other_seed_test_util import seed as seed_list
import argparse
import torch
import numpy as np
from decision_transformer.models.decision_transformer import DecisionTransformer
import pickle
import os
import torch.backends.cudnn as cudnn
from other_seed_test_expert_len import length as other_seed_testing_expert_len

cudnn.benchmark = False
cudnn.deterministic = True
torch.manual_seed(42)
torch.cuda.manual_seed(42)


def main(variant):
    for time in range(variant["eval_times"]):
        num_eval_episodes = variant["num_eval_episodes"]
        state_dim = (96, 96, 3)
        act_dim = 10
        lang_dim = 768
        max_ep_len = 100
        scale = 1000.0
        target_rew = 1.5
        mode = variant["mode"]
        if variant["gpu"] is None:
            device = variant.get("device", "cuda")
        else:
            device = f"cuda:{variant['gpu']}"

        K = variant["K"]

        # load state mean and state std
        if os.path.exists(f"{variant['load_path']}/state_mean_std.pkl"):
            with open(f"{variant['load_path']}/state_mean_std.pkl", "rb") as f:
                print(
                    f"load {variant['load_path']}/state_mean_std.pkl for state mean and std"
                )
                state_mean, state_std = pickle.load(f)
        else:
            with open(f"data/{variant['load_pkl_for_state_mean_std']}", "rb") as f:
                print(
                    "load "
                    + f"data/{variant['load_pkl_for_state_mean_std']}"
                    + " for state mean and std"
                )
                trajectories = pickle.load(f)

            states, lan_embeds, traj_lens, returns = [], [], [], []
            for path in trajectories:
                if mode == "delayed":  # delayed: all rewards moved to end of trajectory
                    path["rewards"][-1] = path["rewards"].sum()
                    path["rewards"][:-1] = 0.0

                states.append(path["observations"])

                lan_embeds.append(path["languages"])

                traj_lens.append(len(path["observations"]))
                returns.append(sum(path["rewards"]))

            traj_lens, returns = np.array(traj_lens), np.array(returns)

            # used for input normalization
            states = np.concatenate(states, axis=0)

            state_mean, state_std = np.mean(states, axis=0), np.std(states, axis=0) + 1e-6

        model = DecisionTransformer(
            state_dim=state_dim,
            act_dim=act_dim,
            lang_dim=lang_dim,
            max_length=K,
            max_ep_len=max_ep_len,
            hidden_size=variant["embed_dim"],
            n_layer=variant["n_layer"],
            n_head=variant["n_head"],
            n_inner=4 * variant["embed_dim"],
            activation_function=variant["activation_function"],
            n_positions=1024,
            resid_pdrop=variant["dropout"],
            attn_pdrop=variant["dropout"],
            no_lang=variant["no_lang"],
            no_lang_input=variant["no_lang_input"],
        )
        model = model.to(device=device)

        print(f"load model from {variant['load_path']}/model-{variant['model_iter']}.pt")
        model.load_state_dict(
            torch.load(
                f"{variant['load_path']}/model-{variant['model_iter']}.pt",
                map_location=device,
            )
        )

        returns, lengths, have_lang_ratios, normalized_returns = [], [], [], []
        logs = []
        for i in range(num_eval_episodes):
            if i < variant["skip"]:
                continue
            with torch.no_grad():

                ret, length, valid, have_lang_ratio, log = evaluate_episode_rtg(
                    None,
                    state_dim,
                    act_dim,
                    model,
                    max_ep_len=max_ep_len,
                    scale=scale,
                    target_return=target_rew / scale,
                    mode=mode,
                    state_mean=state_mean,
                    state_std=state_std,
                    device=device,
                    seed=int(seed_list[i]) if variant["seed"] is None else variant["seed"],
                    real_lang=(variant["real_lang"] or variant["real_lang_eval"]),
                    encoder_type=variant["encoder"],
                    only_fp=variant["only_fp"],
                    train_ratio=variant["train_ratio"],
                    lang_mode=variant["lang_mode"],
                    no_lang_eval=variant["no_lang_eval"],
                    use_gpt=variant["use_gpt"],
                    use_gpt_enable_no_feedback=variant["use_gpt_enable_no_feedback"],
                    random_lang=variant["random_lang"],
                    attack_hind=variant["attack_hind"],
                    attack_fore=variant["attack_fore"],
                    random_rate=variant["random_rate"],
                )
                normalized_return = ret * other_seed_testing_expert_len[i] / max(length, other_seed_testing_expert_len[i])
                
                if not valid:
                    continue
                returns.append(ret)
                lengths.append(length)
                have_lang_ratios.append(have_lang_ratio)
                normalized_returns.append(normalized_return)
        result = {
            f"target_{target_rew}_return_mean": np.mean(returns),
            f"target_{target_rew}_return_std": np.std(returns),
            f"target_{target_rew}_length_mean": np.mean(lengths),
            f"target_{target_rew}_length_std": np.std(lengths),
            f"target_{target_rew}_complete_rate": np.sum(np.array(returns) >= 1)
            / len(returns),
            f"target_{target_rew}_subgoal_complete_rate": np.sum(np.array(returns) >= 0.5)
            / len(returns),
            f"target_{target_rew}_have_lang_ratio": np.mean(have_lang_ratios),
            f"target_{target_rew}_normalized_return_mean": np.mean(normalized_returns),
        }

        print(result)


if __name__ == "__main__":
    parser = argparse.ArgumentParser()
    parser.add_argument("--env", type=str, default="homegrid")
    parser.add_argument(
        "--dataset", type=str, default="non-expert"
    )  # medium, medium-replay, medium-expert, expert
    parser.add_argument(
        "--mode", type=str, default="normal"
    )  # normal for standard setting, delayed for sparse
    parser.add_argument("--K", type=int, default=10)
    parser.add_argument("--pct_traj", type=float, default=1.0)
    parser.add_argument("--batch_size", type=int, default=64)
    parser.add_argument(
        "--model_type", type=str, default="dt"
    )  # dt for decision transformer, bc for behavior cloning
    parser.add_argument("--embed_dim", type=int, default=128)
    parser.add_argument("--n_layer", type=int, default=3)
    parser.add_argument("--n_head", type=int, default=1)
    parser.add_argument("--activation_function", type=str, default="relu")
    parser.add_argument("--dropout", type=float, default=0.1)
    parser.add_argument("--learning_rate", "-lr", type=float, default=1e-4)
    parser.add_argument("--weight_decay", "-wd", type=float, default=1e-4)
    parser.add_argument("--warmup_steps", type=int, default=10000)
    parser.add_argument("--num_eval_episodes", type=int, default=100)
    parser.add_argument("--max_iters", type=int, default=30)
    parser.add_argument("--num_steps_per_iter", type=int, default=1000)
    parser.add_argument("--device", type=str, default="cuda")
    parser.add_argument("--log_to_wandb", "-w", type=int, default=1)
    parser.add_argument("--load", type=bool, default=False)
    parser.add_argument("--shot", type=int, default=None)
    parser.add_argument("--save", type=str, default=None)
    parser.add_argument("--no_lang", type=bool, default=False)
    parser.add_argument("--gpu", type=int, default=None)
    parser.add_argument("--model_iter", type=int, default=None)
    parser.add_argument("--no_lang_input", type=bool, default=False)
    parser.add_argument("--real_lang", type=int, default=1)
    parser.add_argument("--encoder", type=str, default="sentencebert")
    parser.add_argument("--load_path", type=str, default=None)
    parser.add_argument("--only_fp", type=bool, default=False)
    parser.add_argument("--real_lang_eval", type=bool, default=False)
    parser.add_argument("--only_hindsight", type=bool, default=False)
    parser.add_argument("--train_ratio", type=float, default=0.8)
    parser.add_argument("--lang_mode", type=str, default="test")
    parser.add_argument("--no_lang_eval", type=int, default=0)
    parser.add_argument("--testing", type=bool, default=True)
    parser.add_argument("--load_pkl_for_state_mean_std", type=str, default=None)
    parser.add_argument("--use_gpt", type=int, default=1)
    parser.add_argument("--use_gpt_enable_no_feedback", type=int, default=1)
    parser.add_argument("--skip", type=int, default=0)
    parser.add_argument("--seed", type=int, default=None)
    parser.add_argument("--random_lang", type=bool, default=False)
    parser.add_argument("--attack_hind", type=bool, default=False)
    parser.add_argument("--attack_fore", type=bool, default=False)
    parser.add_argument("--eval_times", type=int, default=1)
    parser.add_argument("--random_rate", type=float, default=0.3)
    args = parser.parse_args()

    assert args.testing == True
    # assert args.load_pkl_for_state_mean_std is not None
    assert args.load_path is not None
    assert args.model_iter is not None
    if args.use_gpt_enable_no_feedback:
        assert args.use_gpt

    main(variant=vars(args))
