import gym
import numpy as np
import torch
import wandb

import argparse
import pickle
import random
import sys

from decision_transformer.evaluation.evaluate_episodes import (
    evaluate_episode_rtg,
)
from decision_transformer.evaluation.evaluate_episodes_real_lang import (
    evaluate_episode_rtg as evaluate_episode_rtg_real_lang,
)
from decision_transformer.evaluation.evaluate_episodes_other_task import (
    evaluate_episode_rtg as evaluate_episode_rtg_other_task,
)
from decision_transformer.evaluation.evaluate_episodes_other_task_real_lang import (
    evaluate_episode_rtg as evaluate_episode_rtg_other_task_real_lang,
)
from decision_transformer.models.decision_transformer import DecisionTransformer
from decision_transformer.models.mlp_bc import MLPBCModel
from decision_transformer.training.act_trainer import ActTrainer
from decision_transformer.training.seq_trainer import SequenceTrainer
import torch.backends.cudnn as cudnn
from putplate_seed_util import seed as seed_validation_list
from putplate_seed_test_util import seed as seed_testing_list
from other_seed_val_util import seed as other_seed_validation_list

from putplate_seed_val_expert_len import length as seed_validation_expert_len
from other_seed_val_expert_len import length as other_seed_validation_expert_len

from decision_transformer.models.sentencebert import sentencebert_encode

import os

def discount_cumsum(x, gamma):
    discount_cumsum = np.zeros_like(x)
    discount_cumsum[-1] = x[-1]
    for t in reversed(range(x.shape[0] - 1)):
        discount_cumsum[t] = x[t] + gamma * discount_cumsum[t + 1]
    return discount_cumsum


