import argparse
from torch.utils.data import Dataset
from itertools import chain
from pathlib import Path
from collections import OrderedDict
import torch
import random
import tqdm
import json
import ipdb
import re

from worldformer2.tools import jericho_utils

from copy import deepcopy
from typing import Any, Dict, List, Set, Union, Optional, Tuple

import logging
from worldformer2.tools.logging_util import basic_logging
basic_logging()

from worldformer2.tokenization.train_hf_tokenizer import train_sentencepiece, train_bpe, make_triples_tokenizer
from transformers import PreTrainedTokenizerFast
from tokenizers.pre_tokenizers import Whitespace
import tokenizers
from transformers import AutoTokenizer

# Tokenization related
from tokenizers import Tokenizer
from tokenizers.normalizers import NFKC
from tokenizers import decoders
from tokenizers.models import BPE, WordLevel
from tokenizers.trainers import BpeTrainer
from tokenizers.pre_tokenizers import Whitespace, Digits, Metaspace
from tokenizers import SentencePieceBPETokenizer
from transformers import PreTrainedTokenizerFast


import jericho
from worldformer2.dataset import jerichoworld_defines_updated

global whitespace_tokenizer
whitespace_tokenizer = Whitespace()  

global basic_actions
global noeffect_actions
global abbrv_dict

basic_actions = jerichoworld_defines_updated.BASIC_ACTIONS
noeffect_actions = jerichoworld_defines_updated.NO_EFFECT_ACTIONS
abbrv_dict = jerichoworld_defines_updated.ABBRV_DICT



