from nltk.corpus import wordnet as wn
from nltk.corpus.reader import Synset
from enum import Enum
from nltk.corpus import wordnet_ic
import nltk

nltk.download('wordnet_ic')


class InformationContent:
    INFORMATION_CONTENT = wordnet_ic.ic('ic-brown.dat')

    @staticmethod
    def set_information_content(name):
        InformationContent.INFORMATION_CONTENT = wordnet_ic.ic(name)


class SimilarityFunction(Enum):
    path = wn.path_similarity
    lch = wn.lch_similarity
    wup = wn.wup_similarity

    res = lambda x, y: wn.res_similarity(x, y, InformationContent.INFORMATION_CONTENT)
    jcn = lambda x, y: wn.jcn_similarity(x, y, InformationContent.INFORMATION_CONTENT)
    lin = lambda x, y: wn.lin_similarity(x, y, InformationContent.INFORMATION_CONTENT)

    @staticmethod
    def name(similarity_function):
        if similarity_function == SimilarityFunction.path:
            return 'path'
        if similarity_function == SimilarityFunction.lch:
            return 'lch'
        if similarity_function == SimilarityFunction.wup:
            return 'wup'
        if similarity_function == SimilarityFunction.res:
            return 'res'
        if similarity_function == SimilarityFunction.jcn:
            return 'jcn'
        if similarity_function == SimilarityFunction.lin:
            return 'lin'

    @staticmethod
    def by_name(similarity_function_name):
        if similarity_function_name == 'path':
            return SimilarityFunction.path
        if similarity_function_name == 'lch':
            return SimilarityFunction.lch
        if similarity_function_name == 'wup':
            return SimilarityFunction.wup
        if similarity_function_name == 'res':
            return SimilarityFunction.res
        if similarity_function_name == 'jcn':
            return SimilarityFunction.jcn
        if similarity_function_name == 'lin':
            return SimilarityFunction.lin


class SynsetOOVCouple:
    def __init__(self, oov, synset_oov, first, second, synset_second, target_pos, w1_pos, w2_pos):
        self.synset_oov = wn.synset(synset_oov)
        self.oov = oov
        self.first = first
        self.synset_second = wn.synset(synset_second)
        self.second = second
        self.target_pos = target_pos
        self.w1_pos = w1_pos
        self.w2_pos = w2_pos

    def to_dictionary(self):
        return {'synset_oov': self.synset_oov,
                'oov': self.oov,
                'first': self.first,
                'synset_second': self.synset_second,
                'second': self.second,
                'target_pos': self.target_pos,
                'w1_pos': self.w1_pos,
                'w2_pos': self.w2_pos}

    def equals(self, other):
        if (self.synset_oov.name() != other.synset_oov.name() or
                self.oov != other.oov or
                self.first != other.first or
                self.synset_second.name() != other.synset_second.name() or
                self.second != other.second or
                self.target_pos != other.target_pos or
                self.w1_pos != other.w1_pos or
                self.w2_pos != other.w2_pos):
            return False
        else:
            return True


class ReaderSynsetOOVCouple:
    @staticmethod
    def read(input_path, s1_index=5, w1_index=1, s2_index=9, w2_index=10, first_indexes=[2, 4],
             s_pos_index=6, w1_pos=7, w2_pos=8, exclude_first=False):
        couples = []
        first = True
        with open(input_path, 'r') as input:
            while True:
                line = input.readline()
                if not line:
                    return couples

                if exclude_first and first:
                    first = False
                    continue

                split = line.split('\t')

                couples.append(SynsetOOVCouple(oov=split[w1_index], synset_oov=split[s1_index],
                                               first=split[first_indexes[0]:first_indexes[1]], second=split[w2_index],
                                               synset_second=split[s2_index], target_pos=split[s_pos_index],
                                               w1_pos=split[w1_pos], w2_pos=split[w2_pos]))
        return couples


class SisterOOVPair:
    def __init__(self, couple):
        self.target_synset = str(couple['target_synset'])
        self.target_word = str(couple['target_word'])
        self.data_w1 = str(couple['data_w1'])
        self.data_w2 = str(couple['data_w2'])
        self.target_pos = str(couple['target_pos'])
        self.w1_pos = str(couple['w1_pos'])
        self.w2_pos = str(couple['w2_pos'])
        self.definition = str(couple['definition'])
        self.sister_synset = str(couple['sister_synset'])
        self.sister_word = str(couple['sister_word'])

    def to_list(self):
        return [self.target_synset, self.target_word, self.data_w1, self.data_w2, self.target_pos, self.w1_pos,
                self.w2_pos, self.definition, self.sister_synset, self.sister_word]


class SisterInVocPair:
    def __init__(self, couple):
        self.target_synset = str(couple['S1'])
        self.target_word = str(couple['W1'])
        self.target_pos = str(couple['S1_POS'])
        self.sister_synset = str(couple['S2'])
        self.sister_word = str(couple['W2'])

    def to_list(self):
        return [self.target_synset, self.target_word, self.target_pos, self.sister_synset, self.sister_word]