def experiment(
    exp_prefix,
    variant,
):
    print(f"RUNNING in {variant['lang_mode']} MODE!!!!")

    if variant["gpu"] is None:
        device = variant.get("device", "cuda")
    else:
        device = f"cuda:{variant['gpu']}"
    log_to_wandb = variant.get("log_to_wandb", False)

    env_name, dataset = variant["env"], variant["dataset"]
    model_type = variant["model_type"]
    if variant["no_lang_input"]:
        exp_type = "nolanginput"
    elif variant["no_lang"]:
        exp_type = "nolang"
    elif variant["load_path"] is not None:
        if variant["real_lang"]:
            exp_type = "real-lang"
        else:
            exp_type = "template"
    else:
        if variant["real_lang"]:
            exp_type = "retrain-real-lang"
        else:
            exp_type = "retrain-template"

    if variant["load_path"] is not None:
        group_name = f"{exp_prefix}-{env_name}-putplates-{exp_type}-{variant['shot']}shots-{variant['model_iter']}iter"
    else:
        group_name = f"{exp_prefix}-{env_name}-putplates-{exp_type}"

    if variant["encoder"] == "sentencebert":
        group_name += "-sentencebert"
    if variant["perturb"]:
        group_name += "-perturb"
    if variant["only_fp"]:
        group_name += "-only-fp"
    if variant["only_hindsight"]:
        group_name += "-only-hindsight"
    if variant["random_lang"]:
        group_name += "-random-lang"

    exp_prefix = f"{group_name}-{random.randint(int(1e5), int(1e6) - 1)}"

    cudnn.benchmark = False
    cudnn.deterministic = True
    torch.manual_seed(42)
    torch.cuda.manual_seed(42)
    np.random.seed(42)
    random.seed(42)

    rng = np.random.RandomState(42)

    encoder = sentencebert_encode

    max_ep_len = 100

    env_targets = [1.5]
    scale = 1000.0

    if model_type == "bc":
        env_targets = env_targets[
            :1
        ]  # since BC ignores target, no need for different evaluations

    state_dim = (96, 96, 3)
    act_dim = 10
    lang_dim = 768

    # load dataset
    if variant["shot"] is not None:
        if variant["real_lang"] or variant["adapt"]:
            dataset_path = f"data/homegrid-{variant['shot']}shot-real-lang"
        else:
            dataset_path = f"data/homegrid-{variant['shot']}shot-template"
    else:
        if variant["real_lang"]:
            dataset_path = f"data/homegrid-retrain-real-lang"
        else:
            dataset_path = f"data/homegrid-retrain-template"

    if variant["encoder"] == "sentencebert":
        dataset_path += f"-{variant['encoder']}"
    if variant["perturb"]:
        dataset_path += "-perturb"
    if variant["only_fp"]:
        dataset_path += "-only-fp"
    if variant["only_hindsight"]:
        dataset_path += "-only-hindsight"
    if variant["random_lang"]:
        dataset_path += "-random-lang"
    if variant["no_lang_data"]:
        dataset_path += "-no-lang"

    dataset_path += ".pkl"

    with open(dataset_path, "rb") as f:
        print("load", dataset_path)
        trajectories = pickle.load(f)

    # save all path information into separate lists
    mode = variant.get("mode", "normal")
    states, lan_embeds, traj_lens, returns = [], [], [], []
    for path in trajectories:
        if mode == "delayed":  # delayed: all rewards moved to end of trajectory
            path["rewards"][-1] = path["rewards"].sum()
            path["rewards"][:-1] = 0.0

        states.append(path["observations"])

        lan_embeds.append(path["languages"])

        traj_lens.append(len(path["observations"]))
        returns.append(sum(path["rewards"]))

    traj_lens, returns = np.array(traj_lens), np.array(returns)

    # used for input normalization
    states = np.concatenate(states, axis=0)

    state_mean, state_std = np.mean(states, axis=0), np.std(states, axis=0) + 1e-6

    if variant["save"]:
        print(f"save state mean and std to {variant['save']}/state_mean_std.pkl")

        if not os.path.exists(f"{variant['save']}"):
            os.makedirs(f"{variant['save']}")

        with open(f"{variant['save']}/state_mean_std.pkl", "wb") as f:
            pickle.dump([state_mean, state_std], f)

    num_timesteps = sum(traj_lens)

    print("=" * 50)
    print(f"Starting new experiment: {env_name} {dataset}")
    print(f"{len(traj_lens)} trajectories, {num_timesteps} timesteps found")
    print(f"Average return: {np.mean(returns):.2f}, std: {np.std(returns):.2f}")
    print(f"Max return: {np.max(returns):.2f}, min: {np.min(returns):.2f}")
    print("=" * 50)

    K = variant["K"]
    batch_size = variant["batch_size"]
    num_eval_episodes = variant["num_eval_episodes"]
    pct_traj = variant.get("pct_traj", 1.0)

    # only train on top pct_traj trajectories (for %BC experiment)
    num_timesteps = max(int(pct_traj * num_timesteps), 1)
    sorted_inds = np.argsort(returns)  # lowest to highest
    num_trajectories = 1
    timesteps = traj_lens[sorted_inds[-1]]
    ind = len(trajectories) - 2
    while ind >= 0 and timesteps + traj_lens[sorted_inds[ind]] <= num_timesteps:
        timesteps += traj_lens[sorted_inds[ind]]
        num_trajectories += 1
        ind -= 1
    sorted_inds = sorted_inds[-num_trajectories:]

    # used to reweight sampling so we sample according to timesteps instead of trajectories
    p_sample = traj_lens[sorted_inds] / sum(traj_lens[sorted_inds])

    if variant["no_lang_input"]:
        no_lang_input = encoder([""]).to(device=device, dtype=torch.float32)

    def get_batch(batch_size=256, max_len=K):
        batch_inds = rng.choice(
            np.arange(num_trajectories),
            size=batch_size,
            replace=True,
            p=p_sample,  # reweights so we sample according to timesteps
        )

        s, a, l, r, d, rtg, timesteps, mask, tasks = [], [], [], [], [], [], [], [], []
        for i in range(batch_size):
            traj = trajectories[int(sorted_inds[batch_inds[i]])]
            tasks.append(traj["task_name"].reshape(1, 1, 768))
            si = random.randint(0, traj["rewards"].shape[0] - 1)

            # get sequences from dataset
            s.append(
                traj["observations"][si : si + max_len].reshape(
                    1, -1, state_dim[0], state_dim[1], state_dim[2]
                )
            )
            a.append(traj["actions"][si : si + max_len].reshape(1, -1, act_dim))

            l.append(traj["languages"][si : si + max_len].reshape(1, -1, 768))

            r.append(traj["rewards"][si : si + max_len].reshape(1, -1, 1))
            if "terminals" in traj:
                d.append(traj["terminals"][si : si + max_len].reshape(1, -1))
            else:
                d.append(traj["dones"][si : si + max_len].reshape(1, -1))
            timesteps.append(np.arange(si, si + s[-1].shape[1]).reshape(1, -1))
            timesteps[-1][timesteps[-1] >= max_ep_len] = (
                max_ep_len - 1
            )  # padding cutoff
            rtg.append(
                discount_cumsum(traj["rewards"][si:], gamma=1.0)[
                    : s[-1].shape[1] + 1
                ].reshape(1, -1, 1)
            )
            if rtg[-1].shape[1] <= s[-1].shape[1]:
                rtg[-1] = np.concatenate([rtg[-1], np.zeros((1, 1, 1))], axis=1)

            # padding and state + reward normalization
            tlen = s[-1].shape[1]
            s[-1] = np.concatenate(
                [
                    np.zeros(
                        (1, max_len - tlen, state_dim[0], state_dim[1], state_dim[2])
                    ),
                    s[-1],
                ],
                axis=1,
            )
            s[-1] = (s[-1] - state_mean) / state_std
            a[-1] = np.concatenate(
                [np.ones((1, max_len - tlen, act_dim)) * -10.0, a[-1]], axis=1
            )
            l[-1] = np.concatenate([np.zeros((1, max_len - tlen, 768)), l[-1]], axis=1)
            r[-1] = np.concatenate([np.zeros((1, max_len - tlen, 1)), r[-1]], axis=1)
            d[-1] = np.concatenate([np.ones((1, max_len - tlen)) * 2, d[-1]], axis=1)
            rtg[-1] = (
                np.concatenate([np.zeros((1, max_len - tlen, 1)), rtg[-1]], axis=1)
                / scale
            )
            timesteps[-1] = np.concatenate(
                [np.zeros((1, max_len - tlen)), timesteps[-1]], axis=1
            )
            mask.append(
                np.concatenate(
                    [np.zeros((1, max_len - tlen)), np.ones((1, tlen))], axis=1
                )
            )
        tasks = torch.from_numpy(np.concatenate(tasks, axis=0)).to(
            dtype=torch.float32, device=device
        )
        s = torch.from_numpy(np.concatenate(s, axis=0)).to(
            dtype=torch.float32, device=device
        )
        a = torch.from_numpy(np.concatenate(a, axis=0)).to(
            dtype=torch.float32, device=device
        )
        l = torch.from_numpy(np.concatenate(l, axis=0)).to(
            dtype=torch.float32, device=device
        )

        if variant["no_lang_input"]:
            l = no_lang_input.repeat(batch_size, 1, 1)

        r = torch.from_numpy(np.concatenate(r, axis=0)).to(
            dtype=torch.float32, device=device
        )
        d = torch.from_numpy(np.concatenate(d, axis=0)).to(
            dtype=torch.long, device=device
        )
        rtg = torch.from_numpy(np.concatenate(rtg, axis=0)).to(
            dtype=torch.float32, device=device
        )
        timesteps = torch.from_numpy(np.concatenate(timesteps, axis=0)).to(
            dtype=torch.long, device=device
        )
        mask = torch.from_numpy(np.concatenate(mask, axis=0)).to(device=device)

        return tasks, s, l, a, r, d, rtg, timesteps, mask

    def eval_episodes(target_rew):
        def fn(model):
            returns, lengths, performances, have_lang_ratios, normalized_returns = [], [], [], [], []
            if not variant["other_task"]:
                if variant["testing"]:
                    seed_list = seed_testing_list
                else:
                    seed_list = seed_validation_list
            else:
                print("="*25)
                print("OTHER TASK")
                seed_list = other_seed_validation_list


            for i in range(num_eval_episodes):
                with torch.no_grad():
                    if variant["real_lang_eval"]:
                        if not variant["other_task"]:
                            ret, length, valid, have_lang_ratio = evaluate_episode_rtg_real_lang(
                                None,
                                state_dim,
                                act_dim,
                                model,
                                max_ep_len=max_ep_len,
                                scale=scale,
                                target_return=target_rew / scale,
                                mode=mode,
                                state_mean=state_mean,
                                state_std=state_std,
                                device=device,
                                seed=int(seed_list[i]),
                                real_lang=variant["real_lang"],
                                encoder_type=variant["encoder"],
                                only_fp=variant["only_fp"],
                                train_ratio=variant["train_ratio"],
                                lang_mode=variant["lang_mode"],
                                val_ratio=variant["val_ratio"],
                                use_gpt=variant["use_gpt"],
                                use_gpt_enable_no_feedback=variant[
                                    "use_gpt_enable_no_feedback"
                                ],
                                random_lang=variant["random_lang"],
                            )
                            normalized_return = ret * seed_validation_expert_len[i] / max(length, seed_validation_expert_len[i])
                        else:
                            ret, length, valid, have_lang_ratio = (
                                evaluate_episode_rtg_other_task_real_lang(
                                    None,
                                    state_dim,
                                    act_dim,
                                    model,
                                    max_ep_len=max_ep_len,
                                    scale=scale,
                                    target_return=target_rew / scale,
                                    mode=mode,
                                    state_mean=state_mean,
                                    state_std=state_std,
                                    device=device,
                                    seed=int(seed_list[i]),
                                    real_lang=variant["real_lang"],
                                    encoder_type=variant["encoder"],
                                    only_fp=variant["only_fp"],
                                    train_ratio=variant["train_ratio"],
                                    lang_mode=variant["lang_mode"],
                                    val_ratio=variant["val_ratio"],
                                    use_gpt=variant["use_gpt"],
                                    use_gpt_enable_no_feedback=variant[
                                        "use_gpt_enable_no_feedback"
                                    ],
                                    random_lang=variant["random_lang"],
                                )
                            )
                            normalized_return = (
                                ret
                                * other_seed_validation_expert_len[i]
                                / max(length, other_seed_validation_expert_len[i])
                            )
                    elif variant["other_task"]:
                        ret, length, valid, have_lang_ratio = evaluate_episode_rtg_other_task(
                            None,
                            state_dim,
                            act_dim,
                            model,
                            max_ep_len=max_ep_len,
                            scale=scale,
                            target_return=target_rew / scale,
                            mode=mode,
                            state_mean=state_mean,
                            state_std=state_std,
                            device=device,
                            seed=int(seed_list[i]),
                            real_lang=variant["real_lang"],
                            encoder_type=variant["encoder"],
                            only_fp=variant["only_fp"],
                            only_hindsight=variant["only_hindsight"],
                            train_ratio=variant["train_ratio"],
                            lang_mode=variant["lang_mode"],
                            no_lang_eval=variant["no_lang_eval"],
                            val_ratio=variant["val_ratio"],
                            use_gpt=variant["use_gpt"],
                            use_gpt_enable_no_feedback=variant[
                                "use_gpt_enable_no_feedback"
                            ],
                            random_lang=variant["random_lang"],
                        )
                        normalized_return = ret * other_seed_validation_expert_len[i] / max(length, other_seed_validation_expert_len[i])

                    elif model_type == "dt":
                        ret, length, valid, have_lang_ratio = evaluate_episode_rtg(
                            None,
                            state_dim,
                            act_dim,
                            model,
                            max_ep_len=max_ep_len,
                            scale=scale,
                            target_return=target_rew / scale,
                            mode=mode,
                            state_mean=state_mean,
                            state_std=state_std,
                            device=device,
                            seed=int(seed_list[i]),
                            real_lang=variant["real_lang"],
                            encoder_type=variant["encoder"],
                            only_fp=variant["only_fp"],
                            only_hindsight=variant["only_hindsight"],
                            train_ratio=variant["train_ratio"],
                            lang_mode=variant["lang_mode"],
                            no_lang_eval=variant["no_lang_eval"],
                            val_ratio=variant["val_ratio"],
                            use_gpt=variant["use_gpt"],
                            use_gpt_enable_no_feedback=variant[
                                "use_gpt_enable_no_feedback"
                            ],
                            random_lang=variant["random_lang"],
                        )
                        normalized_return = ret * seed_validation_expert_len[i] / max(length, seed_validation_expert_len[i])
                    else:
                        assert NotImplementedError
                if not valid:
                    continue
                returns.append(ret)
                lengths.append(length)
                have_lang_ratios.append(have_lang_ratio)
                normalized_returns.append(normalized_return)
                complete_reward_threshold = 1 if variant["other_task"] else 1.5
            return {
                f"target_{target_rew}_return_mean": np.mean(returns),
                f"target_{target_rew}_return_std": np.std(returns),
                f"target_{target_rew}_length_mean": np.mean(lengths),
                f"target_{target_rew}_length_std": np.std(lengths),
                f"target_{target_rew}_complete_rate": np.sum(
                    np.array(returns) >= complete_reward_threshold
                )
                / len(returns),
                f"target_{target_rew}_subgoal_complete_rate": np.sum(
                    np.array(returns) >= 0.5
                )
                / len(returns),
                f"target_{target_rew}_have_lang_ratio": np.mean(have_lang_ratios),
                f"target_{target_rew}_normalized_return_mean": np.mean(normalized_returns),
            }

        return fn

    if model_type == "dt":
        model = DecisionTransformer(
            state_dim=state_dim,
            act_dim=act_dim,
            lang_dim=lang_dim,
            max_length=K,
            max_ep_len=max_ep_len,
            hidden_size=variant["embed_dim"],
            n_layer=variant["n_layer"],
            n_head=variant["n_head"],
            n_inner=4 * variant["embed_dim"],
            activation_function=variant["activation_function"],
            n_positions=1024,
            resid_pdrop=variant["dropout"],
            attn_pdrop=variant["dropout"],
            no_lang=variant["no_lang"],
            no_lang_input=variant["no_lang_input"],
        )
    elif model_type == "bc":
        model = MLPBCModel(
            state_dim=state_dim,
            act_dim=act_dim,
            max_length=K,
            hidden_size=variant["embed_dim"],
            n_layer=variant["n_layer"],
        )
    else:
        raise NotImplementedError

    model = model.to(device=device)

    if variant["load_path"] is not None:
        print("load model from path:", variant["load_path"])
        model.load_state_dict(
            torch.load(f"{variant['load_path']}/model-{variant['model_iter']}.pt")
        )
    else:
        print("train from scratch!")

    warmup_steps = variant["warmup_steps"]
    optimizer = torch.optim.AdamW(
        model.parameters(),
        lr=variant["learning_rate"],
        weight_decay=variant["weight_decay"],
    )


    scheduler = torch.optim.lr_scheduler.LambdaLR(
        optimizer, lambda steps: min((steps + 1) / warmup_steps, 1)
    )

    if model_type == "dt":
        trainer = SequenceTrainer(
            model=model,
            optimizer=optimizer,
            batch_size=batch_size,
            get_batch=get_batch,
            scheduler=scheduler,
            loss_fn=lambda s_hat, a_hat, r_hat, s, a, r: torch.nn.BCEWithLogitsLoss()(
                a_hat, a
            ),
            eval_fns=[eval_episodes(tar) for tar in env_targets],
        )
    elif model_type == "bc":
        trainer = ActTrainer(
            model=model,
            optimizer=optimizer,
            batch_size=batch_size,
            get_batch=get_batch,
            scheduler=scheduler,
            loss_fn=lambda s_hat, a_hat, r_hat, s, a, r: torch.nn.BCEWithLogitsLoss()(
                a_hat, a
            ),
            eval_fns=[eval_episodes(tar) for tar in env_targets],
        )

    if log_to_wandb:
        wandb.login()
        wandb.init(
            name=exp_prefix,
            group=group_name,
            project="decision-transformer",
            config=variant,
        )
        # wandb.watch(model)  # wandb has some bug
    print("initial eval")
    outputs = trainer.initial_eval()
    if log_to_wandb:
        wandb.log(outputs)

    print("in training")
    for iter in range(variant["max_iters"]):
        outputs = trainer.train_iteration(
            num_steps=variant["num_steps_per_iter"], iter_num=iter + 1, print_logs=True
        )
        if log_to_wandb:
            wandb.log(outputs)
        if variant["save"] is not None:
            print(f"save model at {variant['save']}/model-{iter}.pt")
            torch.save(model.state_dict(), f"{variant['save']}/model-{iter}.pt")

    if variant["save"]:
        print(f"save model at {variant['save']}/model.pt")
        torch.save(model.state_dict(), f"{variant['save']}/model.pt")


