from nltk.corpus.reader import Synset
from nltk.corpus import wordnet as wn

from preprocessing.w2v_preprocessing_embedding import PreprocessingWord2VecEmbedding
from utility.randomfixedseed import Random
from utility.word_in_vocabulary import WNManager


class HeadModel:
    def __init__(self, pretrained_embeddings_path, binary=True):
        self.preprocessor = PreprocessingWord2VecEmbedding(pretrained_embeddings_path, binary=binary)

    def predict_and_word(self, word, pos_tag=None, synset=None):
        pos = pos_tag
        if pos == 'j':
            pos = 'a'
        if synset is None:
            if len(wn.synsets(word, pos=pos)) > 0:
                syn = wn.synsets(lemma=word, pos=pos)[0]
            else:
                return self.preprocessor.get_vector_and_word("entity")
        else:
            if type(synset) is str:
                syn = wn.synset(synset)
            if type(synset) is Synset:
                syn = synset

        h_paths = syn.hypernym_paths()

        definition = syn.definition().split(' ')
        for h_path in h_paths:
            h_path.reverse()
            for hyp in h_path:
                in_voc = [lemma for lemma in hyp.lemma_names()
                          if lemma in self.preprocessor.model.vocab and not WNManager.is_expression(lemma=lemma)]
                for voc in in_voc:
                    for w in definition:
                        if w.startswith(voc) and (len(w) - len(voc)) in range(-3, 3):
                            return self.preprocessor.get_vector_and_word(voc)

        if len(syn.hypernyms()) > 0:
            in_voc = [lemma for lemma in syn.hypernyms()[0].lemma_names() if lemma in self.preprocessor.model.vocab
                      and lemma != word and not WNManager.is_expression(lemma=lemma)]
            if len(in_voc) != 0:
                voc = Random.randomchoice(in_voc)
                return self.preprocessor.get_vector_and_word(voc)

        return self.preprocessor.get_vector_and_word("entity")

    def predict(self, word, pos_tag=None, synset=None):
        pos = pos_tag
        if pos == 'j':
            pos = 'a'
        if synset is None:
            if len(wn.synsets(word, pos=pos)) > 0:
                syn = wn.synsets(lemma=word, pos=pos)[0]
            else:
                return self.preprocessor.get_vector("entity")
        else:
            if type(synset) is str:
                syn = wn.synset(synset)
            if type(synset) is Synset:
                syn = synset

        h_paths = syn.hypernym_paths()

        definition = syn.definition().split(' ')
        for h_path in h_paths:
            h_path.reverse()
            for hyp in h_path:
                in_voc = [lemma for lemma in hyp.lemma_names()
                          if lemma in self.preprocessor.model.vocab and not WNManager.is_expression(lemma=lemma)]
                for voc in in_voc:
                    for w in definition:
                        if w.startswith(voc) and (len(w) - len(voc)) in range(-3, 3):
                            return self.preprocessor.get_vector(voc)

        if len(syn.hypernyms()) > 0:
            in_voc = [lemma for lemma in syn.hypernyms()[0].lemma_names() if lemma in self.preprocessor.model.vocab
                      and lemma != word and not WNManager.is_expression(lemma=lemma)]
            if len(in_voc) != 0:
                voc = Random.randomchoice(in_voc)
                return self.preprocessor.get_vector(voc)

        return self.preprocessor.get_vector("entity")