class JerichoWorldDataset(object):
    def __init__(self, path):
        """
        args
            path: json file
        """    

        self.data = {}
        self.load(path)

    def load(self, path, skip_empty_graphs=False):
        dataset = json.load(open(path, 'r'))

        for gametype in dataset:
            rom = gametype[0]["rom"]
            assert rom not in self.data
            self.data[rom] = gametype


        # debugging
        #self.filter_games('huntdark')


        # We process the actions so they are in normalized, tokenized forms
        #self.bugfix_data()
        #self.normalize_player_actions() # bug in dataset
        #ipdb.set_trace()
        self.normalize_actions(skip_empty_graphs=skip_empty_graphs)
        logging.info(f"Finished loading games from {path}")


    # Utility functions

    def filter_games(self, include: list):
        filtered = {}
        for k, v in self.data.items():
            if k in include:
                filtered[k] = v
        included = set(list(filtered.keys()))
        dropped = set(self.list_games()) - included
        self.data = filtered
        print(f"Dropped: {dropped}\nRemaining: {self.list_games()}")


    def bugfix_data(self):
        """
        Fix bug in dataset
        """

        for game in self.list_games():
            for idx, instance in enumerate(self.data[game]):
                check_fix_state_kg(self.data[game], idx)


        #"""
        new_instances = []
        actions = []

        for game in self.list_games():
            game_actions = []
            game_cs_actions = []
            game_ns_actions = []

            for idx, instance in enumerate(self.data[game]):
                #if idx == len(self.data[game]) - 1:
                #    break
            
                if idx == 0:
                    continue
        
                if not self.data[game][idx-1]['next_state']['walkthrough_act'] == self.data[game][idx]['state']['walkthrough_act']:


                    print('ACT', self.data[game][idx-1]['next_state']['walkthrough_act'])
                    print('OBS', self.data[game][idx-1]['next_state']['obs']) 
                    print('GRAPH DIFF', get_graph_diff(self.data[game][idx-1]['state']['graph'], self.data[game][idx-1]['next_state']['graph']))
                    print('')
                    print('ACT', self.data[game][idx]['state']['walkthrough_act'])
                    print('OBS', self.data[game][idx]['state']['obs']) 
                    print('GRAPH DIFF', get_graph_diff(self.data[game][idx-1]['state']['graph'], self.data[game][idx]['state']['graph']))
                else:
                    print('ACT', self.data[game][idx-1]['next_state']['walkthrough_act'])
                    print('OBS', self.data[game][idx-1]['next_state']['obs']) 
                    print('GRAPH DIFF', get_graph_diff(self.data[game][idx-1]['state']['graph'], self.data[game][idx-1]['next_state']['graph']))
                    print('')
                    print('ACT', self.data[game][idx]['state']['walkthrough_act'])
                    print('OBS', self.data[game][idx]['state']['obs']) 
                    print('GRAPH DIFF', get_graph_diff(self.data[game][idx-1]['state']['graph'], self.data[game][idx]['state']['graph']))

                    ipdb.set_trace()

                """
                print('instance idx', idx)
                print('transition act',instance['action'])
                print('cs',self.data[game][idx]['state']['walkthrough_act'])
                print('cs',self.data[game][idx]['state']['obs'])
                print('\n')
                print('ns',self.data[game][idx]['next_state']['walkthrough_act'])
                print('ns',self.data[game][idx]['next_state']['obs'])
    
                game_actions.append(instance['action'])
                game_cs_actions.append(instance['state']['walkthrough_act'])
                game_ns_actions.append(instance['next_state']['walkthrough_act'])

                to_add = [self.data[game][idx]['state']['graph'],
                          self.data[game][idx]['next_state']['graph'],
                          ]
                new_instances.append(to_add)
                """
 
                """
                to_add = [self.data[game][idx]['next_state'],
                          self.data[game][idx+1]['next_state'],
                          self.data[game][idx]['action']]
                         
                adds, dels = get_graph_diff(to_add[0]['graph'], to_add[1]['graph'])
                print(to_add[-1])
                print(adds)
                print(dels)

                new_instances.append(to_add)
                #print(f"{self.data[game][idx]['action']}\n\n{self.data[game][idx]['next_state']['valid_acts'].values()}\n\n{self.data[game][idx]['next_state']['graph']}\n\n{self.data[game][idx+1]['state']['graph']}")
                """
                #ipdb.set_trace()

            print(game_actions)
            print(game_cs_actions)
            print(game_ns_actions)
            #list(zip(game_ns_actions[1:150],game_cs_actions[2:150],game_actions[:150]))
            #ipdb.set_trace()
        #"""
        #ipdb.set_trace()


    def normalize_player_actions(self):
        """
        Convert player actions into normalized, tokenized forms (unabbreviated and lowercase).
        Adds to data dict as additional field.
        """
        for game in self.list_games():
            for instance in self.data[game]:
                # Process player action
                # NOTE using hf whitespace pretokenizer to parse player command into individual tokens
                player_action_tokens = split(instance['action'], method='hf_whitespace')
                converted_player_action = []
                for token in player_action_tokens:
                    token = token.lower()
                    if token in abbrv_dict:
                        token = abbrv_dict[token]
                    converted_player_action.append(token)
                
                instance['action_normalized'] = converted_player_action
                #print(instance['action'], converted_player_action)

    def normalize_actions(self, skip_empty_graphs=False):
        """
        Convert valid actions into tokenized forms
        """
        #ipdb.set_trace()
        fail = 0
        success = 0
        
        n = 0

        n_skip = 0
        n_noskip = 0

        n_no_valid_acts = 0
        n_empty_graphs = 0

        for game in self.list_games():
            for idx, instance in enumerate(self.data[game]):
                n += 1
                """
                if skip_empty_graphs:
                    if instance['state']['graph'] == []:
                        n_skip += 1
                        continue
                    else:
                        n_noskip += 1
                """
                if instance['state']['graph'] == []:
                    n_empty_graphs += 1
                if instance['next_state']['graph'] == []:
                    n_empty_graphs += 1
            
                # Process valid (engine) actions 
                if len(instance['state']['valid_acts']) == 0: # Skipping because prev state doesn't have valid actions
                    state_valid_acts = []
                    n_no_valid_acts += 1
                else:
                    state_valid_acts = convert_to_tokens(instance['state']['valid_acts'], lower=True, unabbrev=True)

                if len(instance['next_state']['valid_acts']) == 0:
                    next_state_valid_acts = []
                    n_no_valid_acts += 1
                else:
                    next_state_valid_acts = convert_to_tokens(instance['next_state']['valid_acts'], lower=True, unabbrev=True)

                instance['state']['valid_acts_normalized'] = state_valid_acts
                instance['next_state']['valid_acts_normalized'] = next_state_valid_acts

                state_act = convert_to_tokens_simple([instance['state']['walkthrough_act']], lower=True, unabbrev=True)
                next_state_act = convert_to_tokens_simple([instance['next_state']['walkthrough_act']], lower=True, unabbrev=True)
                instance['state']['walkthrough_act_normalized'] = state_act
                instance['next_state']['walkthrough_act_normalized'] = next_state_act
                #ipdb.set_trace()
          
                transition_act = next_state_act
                instance['transition_action_normalized'] = transition_act

                """ # skip check
                if not ' '.join(transition_act) in instance['state']['valid_acts'].values():
                    if next_state_act in abbrv_dict:
                        success += 1
                        next_state_act = abbrv_dict[next_state_act]
                    else:
                        # super hacky
                        if 'x' in next_state_act.split():
                            next_state_act = next_state_act.replace('x ','examine ')
                            success += 1
                            #print(next_state_act)
                            pass
                        else:
                            fail += 1
                            print(next_state_act)
                            #print(next_state_act, instance['state']['valid_acts'].values())
                            #ipdb.set_trace()
                else:
                    success += 1
                """
    
                
                #curr_graph = instance['state']['graph']
                #next_graph = instance['next_state']['graph']
                #adds, dels = get_graph_diff(curr_graph, next_graph)

                #print(next_state_act)
                #print(self.data[game][idx-1]['action'])
                #print(adds)
                #print(dels)

                #if idx > 10:
                #    ipdb.set_trace()
        
                #if instance['state']['graph'] == self.data[game][idx-1]['next_state']['graph']:
                #    ipdb.set_trace()

                #if instance['next_state']['graph'] == []:
                #    ipdb.set_trace()


                """
                if instance['state']['graph'] != []:
                    #print(self.data[game][idx-1]['next_state']['loc_desc'])
                    #print(self.data[game][idx]['state']['loc_desc'])
                    #print()
                    #print(self.data[game][idx-1]['action'])
                    #ipdb.set_trace()

                    adds, dels = get_graph_diff(self.data[game][idx]['state']['graph'], instance['next_state']['graph'])
                    print(f"ADDS: {adds}")
                    print(f"DELS: {dels}")
                    print(self.data[game][idx]['action'], transition_act)
                    ipdb.set_trace()
                """

                """
                if instance['state']['graph'] == []:
                    #continue
                    #print(self.data[game][idx-1]['next_state']['loc_desc'])
                    #print(self.data[game][idx]['state']['loc_desc'])
                    #print()
                    #print(self.data[game][idx-1]['action'])
                    #ipdb.set_trace()

                    if self.data[game][idx-1]['next_state']['loc_desc'] == self.data[game][idx]['state']['loc_desc']:
                        adds, dels = get_graph_diff(self.data[game][idx-1]['next_state']['graph'], instance['next_state']['graph'])
                        print(f"ADDS: {adds}")
                        print(f"DELS: {dels}")
                        print(self.data[game][idx-1]['action'], transition_act)
                        ipdb.set_trace()
                """
               
        print("skips", n_skip, "no skips",n_noskip, "no valid acts", n_no_valid_acts, "no graph triples (empty graph)", n_empty_graphs)
        #ipdb.set_trace()
    

    def list_games(self):
        return list(self.data.keys())

    def sample_game(self, rom):
        return random.sample(self.data[rom], 1)
   
    def __str__(self):
        total_instances = 0
        string = ""
        string += f"## JerichoWorldDataset ##\n"
        for rom, instances in self.data.items():
            total_instances += len(instances)
            string += f" [{rom}]: {len(instances)} instances\n"
        string += f"Total {total_instances} instances. Each instance contains current state and next state."
        return string


    def flatten(self):
        """
        Flatten contents in dataset to flat list

        """
        self.instances = []
        for game in self.list_games():
            for instance in self.data[game]:
                self.instances.append(instance)
                #ipdb.set_trace()

        logging.info(f"Flattened: {len(self.instances)} instances in {len(self.list_games())} games")

    def split_validation(self, ratio, seed=42):
        """
        Split off a subset for validation

        """
        random.seed(seed)

        random.shuffle(self.instances)
        n_valid = int(len(self.instances) * ratio)

        valid_examples = deepcopy(self.instances[:n_valid])
        self.instances = self.instances[n_valid:]

        logging.info(f"Train/valid split: {len(self.instances)} train instances remaining, split off {len(valid_examples)} instances")

        return valid_examples

