import glob
import json
import os
import pickle
import random
from typing import List, Dict, Union, Optional, NamedTuple

import numpy as np
from jax import numpy as jnp
from torch.utils.data.dataset import Dataset
    
from common.invariants import available, max_epi_len, n_rewards

    
def find_pkl_files(path: str):
    directory = os.path.abspath(path)
    search_pattern = os.path.join(directory, "*.pkl")
    pkl_files = glob.glob(search_pattern)

    return pkl_files

def find_json_files(path: str):
    json_file = os.path.abspath(path)
    return json_file

class EnsembleSample(NamedTuple):
    prompts: List[str]
    rewards: Union[np.ndarray]
    dummys: Union[np.ndarray]
    task_ids: Union[np.ndarray]
    labels: Union[np.ndarray]
    succs: Union[np.ndarray]
    reward_masks: Union[np.ndarray]

class TestSample(NamedTuple):
    prompts: List[str]
    rewards: Union[np.ndarray]
    task_ids: Union[int]
    labels: Union[np.ndarray]
    historys: List[str]
    actions: List[str]


class EoEDataset(Dataset):
    def __init__(
        self,
        dataset_path: Optional[str] = None,
        instruction_path: Optional[str] = None,
        temporal_reward_path: Optional[list] = None,
        relational_reward_path: Optional[list] = None,
        procedure_reward_path: Optional[list] = None,
        expert_reward_path: Optional[list] = None,
        prompt_format: Optional[Union[str, List[str]]] = None,
        num_data_limit: Optional[int] = None,
        for_eval: Optional[bool] = False,
        max_episode_length : Optional[int] = max_epi_len,
    ):
        self.dataset_path = dataset_path
        self.instruction_path = instruction_path
        self.prompt_format = prompt_format
        self.max_episode_length = max_episode_length
        self.fpr_eval = for_eval
        
        with open(dataset_path, "rb") as f:
            dataset = json.load(f)

        with open(instruction_path, "rb") as f:
            self.instruction_set = json.load(f)

        random.shuffle(dataset)
        if num_data_limit is not None:
            dataset = dataset[:num_data_limit]

        self.dataset = dataset  # type: List[Dict]

        self.eval = for_eval
        self.idx = 0
  
  
  
        if temporal_reward_path != None:
            with open(find_json_files(temporal_reward_path), "r") as fp:
                self.temporal_reward = json.load(fp)
                
        if relational_reward_path != None:
            with open(find_json_files(relational_reward_path), "r") as fp:
                self.relational_reward = json.load(fp)
                
        if procedure_reward_path != None:
            with open(find_json_files(procedure_reward_path), "r") as fp:
                self.procedure_reward = json.load(fp)


        if expert_reward_path != None:
            with open(find_json_files(expert_reward_path), 'r') as f:
                self.expert_reward_set = json.load(f)      
                
        else:
            self.expert_reward_set = None

                        
    
    def object_processing(self, objects, grab_obj):
        objects_list = set(objects)
        processing = []
        for i in objects_list:
            if i in available:
                processing.append(i)
                
                
        for grab in grab_obj:
            if grab in processing:
                processing.remove(grab)
                
        obj_prompt = ', '.join(processing)
        
        return obj_prompt
    
    def grab_processing(self, history):

        grab_item = []
        if len(history)!=0:
            for i in history:
                if "grab" in i:
                    _, item = i.split()
                    grab_item.append(item)
                if "put in" in i:
                    _, _, item, _  = i.split()
                    if item in grab_item:
                        grab_item.remove(item)
                elif "put" in i:
                    _, item, _ = i.split()        
                    if item in grab_item:
                        grab_item.remove(item)
                     
        if len(grab_item) == 0:
            grab_prompt = "nothing"
        else:
            grab_prompt =", ".join(grab_item)
        
        return grab_prompt, grab_item
 
    def __len__(self):
        return len(self.dataset)
    
    def __getitem__(self, idx):
        
        traj = self.dataset[idx]
        traj_data = []
        
        task_id = traj[0]["task_id"]
        instruction = random.choice(self.instruction_set[str(task_id)])

        for x in traj:
            task_id = str(x["task_id"])
            history = x["history"]
            observation = x["observation"]
            
            ### obs processing
            grab_prompt, grab_item = self.grab_processing(history)
            obs_prompt = self.object_processing(observation, grab_item)
            
            prompt = self.prompt_format.format(instruction=instruction, objects=obs_prompt, grab=grab_prompt)
            action = x["action"]
            succ = x["succ"]
            
            ### reward ###
            rewards = []
            reward_mask = []
            for reward_set in [ self.temporal_reward,  self.relational_reward, self.procedure_reward]:            
    
                if str(len(history)) in reward_set[task_id]:
                    rewards_for_actions = reward_set[task_id][str(len(history))]
                    
                    if action in rewards_for_actions:
                        rewards.append(rewards_for_actions[action])
                        reward_mask.append(1)
                    else: 
                        rewards.append(0)
                        reward_mask.append(0)
                else:
                    rewards.append(0)
                    reward_mask.append(0)        
                
                
            if self.expert_reward_set == None:
                label = None
            else:          
                label = self.expert_reward_set[task_id][str(len(history))][action]

            data = {
                "prompt": prompt,
                "rewards": rewards,
                "dummy": False,
                "task_id": int(task_id),
                "label": label,
                "succ": succ,
                "reward_mask": reward_mask
            }
            traj_data.append(data)
        
        len_reward = len(rewards)
        
        if self.expert_reward_set == None:
            label = None
        else:          
            label = -999
        for i in range(self.max_episode_length - len(traj_data)):
            data = {
                "prompt": prompt,
                "rewards": [-999]*len_reward,
                "dummy": True,
                "task_id": int(task_id), 
                "label": label,
                "succ": 0,
                "reward_mask": [0]*len_reward
            }
            traj_data.append(data)
            
        if len(traj_data) != self.max_episode_length:
            exit()
            
        return traj_data
    

