from torch.utils.data import Dataset
import more_itertools
from itertools import chain
from pathlib import Path
from collections import OrderedDict
from typing import Any, Callable, Dict, List, NewType, Optional, Tuple, Union
import functools
import torch
import random
import tqdm
import json
import copy
import ipdb
import re
import numpy as np
from pprint import pprint

from worldformer2.tools import jericho_utils
import logging
from worldformer2.tools.logging_util import basic_logging
basic_logging()

from worldformer2.tokenization.train_hf_tokenizer import train_sentencepiece, train_bpe, make_triples_tokenizer
from worldformer2.tools.graph_utils import get_graph_diff

from transformers import PreTrainedTokenizerFast
from tokenizers.pre_tokenizers import Whitespace
import tokenizers
from transformers import AutoTokenizer

import jericho
from worldformer2.dataset import jerichoworld_defines_updated
from worldformer2.dataset.jerichoworld_dataset import JerichoWorldDataset


"""
global whitespace_tokenizer
whitespace_tokenizer = Whitespace()  

global basic_actions
global noeffect_actions
global abbrv_dict

basic_actions = jerichoworld_defines_updated.BASIC_ACTIONS
noeffect_actions = jerichoworld_defines_updated.NO_EFFECT_ACTIONS
abbrv_dict = jerichoworld_defines_updated.ABBRV_DICT
"""


