from tokenizers import Tokenizer
from tokenizers.normalizers import NFKC
from tokenizers import decoders

from tokenizers.models import BPE
from tokenizers.trainers import BpeTrainer
from tokenizers.pre_tokenizers import Whitespace, Digits, Metaspace
from tokenizers import SentencePieceBPETokenizer
from transformers import PreTrainedTokenizerFast

from tokenizers.trainers import WordLevelTrainer

import ipdb

def train_sentencepiece(text, vocab_size, special_tokens=[]):
    tokenizer = SentencePieceBPETokenizer(unk_token='[UNK]')
    
    tokenizer.normalizer = NFKC()
    tokenizer.pre_tokenizer = Whitespace()
    #tokenizer.pre_tokenizer = Metaspace() # When using metaspace

    #vocab_size=11500
    #vocab_size=7000

    tokenizer.train_from_iterator(
            text,
            vocab_size=vocab_size,
            min_frequency=2,
            show_progress=True,
            limit_alphabet=1000,
            special_tokens=special_tokens,
    )

    #tokenizer.decoder = decoders.Metaspace() # When using metaspace

    #ipdb.set_trace()
    return tokenizer

def train_bpe(text, special_tokens=[]):
    tokenizer = Tokenizer(BPE(unk_token="[UNK]"))
    
    tokenizer.normalizer = NFKC()
    tokenizer.pre_tokenizer = Whitespace()

    trainer = BpeTrainer(special_tokens=special_tokens,
                        show_progress=True)

    tokenizer.train_from_iterator(text, trainer=trainer)

    #ipdb.set_trace()

    return tokenizer


def _make_triples_tokenizer_bpe(tokens: dict):
    tokenizer = Tokenizer(BPE(unk_token="[UNK]"))
    
    tokenizer.normalizer = NFKC()
    tokenizer.pre_tokenizer = Whitespace()

    trainer = BpeTrainer(special_tokens=special_tokens)


    tokenizer.train_from_iterator(text, trainer=trainer)
    return tokenizer



def make_triples_tokenizer(tokens: dict):
    """
    Tokenizer to turn graph s,r,o into ids
    """

    #ipdb.set_trace()
    from tokenizers.models import WordLevel
    tokenizer = Tokenizer(WordLevel(unk_token="[UNK]"))
    
    #tokenizer.normalizer = NFKC()
    #tokenizer.pre_tokenizer = Whitespace()
    #tokenizer.pre_tokenizer = Whitespace()

    #tokenizer.normalizer = normalizers.Sequence([NFD(), Lowercase(), StripAccents()])
    #tokenizer.pre_tokenizer = pre_tokenizers.Sequence([Whitespace(), Digits(individual_digits=True)])
    tokenizer.pre_tokenizer = Digits(individual_digits=True)

    trainer = WordLevelTrainer(vocab_size=1, special_tokens=["[UNK]", "[CLS]", "[SEP]", "[PAD]", "[MASK]"])

    tokenizer.train_from_iterator([], trainer=trainer)


    tokenizer = PreTrainedTokenizerFast(
                tokenizer_object=tokenizer,
                special_tokens=[]
        )

    tokenizer._add_tokens(list(tokens.keys()))
 
    #ipdb.set_trace()
    return tokenizer
