import json
import os

import numpy as np

from bilm.data import Batcher


class BiLMVocabLoader(object):
    def __init__(self, data_dir: str):
        vocab_file = os.path.join(data_dir, 'vocab.txt')
        options_file = os.path.join(data_dir, 'options.json')

        with open(options_file, 'r') as fin:
            options = json.load(fin)
            self.max_word_length = options['char_cnn']['max_characters_per_token']

        self.batcher = Batcher(vocab_file, self.max_word_length)

    def get_chars_input(self, words, padded_length, boundaries=True):
        bilm_chars = self.batcher.batch_sentences(
            [words], boundaries=boundaries)[0]
        bilm_chars_padded = np.pad(
            bilm_chars,
            ((0, padded_length - bilm_chars.shape[0] + (2 if boundaries else 0)),
             (0, 0)),
            "constant",
            constant_values=0
        )
        return bilm_chars_padded


if __name__ == '__main__':
    loader = BiLMVocabLoader("/home/author_name/work/bilm/official/")
    print(loader.get_chars_input(["Municipals"], 10))