def ensemble_collate_fn(samples: List) -> EnsembleSample:
    prompts = []
    rewards = []
    dummys = []
    task_ids = []
    labels = []
    succs = []
    reward_masks = []
    batch = len(samples)
    
    for sample in samples:
   
        prompts.extend([s["prompt"] for s in sample])
        rewards.append([s["rewards"] for s in sample])
        dummys.append([s["dummy"] for s in sample])
        task_ids.append([s["task_id"] for s in sample])
        labels.append([s["label"] for s in sample])
        if sample[0]["succ"] > 0:
            succs.append(1)
        else:
            succs.append(-1)
        reward_masks.append([s["reward_mask"] for s in sample])
        
    rewards = jnp.array(rewards).reshape(batch, -1, n_rewards)
    dummys = jnp.array(dummys)
    task_ids = jnp.array(task_ids)
    labels = jnp.array(labels).reshape(-1)
    succs = jnp.array(succs)
    reward_masks = jnp.array(reward_masks).reshape(batch, -1, n_rewards)
 
    data = {
        "task_ids": task_ids,
        "prompts": prompts,
        "rewards": rewards,
        "dummys": dummys,
        "labels": labels,
        "succs": succs,
        "reward_masks": reward_masks
    }
    return EnsembleSample(**data)


class TestDataset(Dataset):
    def __init__(
        self,
        dataset_path: Optional[str] = None,
        instruction_path: Optional[str] = None,
        temporal_reward_path: Optional[list] = None,
        relational_reward_path: Optional[list] = None,
        procedure_reward_path: Optional[list] = None,
        expert_reward_path: Optional[list] = None,
        prompt_format: Optional[Union[str, List[str]]] = None,
        num_data_limit: Optional[int] = None,
        for_eval: Optional[bool] = False,
        max_episode_length : Optional[int] = max_epi_len,
    ):
        self.dataset_path = dataset_path
        self.instruction_path = instruction_path
        self.prompt_format = prompt_format
        self.max_episode_length = max_episode_length
        self.for_eval = for_eval
        
        with open(dataset_path, "rb") as f:
            dataset = json.load(f)

        with open(instruction_path, "rb") as f:
            self.instruction_set = json.load(f)

        random.shuffle(dataset)
        if num_data_limit is not None:
            dataset = dataset[:num_data_limit]

        self.dataset = dataset  # type: List[Dict]
        

        self.eval = for_eval
        self.idx = 0
  
  
        if temporal_reward_path != None:
            with open(find_json_files(temporal_reward_path), "r") as fp:
                self.temporal_reward = json.load(fp)
                
        if relational_reward_path != None:
            with open(find_json_files(relational_reward_path), "r") as fp:
                self.relational_reward = json.load(fp)
                
        if procedure_reward_path != None:
            with open(find_json_files(procedure_reward_path), "r") as fp:
                self.procedure_reward = json.load(fp)
        if expert_reward_path != None:
            with open(find_json_files(expert_reward_path), 'r') as f:
                self.expert_reward_set = json.load(f)      
                
        else:
            self.expert_reward_set = None

    def object_processing(self, objects, grab_obj):
        objects_list = set(objects)
        processing = []
        for i in objects_list:
            if i in available:
                processing.append(i)
                
                
        for grab in grab_obj:
            if grab in processing:
                processing.remove(grab)
                
        obj_prompt = ', '.join(processing)
        
        return obj_prompt
    
    def grab_processing(self, history):

        grab_item = []
        if len(history)!=0:
            for i in history:
                if "grab" in i:
                    _, item = i.split()
                    grab_item.append(item)
                if "put in" in i:
                    _, _, item, _  = i.split()
                    if item in grab_item:
                        grab_item.remove(item)
                elif "put" in i:
                    _, item, _ = i.split()        
                    if item in grab_item:
                        grab_item.remove(item)
                     
        if len(grab_item) == 0:
            grab_prompt = "nothing"
        else:
            grab_prompt =", ".join(grab_item)
        
        return grab_prompt, grab_item
    
    
    def __len__(self):
        return len(self.dataset)
    
    def __getitem__(self, idx):
        
        x = self.dataset[idx]
        task_id = x["task_id"]

        task_id = str(x["task_id"])
        instruction = random.choice(self.instruction_set[str(task_id)])
        history = x["history"]
        observation = x["observation"]
        
        ### obs processing
        grab_prompt, grab_item = self.grab_processing(history)
        
        obs_prompt = self.object_processing(observation, grab_item)
        
        prompt = self.prompt_format.format(instruction=instruction, objects=obs_prompt, grab=grab_prompt)
        action = x["action"]
        
        ### reward ###
        rewards = []
        for reward_set in [ self.temporal_reward,  self.relational_reward, self.procedure_reward]:            
            
            if str(len(history)) in reward_set[task_id]:
                rewards_for_actions = reward_set[task_id][str(len(history))]
                
                if action in rewards_for_actions:
                    rewards.append(rewards_for_actions[action])
                else: 
                    rewards.append(0)
            else:
                rewards.append(0)
                    
                    
        if self.expert_reward_set != None:
            label = self.expert_reward_set[task_id][str(len(history))][action]
        else:
            label = None
        data = {
            "prompt": prompt,
            "rewards": rewards,
            "task_id": int(task_id),
            "label": label,
            "history": history,
            "action": action
        }

        return data
    

def test_collate_fn(samples: List) -> TestSample:
    prompts = []
    rewards = []
    task_ids = []
    labels = []
    historys = []
    actions = []
    
    prompts.extend([s["prompt"] for s in samples])
    rewards.append([s["rewards"] for s in samples])
    task_ids.extend([s["task_id"] for s in samples])
    labels.append([s["label"] for s in samples])
    historys.extend([s["history"] for s in samples])
    actions.extend([s["action"] for s in samples])

    rewards = jnp.array(rewards)
    task_ids = jnp.array(task_ids)
    labels = jnp.array(labels).reshape(-1)

 
    data = {
        "task_ids": task_ids,
        "prompts": prompts,
        "rewards": rewards,
        "labels": labels,
        "historys": historys,
        "actions": actions,
    }
    return TestSample(**data)

