from gensim.models import KeyedVectors
from gensim.models.keyedvectors import FastTextKeyedVectors
from nltk.corpus import wordnet as wn
import os
from enum import Enum

from pytorch_transformers import BertTokenizer

from utility.words_in_synset import WordInSynset, WordInSynsetWriter, WordInSynsetReader
from gensim.models.wrappers import FastText


class PretrainedEmbeddingModel(Enum):
    w2v = 0
    bert = 1
    fasttext = 2


class Checker:
    name_to_type_map = {'bert-base-uncased':  PretrainedEmbeddingModel.bert,
                        'GoogleNews-vectors-negative300.bin': PretrainedEmbeddingModel.w2v,
                        'cc.en.300.bin': PretrainedEmbeddingModel.fasttext}

    def get_OOV(self, test_words: list):
        oov = {}
        for word in test_words:
            if not self.is_in_vocabulary(word):
                oov[word] = 1
        return oov.keys()

    def get_vocabulary(self, test_words: list):
        in_voc = {}
        for word in test_words:
            if self.is_in_vocabulary(word):
                in_voc[word] = 1
        return in_voc.keys()

    def is_in_vocabulary(self, word):
        raise NotImplementedError("Not supported. Choose a subclass and instantiate an object of it")

    @staticmethod
    def name_to_type(name_to_type_map):
        for name in name_to_type_map:
            if name_to_type_map[name] not in [item.value for item in PretrainedEmbeddingModel]:
                raise KeyError("This value isn't mapped to a known PretrainedEmbeddingModel.")
        Checker.name_to_type_map = name_to_type_map

    @staticmethod
    def get_instance_from_path(path, binary=None):
        name = os.path.basename(path)
        if Checker.name_to_type_map[name] == PretrainedEmbeddingModel.w2v:
            return W2VChecker(path, binary=True)

        if Checker.name_to_type_map[name] == PretrainedEmbeddingModel.fasttext:
            return FastTextChecker(path)

        if Checker.name_to_type_map[name] == PretrainedEmbeddingModel.bert:
            return BertChecker(path)

        raise NotImplementedError("Not supported. Static field name_to_type_map must be customized")


class KeyedVectorChecker(Checker):
    def __init__(self, pretrained_embeddings_path, binary: bool):
        self.model = KeyedVectors.load_word2vec_format(pretrained_embeddings_path, binary=binary)
        self.model.vocab
        # self.model.init_sims(replace=True)

    def is_in_vocabulary(self, word):
        return word in self.model.vocab


class FastTextChecker(Checker):
    def __init__(self, pretrained_embeddings_path):
        model = FastText.load_fasttext_format(pretrained_embeddings_path)
        self.model: FastTextKeyedVectors = model.wv

    def is_in_vocabulary(self, word):
        return word in self.model.vocab


class W2VChecker(KeyedVectorChecker):
    def __init__(self, pretrained_embeddings_path, binary: bool):
        super(W2VChecker, self).__init__(pretrained_embeddings_path, binary)

class BertChecker(Checker):
    def __init__(self, tokenizer):
        if not isinstance(tokenizer, str):
            self.model = tokenizer
        else:
            self.model = BertTokenizer.from_pretrained(tokenizer)

    def is_in_vocabulary(self, word):
        return word in self.model.vocab

    def get_vocabulary(self, test_words: list):
        return [x for x in self.model.vocab]

class WNManager:
    def __init__(self):
        self.all_synsets = wn.all_synsets

    @staticmethod
    def is_expression(lemma: str):
        if lemma.find('_') == -1 or lemma.find('-') == -1:
            return False
        else:
            return True

    @staticmethod
    def is_special_token(lemma: str):
        if lemma.startswith('##') or lemma.startswith('['):
            return True
        else:
            return False

    def lemma_from_synsets(self, allow_expression: bool):
        wn_lemmas = {}
        for ss in self.all_synsets():
            for lemma in ss.lemma_names():
                if allow_expression or not WNManager.is_expression(lemma):
                    wn_lemmas[lemma] = 0
        return list(wn_lemmas)


def find_oov_and_synset(name, binary=None, pos_tags=None, output_path='oov_in_synset.txt'):
    if pos_tags is None:
        pos_tags = ['n', 'v']

    checker = Checker.get_instance_from_path(name, binary=binary)

    wn_manager = WNManager()
    words = wn_manager.lemma_from_synsets(allow_expression=False)
    oov_list = checker.get_OOV(words)

    words_in_synsets = []
    for oov in oov_list:
        for pos in pos_tags:
            s = WordInSynset.from_word_and_pos(oov, pos)
            if s is not None:
                words_in_synsets.append(s)

    WordInSynsetWriter.write(words_in_synsets, output_path)

    return words_in_synsets


def find_voc_synset(pretrained_embeddings_name, pos_tags=None, output_path='vocabulary_in_synset.txt'):
    if pos_tags is None:
        pos_tags = ['n', 'v', 'a']

    checker = Checker.get_instance_from_path(pretrained_embeddings_name)
    words_in_synsets = []

    for word in checker.model.vocab:
        if not WNManager.is_expression(word) and not WNManager.is_special_token(word):
            for pos in pos_tags:
                synsets = wn.synsets(word, pos=pos)
                if len(synsets) > 0:
                    s = WordInSynset(word, synsets[0].name(), pos)
                    words_in_synsets.append(s)

    WordInSynsetWriter.write(words_in_synsets, output_path)

    return words_in_synsets


def collect_words_in_synset(model_name, output_dir):
    if not os.path.exists(output_dir):
        os.mkdir(output_dir)
    find_voc_synset(model_name, output_path=os.path.join(output_dir, 'vocabulary_in_synset.txt'))
    find_oov_and_synset(model_name, output_path=os.path.join(output_dir, 'oov_in_synset.txt'))


def collected_vocabulary_words_in_synset(model_name, output_path, pos_tags=None):
    if not os.path.exists(output_path):
        return find_voc_synset(model_name, pos_tags=pos_tags, output_path=output_path)

    return WordInSynsetReader.read(output_path)