class VocabBuilderDataset(object):
    def __init__(self, dataset: JerichoWorldDataset, args):
        """
        args
            dataset: JerichoWorldDataset holding the raw data

        """
        self.args = args
        self.text_data = dataset
        #self.jericho_base_vocab = self.get_jericho_base_vocab()
        self.tokenizer, self.tokens_dict = self.parse_data()
        #ipdb.set_trace()

    def save_tokenizer(self, save_path, save_extra=True):
        self.tokenizer.save_pretrained(save_path)
        if save_extra:
            save_extra = Path(save_path) / 'extra'
            save_extra.mkdir(parents=True,exist_ok=True)

            with open(save_extra / "added_tokens.json", 'w') as outf:
                outf.write(json.dumps(self.tokens_dict))
        print(f"Saved to {save_path}")

    def parse_data(self):
        """
        Flatten contents in text_data to flat list

        """
        # jericho 
        global abbrv_dict

        world_objects = OrderedDict()

        self.text_observations = []
        self.actions = set()

        self.engine_vocab = set()

        self.triples = []
        self.triples_vocab = set()
        self.extra_vocab = set()
        self.relations = []

        self.oov = set()

        for game in self.text_data.list_games():
            jericho_path = Path("/home/mnskim/workspace/tbg/jericho/z-machine-games-master/jericho-game-suite/")
            path = list(jericho_path.glob(f"{game}.z*"))[0] # because extention can differ
            env = jericho.FrotzEnv(str(path))           

            engine_commands = set([str(item) for item in env.get_dictionary()])
            self.engine_vocab.update(engine_commands)
            #print(f"{game} engine vocab size: {len(engine_commands)}")

            # NOTE env.get_world_objects covers all objects/entities appearing in the triples
            # (it covers objects, locations and characters)
            unique_objects = set([item.name for item in list(env.get_world_objects()) if not item.name == ''])
            #ipdb.set_trace()
            for obj in unique_objects:
                if not obj in world_objects:
                    world_objects[obj] = len(world_objects)
            #ipdb.set_trace()

            #for instance in tqdm.tqdm(self.text_data.data[game], desc=f"{game}"):
            for iid, instance in enumerate(self.text_data.data[game]):

                # Text observations
                # the attr values are emtpy
                self.text_observations.append(instance['state']['obs'])
                self.text_observations.append(instance['state']['loc_desc'])
                self.text_observations.append(instance['state']['inv_desc'])

                self.text_observations.extend(instance['state']['inv_objs'].keys())
                self.text_observations.extend(chain(*instance['state']['inv_objs'].values()))
                #self.text_observations.extend(instance['state']['inv_attrs'].keys())

                self.text_observations.extend(instance['state']['surrounding_objs'].keys())
                self.text_observations.extend(chain(*instance['state']['surrounding_objs'].values()))
                self.text_observations.extend(instance['state']['surrounding_attrs'].keys())


                # Add all action tokens to the action vocab.
                for action in instance['state']['valid_acts_normalized'] + instance['next_state']['valid_acts_normalized'] + instance['transition_action_normalized']:
                    for token in action:
                        self.actions.add(token)

                #ipdb.set_trace()

                
                """ 
                # Not using this 
                for tok in raw_tokens:
                    action_tokens.extend(tok.split())
                ipdb.set_trace()

                for token in action_tokens:
                    token = token.lower()
                    if token in abbrv_dict:
                        token = abbrv_dict[token]

                    if check_action(token, engine_commands=engine_commands):
                        self.actions.add(token)
                    else:
                        # Try checking 6-letter truncated form
                        if not check_action(token[:6], engine_commands=engine_commands):
                            # NOTE Debug print
                            #print(token)
                            #print(instance['action'])
                            #ipdb.set_trace()
                            pass
                        self.actions.add(token)
                """


                #self.triples.extend(instance['graph_diff'])
                #ipdb.set_trace()
                #print('\n',' '.join(instance['transition_action_normalized'][0]),'\n', instance['action'], '\n\n', instance['state']['graph'], '\n\n', get_graph_diff(instance['state']['graph'], instance['next_state']['graph'])) 
                for state in [instance['state'], instance['next_state']]:

                    # Location
                    #self.unique_locations.add(state['location']['name'])
                    self.triples.extend(state['graph'])
                    triples_items = list(chain(*state['graph']))
                    # NOTE choose lowercase or not
                    #triples_items = [item.lower() for item in triples_items]
                    triples_items = [item for item in triples_items]

                    #ipdb.set_trace()

                    self.triples_vocab.update(triples_items)

                    #ipdb.set_trace()
                    for triple in state['graph']:
                        s, r, o = triple

                        # DEBUG
                        #if s == 'jewel encrusted dagger' or o == 'jewel encrusted dagger':
                        #    ipdb.set_trace()


                        self.relations.append(r)
                        if not (s in unique_objects):
                            #print(game, s)
                            #ipdb.set_trace()
                            self.oov.add(s)
                        if not (o in unique_objects):
                            #print(game, o)
                            #ipdb.set_trace()
                            self.oov.add(o)

                        #if s == 'sword' or o == 'sword':
                        #    ipdb.set_trace()

                    #self.extra_vocab.update(list(chain(*state['surrounding_objs'].values())))
                    #self.extra_vocab.update(list(chain(*state['surrounding_attrs'].values())))
                    #self.extra_vocab.update(state['surrounding_attrs'].keys())
                    #self.extra_vocab.add(state['location']['name'])
                    #self.extra_vocab.update(list(chain(*state['inv_objs'].values())))
                    #ipdb.set_trace()

                #if game == 'pentari':
                #    ipdb.set_trace()

        #tokenizer_training_data = self.text_observations + [' '.join(self.actions)] + [' '.join(world_objects)]
        tokenizer_training_data = self.text_observations #+ list(self.actions) + list(world_objects)

        game_tokens = list(world_objects.keys()) + list(self.actions)

        special_tokens = ['[PAD]', '[UNK]', '[BOS]', '[SOS]', '[EOS]', '[OBS]', '[ACT]', '[GRAPH]', '[TRIPLE]', '[ADD]', '[DEL]', '[PREV_ACT]', '[TRANSITION_ACT]']

        #mode = 'graph_decoder_tokenizer'
        #mode = 'action_decoder_tokenizer'
        mode = self.args.mode

        #mode = 'bpe'
        #mode = 'sentencepiece'
        #mode = 'bert'
        #mode = 'distilbert'
        #mode = 'gpt2'

        #mode = 'distilbert'

        pre_action_vocab = json.load(open(self.args.pre_action_vocab,'r'))
        pre_graph_vocab = json.load(open(self.args.pre_graph_vocab,'r'))

        if mode == 'sentencepiece':
            sp_tokenizer = train_sentencepiece(tokenizer_training_data, special_tokens=special_tokens)
            tokenizer = PreTrainedTokenizerFast(
                tokenizer_object=sp_tokenizer,
            )

            added_tokens = ['[OBS]', '[ACT]', '[GRAPH]', '[TRIPLE]'] + list(set(self.relations)) \
                                + list(world_objects.keys()) + list(self.actions)
            tokenizer.add_tokens([tokenizers.AddedToken(w, single_word=True) for w in added_tokens])
            print(f"Created sentenpiece tokenizer with vocab size {len(tokenizer.vocab)}")


            #tokenizer.pad_token = '[PAD]'
            #tokenizer.unk_token = '[UNK]'
            #tokenizer.bos_token = '[BOS]'
            #tokenizer.EOS_token = '[EOS]'

            ipdb.set_trace()

        if mode == 'graph_encoder_tokenizer':

            input_vocab = {}
            #input_vocab = {"this": 0, "is": 1, "a": 2, "test": 3, "[UNK]": 4}
            for tok in special_tokens:
                if not tok in input_vocab:
                    input_vocab[tok] = len(input_vocab)
            # special tokens are inthis order:
            # pad, unk, bos, sos, eos, obs, act, graph, triple ...

            for tok in set(list(set(self.relations)) + pre_graph_vocab + list(world_objects.keys())):
                if not tok in input_vocab:
                    input_vocab[tok] = len(input_vocab)
            tokenizer = Tokenizer(WordLevel(input_vocab, unk_token="[UNK]"))
            tokenizer.normalizer = NFKC()
            tokenizer.pre_tokenizer = Whitespace()
            tokenizer = PreTrainedTokenizerFast(tokenizer_object=tokenizer)
            #whitespace_tokenizer.tokenize('hi annyoung hello!') 
            added_tokens = []
            #ipdb.set_trace()


        if mode == 'graph_decoder_tokenizer':
            #tokenizer_training_data = self.text_observations + list(self.actions) + list(world_objects)
            #sp_tokenizer = train_sentencepiece(tokenizer_training_data, 7500)
            #tokenizer = PreTrainedTokenizerFast(tokenizer_object=sp_tokenizer)
            #added_tokens = special_tokens + list(set(self.relations)) + list(world_objects.keys()) + list(self.actions)

            input_vocab = {}
            #input_vocab = {"this": 0, "is": 1, "a": 2, "test": 3, "[UNK]": 4}
            for tok in special_tokens:
                if not tok in input_vocab:
                    input_vocab[tok] = len(input_vocab)
            # special tokens are inthis order:
            # pad, unk, bos, sos, eos, obs, act, graph, triple ...

            for tok in set(list(set(self.relations)) + pre_graph_vocab + list(world_objects.keys())):
                if not tok in input_vocab:
                    input_vocab[tok] = len(input_vocab)
            tokenizer = Tokenizer(WordLevel(input_vocab, unk_token="[UNK]"))
            tokenizer.normalizer = NFKC()
            tokenizer.pre_tokenizer = Whitespace()
            tokenizer = PreTrainedTokenizerFast(tokenizer_object=tokenizer)
            #whitespace_tokenizer.tokenize('hi annyoung hello!') 
            added_tokens = []
            #ipdb.set_trace()


        if mode == 'action_decoder_tokenizer':

            #tokenizer_training_data = self.text_observations + list(self.actions) + pre_action_vocab #+ list(world_objects)
            #sp_tokenizer = train_sentencepiece(tokenizer_training_data, 7500)
            #tokenizer = PreTrainedTokenizerFast(tokenizer_object=sp_tokenizer)
            #added_tokens = special_tokens + list(set(self.relations)) + list(world_objects.keys()) + list(self.actions)
            # New version 1/18/22
            #added_tokens = special_tokens + list(set(self.relations)) + list( set(list(self.actions) + pre_action_vocab) )

            #tokenizer.add_tokens([tokenizers.AddedToken(w, single_word=True) for w in added_tokens])

            #tokenizer.tokenize(self.text_data.data['pentari'][0]['state']['obs'])
            #tokenizer.tokenize(' '.join(self.triples[24]))
            #self.tokenizer_test_t5(tokenizer)

            """
            ipdb.set_trace()

            tokenizer_sp = tokenizers.SentencePieceBPETokenizer(vocab=added_tokens)
            tokenizer_sp.train_from_iterator(tokenizer_training_data, vocab_size=12000, min_frequency=1, show_progress=True, limit_alphabet=500)
            tokenizer = PreTrainedTokenizerFast(tokenizer_object=tokenizer_sp)

            ipdb.set_trace()
            tokenizer = Tokenizer(BPE())
            tokenizer = PreTrainedTokenizerFast(tokenizer_object=tokenizer)
            added_tokens = special_tokens + list(set(self.relations)) + list(world_objects.keys()) + list(self.actions) 
            tokenizer.add_tokens([tokenizers.AddedToken(w, single_word=True) for w in added_tokens])
            """
            
            # https://huggingface.co/docs/tokenizers/python/latest/components.html
            # NOTE WordLevel requires pretokenizer (like whitespace)
    
            input_vocab = {}
            #input_vocab = {"this": 0, "is": 1, "a": 2, "test": 3, "[UNK]": 4}
            for tok in special_tokens:
                if not tok in input_vocab:
                    input_vocab[tok] = len(input_vocab)
            # special tokens are inthis order:
            # pad, unk, bos, sos, eos, obs, act, graph, triple ...

            for tok in set(list(self.actions) + pre_action_vocab):
                if not tok in input_vocab:
                    input_vocab[tok] = len(input_vocab)
            tokenizer = Tokenizer(WordLevel(input_vocab, unk_token="[UNK]"))
            tokenizer.normalizer = NFKC()
            tokenizer.pre_tokenizer = Whitespace()
            tokenizer = PreTrainedTokenizerFast(tokenizer_object=tokenizer)
            #whitespace_tokenizer.tokenize('hi annyoung hello!') 
            added_tokens = []
            #ipdb.set_trace()

        if mode == 'bpe':
            bpe_tokenizer = train_bpe(tokenizer_training_data, special_tokens=special_tokens)
            tokenizer = PreTrainedTokenizerFast(
                tokenizer_object=bpe_tokenizer,
            )

            tokenizer.pad_token = '[PAD]'
            tokenizer.unk_token = '[UNK]'
            tokenizer.bos_token = '[BOS]'
            tokenizer.EOS_token = '[EOS]'

        if mode == 't5':
            tokenizer = AutoTokenizer.from_pretrained("t5-base")
            #tokenizer._add_tokens(['_'+item for item in game_tokens])
            #tokenizer._add_tokens(game_tokens)
            #tokenizer.add_tokens(game_tokens)
            #tokenizer.add_tokens([tokenizers.AddedToken("_"+w, single_word=True) for w in game_tokens])
            added_tokens = ['[OBS]', '[ACT]', '[GRAPH]', '[TRIPLE]'] + list(set(self.relations)) \
                                + list(world_objects.keys()) + list(self.actions)
            tokenizer.add_tokens([tokenizers.AddedToken(w, single_word=True) for w in added_tokens])
        
            #tokenizer.bos_token = tokenizer.eos_token # t5 has no default bos


            #tokenizer2 = AutoTokenizer.from_pretrained("t5-base")
            #tokenizer2.add_tokens(['_'+item for item in game_tokens])
            #ipdb.set_trace()

        if mode == 'bart':
            #tokenizer = AutoTokenizer.from_pretrained("facebook/bart-base")
            #tokenizer.add_tokens([tokenizers.AddedToken(w, single_word=True) for w in game_tokens])
            ##tokenizer.add_tokens(game_tokens)

            #tokenizer._add_tokens(game_tokens)
            #tokenizer.tokenize(self.text_observations[0])
            1

        if mode == 'bert':
            tokenizer = AutoTokenizer.from_pretrained("bert-base-cased")
            added_tokens = ['[OBS]', '[ACT]', '[GRAPH]', '[TRIPLE]'] + list(set(self.relations)) \
                                + list(world_objects.keys()) + list(self.actions)
            tokenizer.add_tokens([tokenizers.AddedToken(w, single_word=True) for w in added_tokens])
            print(f"Created extended BERT tokenizer with vocab size {len(tokenizer.vocab)}")

        if mode == 'distilbert':
            ipdb.set_trace()
            tokenizer = AutoTokenizer.from_pretrained("distilbert-base-cased")
            added_tokens = special_tokens + list(set(self.relations)) + list(world_objects.keys()) + list(self.actions)
            #tokenizer.add_tokens([tokenizers.AddedToken(w, single_word=True) for w in added_tokens])
            print(f"Created extended BERT tokenizer with vocab size {len(tokenizer.vocab)}")

        if mode == 'gpt2':
            #ipdb.set_trace()
            tokenizer = AutoTokenizer.from_pretrained("gpt2")
            added_tokens = ['[OBS]', '[ACT]', '[GRAPH]', '[TRIPLE]'] + list(set(self.relations)) \
                                + list(world_objects.keys()) + list(self.actions)
            tokenizer.add_tokens([tokenizers.AddedToken(w, single_word=True) for w in added_tokens])
            print(f"Created extended BERT tokenizer with vocab size {len(tokenizer.vocab)}")


        # triples tokenizer, for graph

        #graph_triples_tokenizer = make_triples_tokenizer(world_objects)
        
        #action_graph_triples_tokenizer = make_triples_tokenizer(world_objects)

        #ipdb.set_trace()
        #self.tokenizer_test_t5(tokenizer)


        # do tests
        #self.tokenizer_test_action(tokenizer)
        #self.tokenizer_test_graph(tokenizer)

        #ipdb.set_trace()

        # save the added tokens
        tokens_dict = {}
        tokens_dict['special_tokens'] = special_tokens #['[OBS]', '[ACT]', '[GRAPH]', '[TRIPLE]']
        tokens_dict['relations'] = list(set(self.relations))
        tokens_dict['world_objects'] = list(world_objects.keys())
        tokens_dict['action_tokens'] = list(self.actions)

        return tokenizer, tokens_dict

    def tokenizer_test_t5(self, tokenizer):
        n_tests = 0
        n_fail = 0

        # Test graph triples
        for s, r, o in tqdm.tqdm(self.triples):
            #if s == 'you' or o == 'you':
                # t5 tokenizer splits 'you'
            #    continue

            n_tests += 1 # s
            n_tests += 1 # o

            s_tok = tokenizer(s, add_special_tokens=False)['input_ids']
            o_tok = tokenizer(o, add_special_tokens=False)['input_ids']

            #ipdb.set_trace()
            if not len(s_tok) == 1:
                n_fail += 1
                #ipdb.set_trace()
            else:
                if not tokenizer.decode(s_tok[0]) == s:
                    n_fail += 1
                    #ipdb.set_trace()
            if not len(o_tok) == 1:
                n_fail += 1
                #ipdb.set_trace()
            else:
                if not tokenizer.decode(o_tok[0]) == o:
                    n_fail += 1
                    #ipdb.set_trace()
            print(f"Failure rate: {n_fail/float(n_tests+1)}")
            #ipdb.set_trace()
        ipdb.set_trace()


    def tokenizer_test_graph(self, tokenizer):
        n_tests = 0
        n_fail = 0

        # Test graph triples
        for s, r, o in tqdm.tqdm(self.triples):
            n_tests += 1 # s
            n_tests += 1 # o

            s_tok = tokenizer(s)['input_ids']
            o_tok = tokenizer(o)['input_ids']

            #ipdb.set_trace()
            if not len(s_tok) == 1:
                n_fail += 1
            else:
                if not tokenizer.decode(s_tok[0]) == s:
                    n_fail += 1

            if not len(o_tok) == 1:
                n_fail += 1
            else:
                if not tokenizer.decode(o_tok[0]) == o:
                    n_fail += 1

            print(f"Failure rate: {n_fail/float(n_tests+1)}")
        ipdb.set_trace()

    def tokenizer_test_action(self, tokenizer):
        n_tests = 0
        n_fail = 0

        ipdb.set_trace()
        # Test graph triples
        for action in tqdm.tqdm(self.actions):
            n_tests += 1 # s

            tok = tokenizer(action)['input_ids']

            #ipdb.set_trace()
            if not len(tok) == 1:
                n_fail += 1
            else:
                if not tokenizer.decode(tok[0]) == action:
                    n_fail += 1

            print(f"Failure rate: {n_fail/float(n_tests+1)}")
        ipdb.set_trace()


    def get_jericho_base_vocab(self):
        jericho_path = Path("/home/mnskim/workspace/tbg/jericho/z-machine-games-master/jericho-game-suite/")
        graph_vocab = set()
    
        for game in self.text_data.list_games():
            path = list(jericho_path.glob(f"{game}.z*"))[0] # because extention can differ
            env = jericho.FrotzEnv(str(path))
            
            input_vocab = env.get_dictionary()
            #print(f"{game} vocab size: {len(input_vocab)}")

            unique_objects = set([item.name for item in list(env.get_world_objects()) if not item.name == ''])
            graph_vocab.update(unique_objects)
            #ipdb.set_trace()

        #ipdb.set_trace()
        #graph_vocab.update(self.extra_vocab)
        print(len(graph_vocab))
        ipdb.set_trace()

        return graph_vocab
    
    def _build_graph_vocab(self):
        graph_vocab = set()
        for triple in self.triples:
            for item in triple:
                #graph_vocab.update(item.split(' '))
                graph_vocab.add(item)
                #ipdb.set_trace()

        graph_vocab.update(self.extra_vocab)
        #ipdb.set_trace()

        return graph_vocab
 
    def __getitem__(self, idx) -> dict:
        item = {key: torch.tensor(val[idx]) for key, val in self.encodings.items()}
        item['labels'] = torch.tensor(self.labels[idx])
        return item

    def __len__(self):
        return len(self.labels)



