import glob
import json
import os
import pickle
import random
from typing import List, Dict, Union, Optional, NamedTuple
from PIL import Image

import numpy as np
from jax import numpy as jnp
from torch.utils.data.dataset import Dataset

from common.vh_invariants import skill2id, available, VH_SKILL
from collections import Counter
import natsort

def find_pkl_files(path: str):
    directory = os.path.abspath(path)
    search_pattern = os.path.join(directory, "*.pkl")

    pkl_files = glob.glob(search_pattern)

    return pkl_files

def load_path(path: str):
    directory = os.path.abspath(path)
    
    return directory

class EalfDataSample(NamedTuple):
    prompts: List[str]
    next_prompts: List[str]
    captions: List[str]
    next_captions: List[str]
    actions: Union[np.ndarray, jnp.ndarray]
    rewards: Union[np.ndarray, jnp.ndarray]
    dones: Union[np.ndarray, jnp.ndarray]

class EalfDataset(Dataset):
    def __init__(
        self,
        dataset_path: Optional[str] = None,
        instruction_path: Optional[str] = None,
        reward_path: Optional[str] = None,
        prompt: Optional[Union[str, List[str]]] = None,
        num_data_limit: Optional[int] = None,
        for_eval: Optional[bool] = False
    ):
        self.dataset_path = dataset_path
        self.instruction_path = instruction_path

        self.prompt = prompt
        dataset = []

        with open(dataset_path, "r") as f:
            dataset = json.load(f)

        with open(instruction_path, "r") as f:
            self.instruction_set = json.load(f)

        random.shuffle(dataset)

        if num_data_limit is not None:
            dataset = dataset[:num_data_limit]

        self.dataset = dataset  # type: List[Dict]
        self.eval = for_eval

        self.reward_set = None
        if reward_path is not None:
            
            if "pkl" in reward_path:
                with open(reward_path, "rb") as fp:
                    self.reward_set = pickle.load(fp)
            else: 
                with open(load_path(reward_path), "r") as fp:
                    self.reward_set = json.load(fp)    

    def __len__(self):
        return len(self.dataset)
    
   
    def object_processing(self, objects_list, history, grab_items):
        
        processing = []
        for i in objects_list:
            if i in available:
                processing.append(i)

        if len(history)!=0 and "find" in history[-1]:
            _, item = history[-1].split()
            
            if item not in processing:
                processing.append(item)
        
        # objects_set = list(set(processing)) 
        objects_set = processing
        if len(grab_items) != 0:
            for grab_item in grab_items:
                if grab_item in objects_set:
                    objects_set.remove(grab_item)

        objects_set = set(objects_set)
        obj_prompt = ', '.join(objects_set)

        return obj_prompt
    
    def grab_processing(self, history):

        grab_item = []
        if len(history)!=0:
            for i in history:
                if "grab" in i:
                    _, item = i.split()
                    grab_item.append(item)
                if "put in" in i:
                    _, _, item, _  = i.split()
                    if item in grab_item:
                        grab_item.remove(item)
                        
                elif "put" in i:
                    _, item, _  = i.split()
                    if item in grab_item:
                        grab_item.remove(item)
                        
                        
        if len(grab_item) == 0:
            grab_prompt = "nothing"
        else:
            grab_prompt =", ".join(grab_item)
        
        return grab_prompt, grab_item



## caption
    def __getitem__(self, idx):
        if not self.eval:
            x = self.dataset[idx]
            
            history = x["history"]
            action = x["action"]
            task_id = x["task_id"]
            next_history = history + [action]
            instruction = random.choice(self.instruction_set[str(task_id)])
            
            
            len_his = str(len(history))
            
            rewards_for_actions = self.reward_set[str(task_id)][len_his]
            if action in rewards_for_actions and action in VH_SKILL:
                reward = rewards_for_actions[action]
            reward *= 0.5

        else:
            x = self.dataset[idx]
            history = x["history"]
            task_id = x["task_id"]
            reward = 0.0
            action = x["action"]
            next_history = history + [action]
            instruction = random.choice(self.instruction_set[str(task_id)])
            
            
        if len(history) == 0:   
            history_prompt = ""
        else:
            history_prompt = ", ".join(history)
        next_history_prompt = ", ".join(next_history)
        
        grab_prompt, grab_items = self.grab_processing(history)
        next_grab_prompt, next_grab_item = self.grab_processing(next_history)
        
        obj_prompt = self.object_processing( x["observation"], history, grab_items)
        caption  =  f"Visible objects: {obj_prompt}\nGrabbed: {grab_prompt}"

        next_obj_prompt = self.object_processing(x["next_observation"], next_history, next_grab_item)
        next_caption  = f"Visible objects: {next_obj_prompt}\nGrabbed: {next_grab_prompt}"
        
        if history != [] and "find" in history[-1]:
            history_prompt += "\nAvailable skill type: grab, put, sit."
    
        if "find" in action:
            next_history_prompt += "\nAvailable skill type: grab, put, sit."

        prompt = self.prompt.format(instruction=instruction, history_prompt=history_prompt, objects=caption)
        next_prompt = self.prompt.format(instruction=instruction, history_prompt=next_history_prompt, objects=next_caption)

        data = {
            "prompt": prompt,
            "next_prompt": next_prompt,
            "caption": caption,
            "next_caption": next_caption,
            "history": x["history"],
            "actions": x["action"],
            "rewards": reward,
            "dones": x["done"],
        }
        
        return data
    

def ealf_collate_fn(samples: List) -> EalfDataSample:
    prompts = []
    prompts.extend([s["prompt"] for s in samples])

    next_prompts = []
    next_prompts.extend([s["next_prompt"] for s in samples])
    
    caption = []
    caption.extend([s["caption"] for s in samples])

    next_caption = []
    next_caption.extend([s["next_caption"] for s in samples])
    
    
    
    actions = jnp.array([skill2id[s["actions"]] for s in samples])
    rewards = jnp.array([s["rewards"] for s in samples])
    dones = jnp.array([s["dones"] for s in samples])

    data = {
        "prompts": prompts,
        "next_prompts": next_prompts,
        "captions": caption,
        "next_captions": next_caption,
        "actions": actions,
        "rewards": rewards,
        "dones": dones
    }

    return EalfDataSample(**data)


