import numpy as np
from nltk.corpus import wordnet as wn

from defiNNet.DefAnalyzer import DefAnalyzer
from defiNNet.parsing.parse_tree import Parser
from preprocessing.w2v_preprocessing_embedding import PreprocessingWord2VecEmbedding


class Additive:
    def __init__(self, pretrained_embeddings_path, binary=True):
        self.preprocessor = PreprocessingWord2VecEmbedding(pretrained_embeddings_path, binary=binary)
        self.defAnalyzer = DefAnalyzer(parser=Parser())

    def predict_and_word(self, target, pos=None):
        if pos is None:
            if len(wn.synsets(target)) > 0:
                pos = wn.synsets(target)[0].pos()
        else:
            pos = pos.lower()

        words, _ = self.defAnalyzer.analyze(target, category=pos)
        w1 = words[1]
        w2 = words[2]
        return self.predict_analyzed([w1, w2]), w1+"_"+w2

    def predict(self, target, pos=None):
        if pos is None:
            if len(wn.synsets(target)) > 0:
                pos = wn.synsets(target)[0].pos()
        else:
            pos = pos.lower()

        words, _ = self.defAnalyzer.analyze(target, category=pos)
        w1 = words[1]
        w2 = words[2]
        return self.predict_analyzed([w1, w2])

    def predict_analyzed(self, words):
        x = np.array([self.preprocessor.get_vector(word) for word in words])

        if x.shape == (2, 300):
            return np.sum(a=x, axis=0)
