"""
A chess tokenizer which converts UCI moves (e2e4 d7d5 e4d5...) to ids. Uses the huggingface implementation of the tokenizer.
"""

import os
from tokenizers import models, Tokenizer, pre_tokenizers
from transformers import AutoTokenizer
import json

def build_tokenizer(tokenizer_dir):
    # create vocab
    all_tiles = [f"{chr(ord('a')+col)}{row+1}" for col in range(8) for row in range(8)]

    special_tokens = ["<bos>", "<pad>", "<unk>"]
    vocab = {token: i for i, token in enumerate(special_tokens)}
    for tile in all_tiles:
        vocab[tile] = len(vocab)

    # create tokenizer
    tokenizer = Tokenizer(models.WordPiece(unk_token="<unk>", vocab=vocab))
    tokenizer.add_special_tokens(special_tokens)
    tokenizer.pre_tokenizer = pre_tokenizers.WhitespaceSplit()  # split on whitespace
    tokenizer.model.continuing_subword_prefix = ""              # allow to split "e2e4" into "e2" and "e4"

    # save
    os.makedirs(tokenizer_dir, exist_ok=True)
    tokenizer.save(f"{tokenizer_dir}/tokenizer.json")
    
    # add a config.json file to the tokenizer folder
    with open(f"{tokenizer_dir}/config.json", "w") as f:
    #     f.write(f'{{"model_type": "gpt2", "bos_token_id": "{vocab["<bos>"]}", "pad_token_id": "{vocab["<pad>"]}", "eos_token_id": "{vocab["<pad>"]}", "unk_token_id": "{vocab["<unk>"]}", "vocab_size": {len(vocab)}}}')
        json.dump(dict(model_type="gpt2",), f)

def test(tokenizer_dir):
    # test
    tokenizer = AutoTokenizer.from_pretrained(tokenizer_dir)
    
    sample = "e2e4 d7d5 e4d5 d5c6 d2d4 c6b7 d4d5 b7c8 d5d6 c8d8 d6d7 d8e8 d7d8 e8f8 d8e8"
    tokenized = tokenizer.encode(sample)
    decoded = tokenizer.decode(tokenized)
    print("Input:", sample)
    print("Tokenized:", tokenized)
    print("Decoded:", decoded)
    assert sample.replace(" ", "") == decoded.replace(" ", ""), "Tokenizer is not reversible."
    print("Test passed!")


if __name__ == "__main__":
    tokenizer_dir = f"tokenizers/chess_tokenizer"
    build_tokenizer(tokenizer_dir)
    test(tokenizer_dir)
