import os
import glob
from tqdm import tqdm
from tokenizers import models, Tokenizer, pre_tokenizers
from transformers import AutoTokenizer
import json

def build_tokenizer(tokenizer_dir, dataset_dir):
    # read all words from the dataset
    all_tokens = set()
    files = glob.glob(f"{dataset_dir}/*.jsonl")
    for file in files:
        with open(file) as f:
            for line in tqdm(f, desc=f"Reading {file}"):
                data = json.loads(line)
                sentence = data["sentence"]
                sentence = sentence.replace(".", " .").replace(",", " ,")
                tokens = [x for x in sentence.split() if x]
                all_tokens.update(tokens)
    all_tokens = sorted(all_tokens)

    special_tokens = ["<bos>", "<pad>", "<unk>"]
    vocab = {token: i for i, token in enumerate(special_tokens)}
    for token in all_tokens:
        vocab[token] = len(vocab)

    # create tokenizer
    tokenizer = Tokenizer(models.WordPiece(unk_token="<unk>", vocab=vocab))
    tokenizer.add_special_tokens(special_tokens)
    tokenizer.pre_tokenizer = pre_tokenizers.WhitespaceSplit()  # split on whitespace
    tokenizer.model.continuing_subword_prefix = ""              # allows to split "hey." into "hey" and "."

    # save
    os.makedirs(tokenizer_dir, exist_ok=True)
    tokenizer.save(f"{tokenizer_dir}/tokenizer.json")
    
    # add a config.json file to the tokenizer folder
    with open(f"{tokenizer_dir}/config.json", "w") as f:
        json.dump(dict(model_type="gpt2",), f)

def test(tokenizer_dir):
    # test
    tokenizer = AutoTokenizer.from_pretrained(tokenizer_dir)
    
    sample = "The card is in Box G, the pipe is in Box E, there is nothing in Box D, the cheese is in Box C, the computer and the game are in Box B, the drug and the engine and the pot are in Box A, the boat and the tea are in Box F. Remove the pipe from Box E. Box A contains the drug and the engine and the pot, Box B contains the computer and the game, Box C contains the cheese, Box D is empty, Box E is empty, Box F contains the boat and the tea, Box G contains the card."
    tokenized = tokenizer.encode(sample)
    decoded = tokenizer.decode(tokenized)
    print("Input:", sample)
    print("Tokenized:", tokenized)
    print("Decoded:", decoded)
    assert sample.replace(" ", "") == decoded.replace(" ", ""), "Tokenizer is not reversible."
    print("Test passed!")


if __name__ == "__main__":
    # dataset_dir = "./datasets/boxes_dataset_move_only3"
    dataset_dir = "./datasets/boxes_dataset3"
    tokenizer_dir = f"tokenizers/boxes_tokenizer5"
    build_tokenizer(tokenizer_dir, dataset_dir)
    test(tokenizer_dir)
