import argparse
import json
import utils
import random
from utils_prompt import *
import numpy as np
import pickle

from planer import Planer
from sentencebert import sentencebert_encode

import json

with open("./homegrid/template.json") as f:
    template = json.load(f)

def make_one_hot_encoding(action):
    one_hot = np.zeros(10)
    one_hot[action] = 1

    return one_hot

def add_perturbation(action_list):
    new_action_list = []
    for action in action_list:
        if np.random.rand() < 0.4:
            if action in [0, 1, 2, 3]:
                # insert navigate error
                if action == 0:
                    new_action_list.append(1)
                    new_action_list.append(0)
                    new_action_list.append(0)
                elif action == 1:
                    new_action_list.append(0)
                    new_action_list.append(1)
                    new_action_list.append(1)
                elif action == 2:
                    new_action_list.append(3)
                    new_action_list.append(2)
                    new_action_list.append(2)
                else:
                    new_action_list.append(2)
                    new_action_list.append(3)
                    new_action_list.append(3)
            elif action in [7, 8, 9]:
                # insert drop before open
                if np.random.rand() < 0.5:
                    new_action_list.append(5)
                    new_action_list.append(action)
                # insert wrong open
                else:
                    while True:
                        new_action = np.random.choice([7, 8, 9])
                        if new_action != action:
                            break
                    new_action_list.append(new_action)
                    new_action_list.append(action)
            else:
                new_action_list.append(action)
        else:
            new_action_list.append(action)
    return new_action_list


def data_collection(variant):
    encoder = sentencebert_encode

    find_cnt = 0
    get_cnt = 0
    open_cnt = 0
    put_cnt = 0
    move_cnt = 0

    random.seed(42)
    np.random.seed(42)
    
    trajectories = []
    tot_cnt = 0
    
    if variant["retrain"]:
        seed_num = 10000
    else:
        seed_num = 50000

    # 0 for val, 100000 for train, 200000 for test
    offset = 100000
    seed_list = np.random.choice(100000, seed_num, replace=False) + offset
    
    print("="*50)
    if offset == 0:
        print("YOU ARE RUN FOR VAL")
    elif offset == 100000:
        print("YOU ARE RUN FOR TRAIN")
    elif offset == 200000:
        print("YOU ARE RUN FOR TEST")
    print("="*50)

    seed_list = list(seed_list)

    while tot_cnt < variant["shot"] and len(seed_list) != 0:
        if tot_cnt % 100 == 0 and tot_cnt != 0:
            print(tot_cnt)
        observations = []
        languages = []
        rewards = []
        actions = []
        terminateds = []
        real_actions = []

        if len(seed_list) == 0:
            break

        seed = int(seed_list.pop(0))
        env = utils.make_env(
            need_reset=False,
            real_lang=variant["real_lang"],
            train_ratio=variant["train_ratio"],
            mode=variant["mode"],
            val_ratio=variant["val_ratio"],
        )

        obs, _ = env.reset(seed=seed)
        task = env.task
        if variant["retrain"]:
            if "put" in task:
                continue
        else:
            if not ("put" in task and "plates" in task):
                continue

        planer = Planer(env.info["init_state_info"])

        reward = 0
        cnt = 0

        lang = env.task

        terminated = False

        if "put" in env.task and "plates" in env.task and variant["retrain"]:
            assert False

        if variant["perturb"]:
            if np.random.rand() < 0.7:
                planer.action_list = add_perturbation(planer.action_list)
        while not terminated and len(planer.action_list) != 0:
            cnt += 1

            one_hot = make_one_hot_encoding(planer.action_list[0])
            real_actions.append(planer.action_list[0])
            actions.append(one_hot)
            next_obs, reward, terminated, truncated, next_info = env.step(planer.step())

            action_failed_reason = next_info["action_status"]["action_failed_reason"]
            if variant["only_fp"]:
                new_lang = next_obs["log_language_info"]
            elif variant["only_hindsight"]:
                new_lang = action_failed_reason
            else:
                new_lang = action_failed_reason + " " + next_obs["log_language_info"]

            observations.append(obs["image"])
            obs = next_obs

            if variant["random_lang"] and np.random.rand() < 0.7:
                new_lang = ""
            if variant["no_lang"]:
                new_lang = ""

            languages.append(np.array(encoder([lang])))
            lang = new_lang

            rewards.append(reward)
            
            terminateds.append(terminated)

        # skip if already success at beginning
        if cnt == 1 and reward >= 1:
            continue
        if reward < 1:
            continue

        tot_cnt += 1

        if env.task[0] == "f":
            find_cnt += 1
        elif env.task[0] == "g":
            get_cnt += 1
        elif env.task[0] == "o":
            open_cnt += 1
        elif env.task[0] == "p":
            put_cnt += 1
        elif env.task[0] == "m":
            move_cnt += 1
        else:
            raise NotImplementedError


        path = {
            "task_name": encoder([task]),
            "observations": np.array(observations),
            "languages": np.array(languages),
            "rewards": np.array(rewards),
            "actions": np.array(actions),
            "terminals": np.array(terminateds),
        }


        trajectories.append(path)
    if variant["save"]:
        data_path = "./homegrid"
        if not variant["retrain"]:
            data_path += f"-{put_cnt}shot"

            if variant["real_lang"]:
                data_path += "-real-lang"
            else:
                data_path += "-template"
        else:
            data_path += "-retrain"
            if variant["real_lang"]:
                data_path += "-real-lang"
            else:
                data_path += "-template"

        data_path += f"-{variant['encoder']}"

        if variant["perturb"]:
            data_path += "-perturb"

        if variant["only_fp"]:
            data_path += "-only-fp"

        if variant["only_hindsight"]:
            data_path += "-only-hindsight"

        if variant["random_lang"]:
            data_path += "-random-lang"
        elif variant["no_lang"]:
            data_path += "-no-lang"

        data_path += ".pkl"

        with open(data_path, "wb") as f:
            pickle.dump(trajectories, f)
            
        print(f"save in {data_path}")
    
    print("="*50)
    if offset == 0:
        print("YOU ARE RUN FOR VAL")
    elif offset == 100000:
        print("YOU ARE RUN FOR TRAIN")
    elif offset == 200000:
        print("YOU ARE RUN FOR TEST")
    print("="*50)


if __name__ == "__main__":
    parser = argparse.ArgumentParser()
    parser.add_argument("--save", type=int, default=1)
    parser.add_argument("--real_lang", type=bool, default=False)
    parser.add_argument("--retrain", type=bool, default=False)
    parser.add_argument("--shot", type=int, default=None)
    parser.add_argument("--encoder", type=str, default="sentencebert")
    parser.add_argument("--perturb", type=bool, default=True)
    parser.add_argument("--only_fp", type=bool, default=False)
    parser.add_argument("--only_hindsight", type=bool, default=False)
    parser.add_argument("--train_ratio", type=float, default=0.8)
    parser.add_argument("--val_ratio", type=float, default=0.1)
    parser.add_argument("--mode", type=str, default="train")
    parser.add_argument("--random_lang", type=bool, default=False)
    parser.add_argument("--no_lang", type=bool, default=False)
    args = parser.parse_args()

    assert args.train_ratio or not args.real_lang
    assert args.mode or not args.real_lang

    assert not (args.no_lang and args.real_lang)
    assert not (args.no_lang and args.random_lang)
    assert not (args.no_lang and args.only_fp)
    
    # assert args.only_fp or args.perturb
    assert not (args.only_fp and args.only_hindsight)
    assert args.val_ratio is None or args.val_ratio == 0.1
    
    data_collection(variant=vars(args))
