from transformers import AutoTokenizer
import ipdb
import tokenizers
import json

# Tokenization related
from tokenizers import Tokenizer
from tokenizers.normalizers import NFKC
from tokenizers import decoders
from tokenizers.models import BPE
from tokenizers.trainers import BpeTrainer
from tokenizers.pre_tokenizers import Whitespace, Digits, Metaspace
from tokenizers import SentencePieceBPETokenizer
from transformers import PreTrainedTokenizerFast

# TEMPORAYR
#special_tokens = ['[PAD]', '[UNK]', '[BOS]', '[SOS]', '[EOS]', '[OBS]', '[ACT]', '[GRAPH]', '[TRIPLE]', '[ADD]', '[DEL]', '[PREV_ACT]', '[TRANSITION_ACT]']
#additional_special_tokens = ['[SOS]', '[OBS]', '[ACT]', '[GRAPH]', '[TRIPLE]', '[ADD]', '[DEL]', '[PREV_ACT]', '[TRANSITION_ACT]']
#atomic2020_relation_vocab = json.load(open('/home/mnskim/workspace/tbg/tbg1/worldformer2/tokenization/atomic2020_vocab.json', 'r'))['relation_vocab']
#additional_special_tokens += atomic2020_relation_vocab

def get_tokenizer(hf_tokenizer_name, added_tokens_path=None, add_atomic_tokens=False):
    """
    For loading hf tokenizer with extra vocab
    """

    #special_tokens = ['[PAD]', '[UNK]', '[BOS]', '[SOS]', '[EOS]', '[OBS]', '[ACT]', '[GRAPH]', '[TRIPLE]', '[ADD]', '[DEL]', '[PREV_ACT]', '[TRANSITION_ACT]']
    additional_special_tokens = ['[SOS]', '[OBS]', '[ACT]', '[GRAPH]', '[TRIPLE]', '[ADD]', '[DEL]', '[PREV_ACT]', '[TRANSITION_ACT]', '[TEMPLATE]', '[NO_TEMPLATE]', '[TEMPLATE_START]', '[TEMPLATE_END]']
    if add_atomic_tokens:
        print("## Adding atomic tokens to tokenizer")
        atomic2020_relation_vocab = json.load(open('/home/mnskim/workspace/tbg/tbg1/worldformer2/tokenization/atomic2020_vocab.json', 'r'))['relation_vocab']
        additional_special_tokens += atomic2020_relation_vocab


    if hf_tokenizer_name == 'graph_decoder_tokenizer':
        tokenizer = Tokenizer(BPE())
        tokenizer = PreTrainedTokenizerFast(tokenizer_object=tokenizer)
        tokens_dict = json.load(open(added_tokens_path, 'r'))


        tokenizer.add_special_tokens({'pad_token': '[PAD]',
                                      'unk_token': '[UNK]',
                                      'bos_token': '[BOS]',
                                      'eos_token': '[EOS]',
                                     }
                                     )

        added_tokens = tokens_dict['special_tokens'] + tokens_dict['relations'] + tokens_dict['world_objects']
        tokenizer.add_tokens([tokenizers.AddedToken(w, single_word=True) for w in added_tokens])

        #tokenizer.pad_token = '[PAD]'
        #tokenizer.unk_token = '[UNK]'
        #tokenizer.bos_token = '[BOS]'
        #tokenizer.eos_token = '[EOS]'

        print(f"Created {hf_tokenizer_name} decoder tokenizer with vocab size {len(tokenizer.vocab)}")


    elif hf_tokenizer_name == 'action_decoder_tokenizer':
        tokenizer = Tokenizer(BPE())
        tokenizer = PreTrainedTokenizerFast(tokenizer_object=tokenizer)
        tokens_dict = json.load(open(added_tokens_path, 'r'))


        tokenizer.add_special_tokens({'pad_token': '[PAD]',
                                      'unk_token': '[UNK]',
                                      'bos_token': '[BOS]',
                                      'eos_token': '[EOS]',
                                     }
                                     )
        added_tokens = tokens_dict['special_tokens'] + tokens_dict['relations'] + tokens_dict['world_objects'] + tokens_dict['action_tokens']
        tokenizer.add_tokens([tokenizers.AddedToken(w, single_word=True) for w in added_tokens])

        print(f"Created {hf_tokenizer_name} decoder tokenizer with vocab size {len(tokenizer.vocab)}")
        #ipdb.set_trace()

    elif hf_tokenizer_name == 'facebook/bart-base':
        tokenizer = AutoTokenizer.from_pretrained(hf_tokenizer_name, local_files_only=True)

        if added_tokens_path is not None:
            tokens_dict = json.load(open(added_tokens_path, 'r'))

            added_tokens = tokens_dict['special_tokens'] + tokens_dict['relations'] + tokens_dict['world_objects'] + tokens_dict['action_tokens']
            tokenizer.add_tokens([tokenizers.AddedToken(w, single_word=True) for w in added_tokens])
            print(f"Created extended {hf_tokenizer_name} tokenizer with vocab size {len(tokenizer.vocab)}")
            # NOTE not adding special tokens with add_special_tokens
       
            # NOTE tmp hack TODO
            #if tokenizer.pad_token == None:
            #    tokenizer.add_special_tokens({'pad_token': '[PAD]'})

        #ipdb.set_trace()
        special_tokens_dict = {'additional_special_tokens': additional_special_tokens}
        num_added_toks = tokenizer.add_special_tokens(special_tokens_dict)



    else:
        #ipdb.set_trace()
        tokenizer = AutoTokenizer.from_pretrained(hf_tokenizer_name)

        if added_tokens_path is not None:
            tokens_dict = json.load(open(added_tokens_path, 'r'))

            added_tokens = tokens_dict['special_tokens'] + tokens_dict['relations'] + tokens_dict['world_objects'] + tokens_dict['action_tokens']
            tokenizer.add_tokens([tokenizers.AddedToken(w, single_word=True) for w in added_tokens])
            print(f"Created extended {hf_tokenizer_name} tokenizer with vocab size {len(tokenizer.vocab)}")
            # NOTE not adding special tokens with add_special_tokens
       
            # NOTE tmp hack TODO
            #if tokenizer.pad_token == None:
            #    tokenizer.add_special_tokens({'pad_token': '[PAD]'})

        #ipdb.set_trace()
        tokenizer.add_special_tokens({'pad_token': '[PAD]',
                                  'unk_token': '[UNK]',
                                  'bos_token': '[BOS]',
                                  'eos_token': '[EOS]',
                                  'sep_token': '[SEP]',
                                 }
                                 )

        special_tokens_dict = {'additional_special_tokens': additional_special_tokens}
        num_added_toks = tokenizer.add_special_tokens(special_tokens_dict)


        #ipdb.set_trace()

    return tokenizer


if __name__=="__main__":
    # test
    #get_tokenizer('distilbert-base-cased', '/home/mnskim/workspace/tbg/tbg1/ckpts/tokenizers/distilbert_extended-tokenizer-fast/extra/added_tokens.json')
    #get_tokenizer('graph_decoder_tokenizer', '/home/mnskim/workspace/tbg/tbg1/ckpts/tokenizers/graph_decoder-tokenizer-v1/extra/added_tokens.json')
    get_tokenizer('/home/mnskim/workspace/tbg/tbg1/ckpts/tokenizers/action_decoder-tokenizer-v1', '/home/mnskim/workspace/tbg/tbg1/ckpts/tokenizers/action_decoder-tokenizer-v1/extra/added_tokens.json')