"""
utils
"""

def convert_to_tokens_simple(valid_acts: dict, lower=False, unabbrev=False):
    """
    Get action tokens from the original dict
    """
    global abbrv_dict
    
    #ipdb.set_trace()
    output = []
    for action in valid_acts:
        _tokens = action.split()
        tokens = []
        for tok in _tokens:
            if lower:
                tok = tok.lower()
            if unabbrev:
                if tok in abbrv_dict:
                    tok = abbrv_dict[tok]
            tokens.append(tok)
        output.append(tokens)
    return output

def convert_to_tokens(valid_acts: dict, lower=False, unabbrev=False):
    """
    Get action tokens from the original dict
    Uses simple whitespace split to tokenize actions, objects, and templates
    """
    global abbrv_dict
    
    output = []
    for action, conditions in valid_acts: # [action, [list of objects, template]]
        objects, template = conditions
        _tokens = action.split()
        action_tokens = []
        for tok in _tokens:
            if lower:
                tok = tok.lower()
            if unabbrev:
                if tok in abbrv_dict:
                    tok = abbrv_dict[tok]
            action_tokens.append(tok)

        # [action, [list of objects, template]]
        res = [action_tokens, [[obj.split(' ') for obj in objects], template.split(' ')]]
        output.append(res)
    #ipdb.set_trace()
    return output


