from gensim.models import KeyedVectors

"""
Given file's path containing embeddings in Word2Vec format (binary and not),
the PreprocessingWord2VeEmbedding provides embedding for a single word of for group of words:
- get_vector method provides, if known, the embedding for a single word
- get_vector_example provides embeddings for a list of word: the first word is meant to be target, second and third data
- get_vector_example_couple works as get_vector_example does but returns for each list of words 2 examples:
  in this method the order in which data words are given is ignored and is computed the first example as defined before 
  and the second with data words in reverse order
- get get_list_of_vectors returns embedding for each of the given words
Getters raise exceptions defined before when one of the given word has no embedding
"""


class PreprocessingWord2VecEmbedding:
    def __init__(self, pretrained_embeddinds_path: str, binary: bool):
        self.model = KeyedVectors.load_word2vec_format(pretrained_embeddinds_path, binary=binary)
        # self.model.init_sims(replace=True)

    def get_vector_and_word(self, word: str):
        try:
            return self.model.word_vec(word), word
        except KeyError:
            return self.model.word_vec("entity"), "entity"

    def get_vector(self, word: str):
        try:
            return self.model.word_vec(word)
        except KeyError:
            return self.model.word_vec("entity")

    def get_vector_example(self, words):
        try:
            target, data = words[0], words[1:]
            vectors_example = {'target': self.model.word_vec(target),
                               'data': [self.get_vector(word) for word in data]}
            return vectors_example
        except KeyError:
            return None


class POSToIndex:
    tagset = ['ADJP', '-ADV', 'ADVP', '-BNF', 'CC', 'CD', '-CLF', '-CLR', 'CONJP', '-DIR', 'DT', '-DTV', 'EX', '-EXT',
              'FRAG', 'FW', '-HLN', 'IN', 'INTJ', 'JJ', 'JJR', 'JJS', '-LGS', '-LOC', 'LS', 'LST', 'MD', '-MNR', 'NAC',
              'NN', 'NNS', 'NNP', 'NNPS', '-NOM', 'NP', 'NX', 'PDT', 'POS', 'PP', '-PRD', 'PRN', 'PRP', '-PRP',
              'PRP$', 'PRP-S', 'PRT', '-PUT', 'QP', 'RB', 'RBR', 'RBS', 'RP', 'RRC', 'S', 'SBAR', 'SBARQ', '-SBJ',
              'SINV', 'SQ', 'SYM', '-TMP', 'TO', '-TPC', '-TTL', 'UCP', 'UH', 'VB', 'VBD', 'VBG', 'VBN', 'VBP', 'VBZ',
              '-VOC', 'VP', 'WDT', 'WHADJP', 'WHADVP', 'WHNP', 'WHPP', 'WP', 'WP$', 'WP-S', 'WRB', 'X', 'AFX', '#', '$',
              '-LRB-', '\"', '(', ')', ',', '.', ':', '``']

    @staticmethod
    def index(pos):
        try:
            return [POSToIndex.tagset.index(pos)]
        except ValueError as e:
            return [len(POSToIndex.tagset) - 1]


class POSAwarePreprocessingWord2VecEmbedding(PreprocessingWord2VecEmbedding):
    def __init__(self, pretrained_embeddinds_path: str, binary: bool, tagset):
        super().__init__(pretrained_embeddinds_path, binary)
        POSToIndex.tagset = tagset

    def _pos_to_index(self, pos):
        return POSToIndex.index(pos)

    def get_vector_example(self, words, pos_tags):
        vectors_example = super().get_vector_example(words)
        if vectors_example is None:
            return None

        for key in pos_tags:
            vectors_example[key] = self._pos_to_index(pos_tags[key])

        return vectors_example
