import os

from utility.randomfixedseed import Random
from utility.words_in_synset import SynsetCouple, SaverSynsetCouples

from utility.word_in_vocabulary import WNManager, Checker
from nltk.corpus import wordnet as wn


class Picker:
    def __init__(self, checker):
        self.checker = checker

        self.ALL_NAMES = [x for x in wn.all_synsets('n')]
        self.ALL_VERBS = [x for x in wn.all_synsets('v')]

    def pick_from(self, s1, w1, similar=True):
        if similar:
            return self._similar_word_to(s1, w1)
        else:
            return self._dissimilar_word_to(s1, w1)

    def _similar_word_to(self, s1, w1):
        hypernyms = s1.hypernyms()
        if len(hypernyms) == 0:
            return None, None

        # see hypernyms_sister_term_choice file to justify this
        sister_synss = hypernyms[0].hyponyms()
        if s1 in sister_synss:
            sister_synss.remove(s1)

        if len(sister_synss) == 0:
            return None, None
        s2 = Random.randomchoice(sister_synss)
        in_voc = [lemma for lemma in s2.lemma_names() if lemma != w1 and
                  not WNManager.is_expression(lemma) and self.checker.is_in_vocabulary(lemma)]

        if len(in_voc) == 0:
            return None, None
        w2 = Random.randomchoice(in_voc)
        return s2, w2

    def _dissimilar_word_to(self, s1, w1):
        if s1.pos() == wn.NOUN:
            syns = self.ALL_NAMES
        else:
            syns = self.ALL_VERBS

        self_found = False
        i = 0
        while i < 7:
            if len(syns) == 0:
                return None, None
            s2 = Random.randomchoice(syns)
            if s1 == s2:
                self_found = True
                syns.remove(s2)
                continue

            syns.remove(s2)

            in_voc = [x for x in s2.lemma_names() if x != w1 and
                      not WNManager.is_expression(x) and self.checker.is_in_vocabulary(x)]
            if len(in_voc) != 0:
                w2 = Random.randomchoice(in_voc)
                return s2, w2
            i += 1

        if self_found:
            syns.append(s1)
        return None, None


def get_couples_from(words, picker: Picker, similar=True, output_path=None):
    couples = []
    for w1 in words:
        if not WNManager.is_expression(w1):
            for pos in ['n', 'v']:
                ss = wn.synsets(w1, pos=pos)
                if len(ss) > 0:
                    s1 = ss[0]
                    s2, w2 = picker.pick_from(s1, w1, similar=similar)
                    if s2 is not None:
                        couples.append(SynsetCouple(s1, w1, s2, w2, s1.pos()))
                        #print(similar, len(couples))
    print(output_path, similar, len(couples))
    if output_path is not None:
        header = '\t'.join(['S1', 'S2', 'W1', 'W2', 'S1_POS', '#\n'])
        SaverSynsetCouples.save(couples, output_path, header)
    return couples


def voc_sim(couples_output_dir='data/similarity_pedersen_test/sister_terms',
            model_name=None, binary=True):
    checker = Checker.get_instance_from_path(model_name, binary=binary)
    picker = Picker(checker)

    positive_couples = get_couples_from(checker.model.vocab, picker=picker, similar=True,
                                        output_path=couples_output_dir + '/in_voc_sister_terms_positive.txt')
    negative_couples = get_couples_from(checker.model.vocab, picker=picker, similar=False,
                                        output_path=couples_output_dir + '/in_voc_sister_terms_negative.txt')
    return positive_couples, negative_couples


def oov_sim(couples_output_dir='data/similarity_pedersen_test/sister_terms',
            model_name=None, binary=True):
    wn_manager = WNManager()
    checker = Checker.get_instance_from_path(model_name, binary=binary)
    picker = Picker(checker)

    oovs = checker.get_OOV(wn_manager.lemma_from_synsets(allow_expression=False))
    positive_couples = get_couples_from(oovs, picker=picker, similar=True,
                                        output_path=couples_output_dir + '/oov_sister_terms_positive.txt')
    negative_couples = get_couples_from(oovs, picker=picker, similar=False,
                                        output_path=couples_output_dir + '/oov_sister_terms_negative.txt')

    return positive_couples, negative_couples


def compute_sister_terms(seeds, sister_terms_output_path,
                         model_name='data/pretrained_embeddings/GoogleNews-vectors-negative300.bin',
                         check_if_computed=True):
    for seed in seeds:
        Random.set_seed(int(seed))
        seed_dir = 'seed_' + seed
        couples_output_dir = os.path.join(sister_terms_output_path, seed_dir)

        if not os.path.exists(couples_output_dir):
            os.mkdir(couples_output_dir)

        files = os.listdir(couples_output_dir)
        found = False
        for f in files:
            if f.startswith('voc_sister_terms'):
                found = True
                break

        if not check_if_computed or not found:
            voc_sim(couples_output_dir=couples_output_dir,
                    model_name=model_name)


def compute_oov_sister_terms(seeds, sister_terms_output_path,
                             model_name='data/pretrained_embeddings/GoogleNews-vectors-negative300.bin',
                             check_if_computed=True):
    for seed in seeds:
        Random.set_seed(int(seed))
        seed_dir = 'seed_' + seed
        couples_output_dir = os.path.join(sister_terms_output_path, seed_dir)

        if not os.path.exists(couples_output_dir):
            os.mkdir(couples_output_dir)

        files = os.listdir(couples_output_dir)
        found = False
        for f in files:
            if f.startswith('oov_sister_terms'):
                found = True
                break

        if not check_if_computed or not found:
            oov_sim(couples_output_dir=couples_output_dir,
                    model_name=model_name)