def check_action(action: str, engine_commands: set = None):
    """
    Check if action can be found in the sets of valid actions
    """

    global basic_actions
    global noeffect_actions
    global abbrv_dict

    to_check = set()

    if engine_commands is not None:
        to_check.update(engine_commands)

    to_check.update(basic_actions)
    to_check.update(noeffect_actions)
    to_check.update(abbrv_dict.values())

    #ipdb.set_trace()

    return action in to_check

def split(sent: str, method: str,):
    global whitespace_tokenizer 

    if method == 'regex':
        return re.split(r'[.,\'\"]', sent)
    if method == 'hf_whitespace':
        return [item[0] for item in whitespace_tokenizer.pre_tokenize_str(sent)]

def get_graph_diff(g1, g2):
    """
    operations needed to make g1 -> g2
    """
    adds = [item for item in g2 if not item in g1]
    dels = [item for item in g1 if not item in g2]
    #ipdb.set_trace()
    return adds, dels

def check_fix_state_kg(game: Dict, game_step_id: int) -> None:
    """
    Tries to fix the KG for the state associate with game_step_id, in place.
    It replaces the KG for game_step_id with the KG for `next_step` from previous
    step, or uses the `graph_diff` if game_step_id == 0.
    """

    if game[game_step_id]['state']['graph']:
        return

    if game_step_id == 0:
        # The first empty step, replace with the diff.
        game[0]['state']['graph'] = game[0]['graph_diff']
    else:
        #ipdb.set_trace()

        # Get the graph from the next_state that was created in the previous step.
        prev_state_next = game[game_step_id - 1]['next_state']
        game[game_step_id]['state']['graph'] = deepcopy(prev_state_next['graph'])

