from preprocessing.w2v_preprocessing_embedding import PreprocessingWord2VecEmbedding


class W2VModel:
    def __init__(self, pretrained_embeddings_path, binary=True, entity_on_error=True):
        self.preprocessor = PreprocessingWord2VecEmbedding(pretrained_embeddings_path, binary=binary)
        self.entity_on_error = entity_on_error

    def get_vector_and_word(self, word: str):
        if self.entity_on_error:
            return self.preprocessor.get_vector_and_word(word)

        return self.preprocessor.model.word_vec(word), word

    def predict(self, word: str):
        return self.get_vector(word)

    def get_vector(self, word: str):
        if self.entity_on_error:
            return self.preprocessor.get_vector(word)

        return self.preprocessor.model.word_vec(word)

    def get_vector_example(self, words):
        if self.entity_on_error:
            return self.preprocessor.get_vector_example(words)

        target, data = words[0], words[1:]
        vectors_example = {'target': self.preprocessor.model.word_vec(target),
                           'data': [self.preprocessor.model.word_vec(word) for word in data]}
        return vectors_example