class JerichoWorldTorchTrainDataset(Dataset):
    def __init__(self, tokenizer, graph_encoder_tokenizer, action_decoder_tokenizer, graph_decoder_tokenizer,
                 cache_path, refresh_cache=False, add_special_tokens=False, max_input_tokens=1024, tokenizer_type='bert',
                 save_text=False, negative_mode=None, skip_nohole=False):
        """
        Supply training data for the full multitask training of Worldformer

        args
            dataset: JerichoWorldDataset holding the raw data

        """
        self.save_text = save_text

        self.negative_mode = negative_mode
        self.skip_nohole = skip_nohole

        #self.text_data = dataset
        self.tokenizer = tokenizer
        self.graph_encoder_tokenizer = graph_encoder_tokenizer
        self.action_decoder_tokenizer = action_decoder_tokenizer
        self.graph_decoder_tokenizer = graph_decoder_tokenizer

        # TODO TMP
        self.common_templates = set(json.load(open('/home/mnskim/workspace/tbg/tbg1/scripts/analysis/common_templates.json', 'r'))['common_templates'])
        self.all_templates = json.load(open('/home/mnskim/workspace/tbg/tbg1/scripts/analysis/all_templates.json', 'r'))

        self.unique_states = set()
        self.add_special_tokens = add_special_tokens
        self.max_input_tokens = max_input_tokens
        self.tokenizer_type = tokenizer_type
        self.cache_path = cache_path
        self.mlm_probability = 0.15

        # Prepare data
        #self.register_examples()
        #self.pretokenize_examples(refresh_cache=refresh_cache)

    def register_examples(self, dataset: JerichoWorldDataset, empty_graph_policy='skip'):
        """
        Populate self.examples from text_data
            args:
                empty_graph_policy: skip if a graph is empty
        """
        self.text_data = dataset

        self.examples = []
        for instance in self.text_data.instances:
            if empty_graph_policy == 'skip':
                if instance['state']['graph'] != [] and instance['next_state']['graph'] != []:
                    instance['ex_id'] = len(self.examples)
                    self.examples.append(instance)
            elif empty_graph_policy == 'include_all':
                instance['ex_id'] = len(self.examples)
                self.examples.append(instance)
 

    def pretokenize_examples(self, refresh_cache: bool = False, add_ckg: bool = False):
        """
        Run tokenizers
        """
        import mpire
        from mpire import WorkerPool
        from mpire.utils import make_single_arguments

        if add_ckg:
            #self.t_finder = TripletFinder()
            #self.me_re = MeaningResolver(wikdict_fn = '/home/mnskim/workspace/tbg/KEAR/data/kear/wik_dict.json')
            self.ckg_searcher = None
        else:
            self.ckg_searcher = None

        if refresh_cache:
            logging.info(f"Pretokenizing dataset.")

            input_list = make_single_arguments(self.examples, generator=False)

            # Sanity check
            #dd = self.tokenize_example(self.examples[42], self.tokenizer, self.graph_encoder_tokenizer, self.action_decoder_tokenizer, self.graph_decoder_tokenizer, ckg_searcher=self.ckg_searcher)            
            #ff = self.collate_fn(dd)
            # self.tokenizer.decode(dd['tokenized']['next_state']['inputs']['text_desc_ids'])            
            #for ii in range(1000):
            #    dd = self.tokenize_example(self.examples[42], self.tokenizer, self.graph_encoder_tokenizer, self.action_decoder_tokenizer, self.graph_decoder_tokenizer, ckg_searcher=self.ckg_searcher)            
            #    ff = self.collate_fn(dd)
            #ipdb.set_trace()

            func = functools.partial(self.tokenize_example, 
                                                    tokenizer=self.tokenizer, 
                                                    graph_encoder_tokenizer=self.graph_encoder_tokenizer, 
                                                    action_decoder_tokenizer=self.action_decoder_tokenizer, 
                                                    graph_decoder_tokenizer=self.graph_decoder_tokenizer, 
                                                    ckg_searcher=self.ckg_searcher)

            with WorkerPool(n_jobs=16) as pool:
                results = pool.map(func, input_list, progress_bar=True)
            # Flatten
            results = list(chain(*results))            

            torch.save(results, self.cache_path)
            logging.info(f"Saved dataset to cache location {self.cache_path}")
            self.tokenized_data = results
            #ipdb.set_trace()

        else:
            #ipdb.set_trace()
            self.tokenized_data = torch.load(self.cache_path)
            logging.info(f"Loaded dataset from cache. Cache location {self.cache_path}")
           
        #ipdb.set_trace()

        #tokenized_ids = self.examples[idx]


    def tokenize_state(self, state, tokenizer, graph_encoder_tokenizer, action_decoder_tokenizer, graph_decoder_tokenizer, ckg_searcher=None):

        text_loc = state['loc_desc']
        text_loc = text_loc.strip()
        text_loc = f"Location: {text_loc}"

        #text_act = ' '.join(state['walkthrough_act_normalized'][0])
        #text_act = text_act.strip()

        text_obs = state['obs']
        text_obs = text_obs.strip()

        #ipdb.set_trace()
        text_inv = state['inv_desc']
        text_inv = f"Inventory: {text_inv}"

        # Worldformer paper doesnt use this in the main task training
        """
        for k, v in state['inv_objs'].items():
            text_inv += f"{' '.join(v)} : {k}"
            #text_inv += f" "
            text_inv += f" [SEP] "
        text_inv = text_inv.strip()
        """

        text_surrounding = ''
        for k, v in state['surrounding_objs'].items():
            text_surrounding += f"{' '.join(v)} : {k}"
            text_surrounding += f" "
        text_surrounding = text_surrounding.strip()

        #text_all_pre = ' '.join(['[OBS]', text_loc])
        #text_all_post = ' '.join(['[OBS]', text_obs, text_inv, text_surrounding])

        # TODO move functionality to util
        prev_act_ids = []
        for tokens in state['walkthrough_act_normalized']:
            act_ids = [tokenizer.vocab['[ACT]']]
            #ipdb.set_trace()
            #for tok in tokens:
            #    act_ids.append(tokenizer.vocab[tok])
            act_ids.extend(tokenizer.encode(' '.join(tokens), add_special_tokens=False))
            prev_act_ids.extend(act_ids)

        # Process valid actions
        # NOTE valid_act_ids to store the concatenated version
        # NOTE template2acts organizes all actions according to their template
        unique_templates = list(more_itertools.unique_everseen([item[1][1] for item in state['valid_acts_normalized']]))
        template2acts = {}
        for template in unique_templates:
            template_id = ' '.join(template)
            template_tokens = ['[TEMPLATE_START]']
            
            for tok in template:
                if tok == 'OBJ':
                    template_tokens.append(tokenizer.mask_token)
                else:
                    template_tokens.append(tok)

            #template_tokens.append('[TEMPLATE_END]')
            template_tokens.append(tokenizer.eos_token)

            template2acts[template_id] = {'template_text': template, 
                                          'template_tokens': template_tokens,
                                          'ids': None,
                                          'targets_text': [],
                                          'n_holes': None,
                                         }
        #template2acts['no_template'] = {'ids': ['[NO_TEMPLATE]'] }                                                                                 

        #template2acts = {' '.join(template): {'template': template, 
        #                                      'ids': None,
        #                                      'targets_text': []
        #                                     } for template in unique_templates}

        valid_acts = []
        for tokens, (objects, template) in state['valid_acts_normalized']:
            
            n_holes = len([_ for _ in template if _ == 'OBJ'])
            template_id = ' '.join(template)                                    
            template2acts[template_id]['targets_text'].append(tokens)
            template2acts[template_id]['n_holes'] = n_holes            
            valid_acts.append('[ACT] ' + ' '.join(tokens))
        
        for template_id in template2acts.keys():            
            # Fill task            
            template2acts[template_id]['fill_target_ids'] = []
            if not template2acts[template_id]['n_holes'] == 0:
                for target_action in template2acts[template_id]['targets_text']:
                    template2acts[template_id]['fill_target_ids'] += ['[ACT]']
                    template2acts[template_id]['fill_target_ids'] += target_action
                template2acts[template_id]['fill_target_ids'] = tokenizer.encode(' '.join(template2acts[template_id]['fill_target_ids']), add_special_tokens=False)
            

            # Cls task
            template2acts[template_id]['ids'] = tokenizer.encode(' '.join(template2acts[template_id]['template_tokens']), add_special_tokens=False)            
            #ipdb.set_trace()

        valid_act_ids = tokenizer.encode(' '.join(valid_acts), add_special_tokens=False)
        valid_act_ids_out = copy.deepcopy(valid_act_ids) # It's the same thing

        #ipdb.set_trace()

        # NOTE not including prev act
        """
        text_pre_ids = tokenizer.encode(text_all_pre, add_special_tokens=False)
        text_post_ids = tokenizer.encode(text_all_post, add_special_tokens=False)
        text_desc_ids = text_pre_ids + prev_act_ids + text_post_ids
        """

        # NOTE not using text_surrounding
        text_all = '[OBS] ' + ' [OBS] '.join([text_loc, text_obs, text_inv])
        #text_all = '[OBS] ' + ' [OBS] '.join([text_loc, text_obs, text_inv, text_surrounding])
        text_desc_ids = tokenizer.encode(text_all, add_special_tokens=False)
        #ipdb.set_trace()
        
        
        # Graph (for input)
        graph_in_tuples = []
        ckg_heads = set()
        for tup in state['graph']:
            #assert len(tup) == 3
            #s, r, o = tup
            graph_in_tuples.append('[TRIPLE] ' + ' '.join(tup))
            if ckg_searcher is not None:
                ckg_heads.update(set(chain(*[jericho_utils.tokenize(tok) for tok  in tup])))
        graph_in = tokenizer.encode(' '.join(graph_in_tuples), add_special_tokens=False)

        graph_out = copy.deepcopy(graph_in)
        
        
        #ipdb.set_trace()
        ckg_in_tuples = []
        ckg_in = []
        # NOTE fixed to 10 triples
        if ckg_searcher is not None:
            for word in ckg_heads:
                ckg = np.array(ckg_searcher.search(word))
                ckg = ckg[:10]
                if len(ckg) > 0:
                    for tup in ckg:
                        ckg_in_tuples.append('[TRIPLE] ' + ' '.join(tup))
            ckg_in = tokenizer.encode(' '.join(ckg_in_tuples), add_special_tokens=False)

        inputs = {'valid_act_ids': valid_act_ids,
                  'text_desc_ids': text_desc_ids,
                  'graph': graph_in,
                  'ckg': ckg_in,
                  }

        outputs = {'valid_act_ids': valid_act_ids_out,
                   'act_token_id': action_decoder_tokenizer.vocab['[ACT]'],
                   'graph': graph_out,
                   'unk_token_id': graph_decoder_tokenizer.unk_token_id,
                   'template2acts': template2acts
                  }

        #ipdb.set_trace()

        return {'inputs': inputs, 'outputs': outputs}


    def tokenize_example(self, example, tokenizer, graph_encoder_tokenizer, action_decoder_tokenizer, graph_decoder_tokenizer, ckg_searcher=None):
        """
        Tokenize examples into input (text, graph) and output (action pred, graph pred)

        example order: loc_desc, prev_act, obs, inv_desc, inv_objs, inv_attrs, surrounding_objs, surrounding_attr + valid acts

        The inputs are O_t, V_t, G_t, and A, targets are G_{t+1} (or G{t+1} - G{t}), V_{t+1}

        """
        #self.tokenizer = tokenizer

        # TODO move this to a util
        # attrs are empty (seem broken as they don't refer to full object names) so skipped
        # TODO why are inventory object name tokens shuffled?
        
        # TODO should we do lookup for action words or just treat them as text?
        # TODO NOTE pretokenizer, such as bert's, probably has an effect here?
        # NOTE lookup will have tigther matching with decoding vocab? but only applicable for models sharing enc-dec vocab
        
        #ipdb.set_trace()


        prev_state = self.tokenize_state(example['state'], tokenizer, graph_encoder_tokenizer, action_decoder_tokenizer, graph_decoder_tokenizer, ckg_searcher=ckg_searcher)
        next_state = self.tokenize_state(example['next_state'], tokenizer, graph_encoder_tokenizer, action_decoder_tokenizer, graph_decoder_tokenizer, ckg_searcher=ckg_searcher)
        #graph_diff = get_graph_diff(prev_state['outputs']['graph'], next_state['outputs']['graph'])
        graph_adds = example['graph_adds']
        graph_diff_tuples = []
        for tup in graph_adds:
            assert len(tup) == 3
            graph_diff_tuples.append('[TRIPLE] ' + ' '.join(tup))
        graph_diff = tokenizer.encode(' '.join(graph_diff_tuples), add_special_tokens=False)

        if not self.negative_mode is None:
            def make_neg_example(neg_template):
                #_n_holes = len([_ for _ in neg_template.split() if _ == 'OBJ'])
                _n_holes = len([_ for _ in neg_template.split(' ') if _ == 'OBJ'])
                #ipdb.set_trace()
                neg_template_tokens = ['[TEMPLATE_START]']
                for tok in neg_template.split(' '):
                    if tok == 'OBJ':
                        neg_template_tokens.append(tokenizer.mask_token)
                    else:
                        neg_template_tokens.append(tok)
                #neg_template_tokens += ['[TEMPLATE_END]']
                neg_template_tokens += [tokenizer.eos_token]
                neg_template_ids = tokenizer.encode(' '.join(neg_template_tokens), add_special_tokens=False) 
                
                neg_template_output = {'template_text': neg_template.split(' '),
                                    'template_tokens': neg_template_tokens,
                                    'ids': neg_template_ids,
                                    'targets_text': None}

                output = {'rom': example['rom'],
                        'prev_state': prev_state,
                        'next_state': next_state,
                        'graph_diff': graph_diff,
                        'transition_act': transition_act_ids,
                        'template': neg_template_output,
                        'template_label': 'False',
                        'template_n_holes': _n_holes,
                        }
                return output

        finalized_examples = []
        
        # TODO move functionality to util
        transition_act_ids = tokenizer.encode(' [TRANSITION_ACT] ' + ' '.join(example['transition_action_normalized'][0]), add_special_tokens=False)
        #ipdb.set_trace()

        # NOTE prev_state['inputs']['text_desc_ids']
        # create positive examples per template
        # create negative example(s) per positive

        for template in next_state['outputs']['template2acts'].keys():       
            # TODO Add n_holes     

            #_next_state = copy.deepcopy(next_state) # NOTE copy needed or not?          

            #ipdb.set_trace()                                                           
            
            if self.skip_nohole:
                if next_state['outputs']['template2acts'][template]['n_holes'] == 0:
                    continue

            # Positive
            output = {'rom': example['rom'],
                    'prev_state': prev_state,
                    'next_state': next_state,
                    'graph_diff': graph_diff,
                    'transition_act': transition_act_ids,
                    'template': next_state['outputs']['template2acts'][template],
                    'template_label': 'True',
                    'template_n_holes': next_state['outputs']['template2acts'][template]['n_holes'],
                    }

            if self.save_text:
                ex_pos = copy.deepcopy(example)
            else:
                ex_pos = {'rom': example['rom'],
                          'ex_id': example['ex_id'],
                        }
            ex_pos['template']= template

            ex_pos['tokenized'] = output
            finalized_examples.append(ex_pos)

            # Negative            
            if not self.negative_mode is None:
                if self.negative_mode == 'sample':
                    neg_templates = set(self.all_templates[example['rom']]) - set(next_state['outputs']['template2acts'].keys())
                    neg_template = random.sample(neg_templates, 1)[0]
                    output = make_neg_example(neg_template)          

                    if self.save_text:
                        ex_neg = copy.deepcopy(example)
                    else:
                        ex_neg = {'rom': example['rom'],
                                  'ex_id': example['ex_id'],
                                }
                    ex_neg['template']= neg_template

                    ex_neg['tokenized'] = output                
                    finalized_examples.append(ex_neg)
                              
            #ipdb.set_trace()
            #print(self.save_text)

        if not self.negative_mode is None:
            if self.negative_mode == 'all':
                neg_templates = set(self.all_templates[example['rom']]) - set(next_state['outputs']['template2acts'].keys())
                for neg_template in neg_templates:
                    output = make_neg_example(neg_template) 
                    
                    if self.save_text:
                        ex_neg = copy.deepcopy(example)
                    else:
                        ex_neg = {'rom': example['rom'],
                                  'ex_id': example['ex_id'],
                                    }
                    ex_neg['template']= neg_template

                    ex_neg['tokenized'] = output                
                    finalized_examples.append(ex_neg)              

        return finalized_examples

    def get_collate_fn(self, **kwargs):
        return functools.partial(self.collate_fn, **kwargs)

    def collate_fn(self, examples, sos=False, mlm=False):
        """
        Construct tensor batches
        Make masks and labels for mlm
        """
        #ipdb.set_trace()
        
        keys = ['input_text', 'input_graph', 'output_valid_act', 'output_graph_diff', 'binary_labels']
        #keys = ['input_text', 'output_valid_act', 'output_graph_diff']

        id_lists = {k: [] for k in keys}
        id_tensors = {}
        output = {}

        batch_token_ids = []
        for orig_example in examples:
            example = copy.deepcopy(orig_example)
            #ipdb.set_trace()
            # Encoder Inputs
            
            text_desc_ids = example['tokenized']['prev_state']['inputs']['text_desc_ids']
            valid_act_ids = example['tokenized']['prev_state']['inputs']['valid_act_ids']
            transition_act_ids = example['tokenized']['transition_act']

            template_ids = example['tokenized']['template']['ids']
            #ipdb.set_trace()
            template_ids = self.tokenizer.encode(' '.join(example['tokenized']['template']['template_tokens']), add_special_tokens=False)            
            all_tokens = self.concatenate_and_truncate_text_inputs(text_desc_ids, valid_act_ids, template_ids, transition_act_ids, self.max_input_tokens-1)
            id_lists['input_text'].append(all_tokens)
            #ipdb.set_trace()
    
            
            # Input for graph encoder            
            graph_input_ids = example['tokenized']['prev_state']['inputs']['graph']
            graph_input_ids += example['tokenized']['transition_act']
            if self.ckg_searcher is not None:
                graph_input_ids += example['tokenized']['prev_state']['inputs']['ckg']
            graph_input_ids = graph_input_ids[:(self.max_input_tokens - 1)]
            graph_input_ids += [self.tokenizer.eos_token_id]
            id_lists['input_graph'].append(graph_input_ids)
            #ipdb.set_trace()

            
            graph_target_ids = [self.tokenizer.bos_token_id] + example['tokenized']['graph_diff']
            graph_target_ids = graph_target_ids[:(self.max_input_tokens-1)] + [self.tokenizer.eos_token_id]
            
            id_lists['output_graph_diff'].append(graph_target_ids)
            
            # Get template gold label
            if example['tokenized']['template_label'] == 'False':
                lb = 0
            else:
                lb = 1
            id_lists['binary_labels'].append([lb])

            # Fill target depends on label
            if lb == 1:
                fill_target = example['tokenized']['template']['fill_target_ids'] 
            else:
                fill_target = []
            action_target_ids = [self.tokenizer.bos_token_id] + fill_target
            action_target_ids = action_target_ids[:(self.max_input_tokens-1)] + [self.tokenizer.eos_token_id]
            
            id_lists['output_valid_act'].append(action_target_ids)
            #ipdb.set_trace()
            
            

        #ipdb.set_trace()
        for k, v in id_lists.items():
            if k == 'input_text':
                tokenizer = self.tokenizer
                if mlm:
                    mlm_input_ids , mlm_labels = self.torch_mask_tokens(id_tensors[k])
                    id_tensors[k+'_mlm_ids'] = mlm_input_ids
                    id_tensors[k+'_mlm_labels'] = mlm_labels
            elif k == 'input_graph':
                tokenizer = self.tokenizer
                if mlm:
                    mlm_input_ids , mlm_labels = self.torch_mask_tokens(id_tensors[k])
                    id_tensors[k+'_mlm_ids'] = mlm_input_ids
                    id_tensors[k+'_mlm_labels'] = mlm_labels
            elif k == 'output_graph_diff':
                tokenizer = self.tokenizer
            elif k == 'output_valid_act':
                #tokenizer = self.action_decoder_tokenizer
                tokenizer = self.tokenizer
                #ipdb.set_trace()
            #if k == 'binary_labels':
            #    ipdb.set_trace()
            id_tensors[k] = self._torch_collate_batch(v, tokenizer)
            id_tensors[k+'_mask'] = (id_tensors[k] != tokenizer.pad_token_id).long()

        id_tensors['rom'] = [example['rom'] for example in examples]
        id_tensors['ex_id'] = [example['ex_id'] for example in examples]
        id_tensors['template'] = [example['template'] for example in examples]
        id_tensors['n_holes'] = [example['tokenized']['template_n_holes'] for example in examples]


        """
        # NOTE added for template conditioning        
        #id_tensors['template2acts'] = []
        for example in examples:
            #template2acts_update = copy.deepcopy(example['tokenized']['next_state']['outputs']['template2acts'])
            for k,v in example['tokenized']['next_state']['outputs']['template2acts'].items():
                v['ids'] = [self.tokenizer.bos_token_id] + v['ids']
                v['ids'] = v['ids'][:(self.max_input_tokens-1)] + [self.tokenizer.eos_token_id]                            
                ipdb.set_trace()
        id_tensors['template2acts'] = [example['tokenized']['next_state']['outputs']['template2acts'] for example in examples]
        """

        #ipdb.set_trace()

        return id_tensors

    def _torch_collate_batch(self, examples, tokenizer, pad_to_multiple_of: Optional[int] = None):
        """Collate `examples` into a batch, using the information in `tokenizer` for padding if necessary."""
        #import numpy as np
        #import torch

        # Tensorize if necessary.
        if isinstance(examples[0], (list, tuple, np.ndarray)):
            examples = [torch.tensor(e, dtype=torch.long) for e in examples]

        length_of_first = examples[0].size(0)

        # Check if padding is necessary.

        are_tensors_same_length = all(x.size(0) == length_of_first for x in examples)
        if are_tensors_same_length and (pad_to_multiple_of is None or length_of_first % pad_to_multiple_of == 0):
            return torch.stack(examples, dim=0)

        # If yes, check if we have a `pad_token`.
        if tokenizer._pad_token is None:
            raise ValueError(
                "You are attempting to pad samples but the tokenizer you are using"
                f" ({tokenizer.__class__.__name__}) does not have a pad token."
            )

        # Creating the full tensor and filling it with our data.
        max_length = max(x.size(0) for x in examples)
        if pad_to_multiple_of is not None and (max_length % pad_to_multiple_of != 0):
            max_length = ((max_length // pad_to_multiple_of) + 1) * pad_to_multiple_of
        result = examples[0].new_full([len(examples), max_length], tokenizer.pad_token_id)
        for i, example in enumerate(examples):
            if tokenizer.padding_side == "right":
                result[i, : example.shape[0]] = example
            else:
                result[i, -example.shape[0] :] = example
        return result


    # TODO move to util

    def concatenate_and_truncate_text_inputs(self, text_desc_ids, valid_act_ids, template_ids, transition_act_ids, max_length):
        """
        Transition act is just one act so expected to be short. Similarly while valid acts has more actions itll never be longer than max length (1024)

        args:
            text_desc_ids: list
            valid_act_ids: list
            template_ids: list
            transition_act_ids: list
            max_length

        returns:
            all_tokens: list
            
        """
        # Old order
        #action_ids = valid_act_ids + transition_act_ids
        #all_tokens = text_desc_ids + action_ids
        #if len(all_tokens) > max_length:
        #    n_overflow = len(all_tokens) - max_length
        #    all_tokens = text_desc_ids[:-n_overflow] + action_ids  # truncate from text

        # New order
        max_length = max_length - len(template_ids)
        pre_ids = text_desc_ids + transition_act_ids
        all_tokens = pre_ids + valid_act_ids
        all_tokens = all_tokens[:max_length]
        all_tokens += template_ids

        return all_tokens

    def concatenate_acts(self, actions_list, act_token_id):
        """

        args:
            actctions_list: list of list
        """
        output = []
        for tokens in actions_list:
            act_ids = [act_token_id]
            for tok in tokens:
                act_ids.append(tok)
            output.extend(act_ids)
        return output

    def torch_mask_tokens(self, inputs: Any, special_tokens_mask: Optional[Any] = None) -> Tuple[Any, Any]:
        """
        Prepare masked tokens inputs/labels for masked language modeling: 80% MASK, 10% random, 10% original.
        """

        labels = inputs.clone()
        # We sample a few tokens in each sequence for MLM training (with probability `self.mlm_probability`)
        probability_matrix = torch.full(labels.shape, self.mlm_probability)
        if special_tokens_mask is None:
            special_tokens_mask = [
                self.tokenizer.get_special_tokens_mask(val, already_has_special_tokens=True) for val in labels.tolist()
            ]
            special_tokens_mask = torch.tensor(special_tokens_mask, dtype=torch.bool)
        else:
            special_tokens_mask = special_tokens_mask.bool()

        probability_matrix.masked_fill_(special_tokens_mask, value=0.0)
        masked_indices = torch.bernoulli(probability_matrix).bool()
        labels[~masked_indices] = -100  # We only compute loss on masked tokens

        # 80% of the time, we replace masked input tokens with tokenizer.mask_token ([MASK])
        indices_replaced = torch.bernoulli(torch.full(labels.shape, 0.8)).bool() & masked_indices
        inputs[indices_replaced] = self.tokenizer.convert_tokens_to_ids(self.tokenizer.mask_token)

        # 10% of the time, we replace masked input tokens with random word
        indices_random = torch.bernoulli(torch.full(labels.shape, 0.5)).bool() & masked_indices & ~indices_replaced
        random_words = torch.randint(len(self.tokenizer), labels.shape, dtype=torch.long)
        inputs[indices_random] = random_words[indices_random]

        # The rest of the time (10% of the time) we keep the masked input tokens unchanged
        return inputs, labels

    def filter_roms(self, roms):
        len_prev = len(self)
        self.tokenized_data = [item for item in self.tokenized_data if item['rom'] in roms]        
        logging.info(f"Filtering to roms: {roms}. Prev: {len_prev} instances, Filtered: {len(self)} instances")
        #ipdb.set_trace()

    def __getitem__(self, idx) -> dict:
        """
        Outputs {Text descriptions including location and inventory, and Valid actions}, for the mlm task
        """
        #ipdb.set_trace()
        return self.tokenized_data[idx]

    def __len__(self):
        return len(self.tokenized_data)