if __name__=="__main__":


    parser = argparse.ArgumentParser()
    parser.add_argument('--train_file', type=str, help='train json file')
    parser.add_argument('--test_file', type=str, help='train json file')

    # Other
    parser.add_argument('--seed', type=int, default=42, help='')


    args = parser.parse_args()

    """
    #global BASIC_ACTIONS

    #save_path = "/home/mnskim/workspace/tbg/tbg1/ckpts/tokenizers/sentencepiece-tokenizer-v1"
    #save_path = "/home/mnskim/workspace/tbg/tbg1/ckpts/tokenizers/sentencepiece-tokenizer-v2"

    #save_path = "/home/mnskim/workspace/tbg/tbg1/ckpts/tokenizers/t5_extended-tokenizer-v1"
    #save_path = "/home/mnskim/workspace/tbg/tbg1/ckpts/tokenizers/bert_extended-tokenizer-v1"
    #save_path = "/home/mnskim/workspace/tbg/tbg1/ckpts/tokenizers/distilbert_extended-tokenizer-v1"
    save_path = "/home/mnskim/workspace/tbg/tbg1/ckpts/tokenizers/distilbert_extended-tokenizer-fast"
    save_path = "/home/mnskim/workspace/tbg/tbg1/ckpts/tokenizers/gpt2_extended-tokenizer-fast"

    save_path = "/home/mnskim/workspace/tbg/tbg1/ckpts/tokenizers/action_decoder-tokenizer-v2"
    #save_path = "/home/mnskim/workspace/tbg/tbg1/ckpts/tokenizers/graph_decoder-tokenizer-v1"

    #save_path = "/home/mnskim/workspace/tbg/tbg1/ckpts/tokenizers/text_encoder-tokenizer-v2"
    #save_path = "/home/mnskim/workspace/tbg/tbg1/ckpts/tokenizers/graph_encoder-tokenizer-v2"

    # New data
    #save_path = "/home/mnskim/workspace/tbg/tbg1/ckpts/tokenizers/text_encoder-tokenizer-newdata"
    #save_path = "/home/mnskim/workspace/tbg/tbg1/ckpts/tokenizers/graph_decoder-tokenizer-newdata"
    save_path = "/home/mnskim/workspace/tbg/tbg1/ckpts/tokenizers/action_decoder-tokenizer-newdata"

    # for testing
    #train_dataset = JerichoWorldDataset('/home/mnskim/workspace/tbg/JerichoWorld/data/train.json')
    #test_dataset = JerichoWorldDataset('/home/mnskim/workspace/tbg/JerichoWorld/data/test.json')
    # Updated dataset
    train_dataset = JerichoWorldDataset('/home/mnskim/workspace/tbg/jerichoworld_build/JerichoWorld/datasets/v1/train.json')
    test_dataset = JerichoWorldDataset('/home/mnskim/workspace/tbg/jerichoworld_build/JerichoWorld/datasets/v1/test.json')
    print(train_dataset)
    print(test_dataset)

    train_dataset.data.update(test_dataset.data)
    #train_dataset.filter_games('loose')
    #train_dataset.filter_games('zork2')
    #train_dataset.filter_games('temple')
    #train_dataset.filter_games('pentari')
    #train_dataset.filter_games('ztuu')

    dataset = train_dataset

    #ipdb.set_trace()
    vocab_dataset = VocabBuilderDataset(dataset)
    vocab_dataset.save_tokenizer(save_path)

    #print(dataset)
    #ipdb.set_trace()
    """