if __name__ == "__main__":
    parser = argparse.ArgumentParser()
    parser.add_argument("--env", type=str, default="homegrid")
    parser.add_argument(
        "--dataset", type=str, default="non-expert"
    )  # medium, medium-replay, medium-expert, expert
    parser.add_argument(
        "--mode", type=str, default="normal"
    )  # normal for standard setting, delayed for sparse
    parser.add_argument("--K", type=int, default=10)
    parser.add_argument("--pct_traj", type=float, default=1.0)
    parser.add_argument("--batch_size", type=int, default=64)
    parser.add_argument(
        "--model_type", type=str, default="dt"
    )  # dt for decision transformer, bc for behavior cloning
    parser.add_argument("--embed_dim", type=int, default=128)
    parser.add_argument("--n_layer", type=int, default=3)
    parser.add_argument("--n_head", type=int, default=1)
    parser.add_argument("--activation_function", type=str, default="relu")
    parser.add_argument("--dropout", type=float, default=0.1)
    parser.add_argument("--learning_rate", "-lr", type=float, default=1e-4)
    parser.add_argument("--weight_decay", "-wd", type=float, default=1e-4)
    parser.add_argument("--warmup_steps", type=int, default=10000)
    parser.add_argument("--num_eval_episodes", type=int, default=100)
    parser.add_argument("--max_iters", type=int, default=30)
    parser.add_argument("--num_steps_per_iter", type=int, default=1000)
    parser.add_argument("--device", type=str, default="cuda")
    parser.add_argument("--log_to_wandb", "-w", type=int, default=1)
    parser.add_argument("--shot", type=int, default=None)
    parser.add_argument("--save", type=str, default=None)
    parser.add_argument("--no_lang", type=bool, default=False)
    parser.add_argument("--gpu", type=int, default=None)
    parser.add_argument("--model_iter", type=int, default=29)
    parser.add_argument("--no_lang_input", type=bool, default=False)
    parser.add_argument("--real_lang", type=bool, default=False)
    parser.add_argument("--encoder", type=str, default="sentencebert")
    parser.add_argument("--load_path", type=str, default=None)
    parser.add_argument("--perturb", type=bool, default=True)
    parser.add_argument("--only_fp", type=bool, default=False)
    parser.add_argument("--real_lang_eval", type=bool, default=False)
    parser.add_argument("--only_hindsight", type=bool, default=False)
    parser.add_argument("--train_ratio", type=float, default=0.8)
    parser.add_argument("--lang_mode", type=str, default="val")
    parser.add_argument("--adapt", type=bool, default=False)
    parser.add_argument("--no_lang_eval", type=int, default=0)
    parser.add_argument("--testing", type=bool, default=False)
    parser.add_argument("--val_ratio", type=float, default=0.1)
    parser.add_argument("--use_gpt", type=bool, default=False)
    parser.add_argument("--use_gpt_enable_no_feedback", type=bool, default=False)
    parser.add_argument("--random_lang", type=bool, default=False)
    parser.add_argument("--no_lang_data", type=bool, default=False)
    parser.add_argument("--other_task", type=bool, default=False)

    args = parser.parse_args()

    assert not (args.no_lang_eval and args.real_lang_eval)
    assert args.testing or not args.num_eval_episodes > 100
    assert args.use_gpt_enable_no_feedback or not args.use_gpt

    experiment("gym-experiment", variant=vars(args))
