from nltk.corpus import wordnet as wn
from enum import Enum
from nltk.corpus import wordnet_ic
import nltk

from utility.words_in_synset import SynsetCouple

nltk.download('wordnet_ic')

class InformationContent:
    INFORMATION_CONTENT = wordnet_ic.ic('ic-brown.dat')

    @staticmethod
    def set_information_content(name):
        InformationContent.INFORMATION_CONTENT = wordnet_ic.ic(name)


class SimilarityFunction(Enum):
    path = wn.path_similarity
    lch = wn.lch_similarity
    wup = wn.wup_similarity

    res = lambda x, y: wn.res_similarity(x, y, InformationContent.INFORMATION_CONTENT)
    jcn = lambda x, y: wn.jcn_similarity(x, y, InformationContent.INFORMATION_CONTENT)
    lin = lambda x, y: wn.lin_similarity(x, y, InformationContent.INFORMATION_CONTENT)

    @staticmethod
    def name(similarity_function):
        if similarity_function == SimilarityFunction.path:
            return 'path'
        if similarity_function == SimilarityFunction.lch:
            return 'lch'
        if similarity_function == SimilarityFunction.wup:
            return 'wup'
        if similarity_function == SimilarityFunction.res:
            return 'res'
        if similarity_function == SimilarityFunction.jcn:
            return 'jcn'
        if similarity_function == SimilarityFunction.lin:
            return 'lin'

    @staticmethod
    def by_name(similarity_function_name):
        if similarity_function_name == 'path':
            return SimilarityFunction.path
        if similarity_function_name == 'lch':
            return SimilarityFunction.lch
        if similarity_function_name == 'wup':
            return SimilarityFunction.wup
        if similarity_function_name == 'res':
            return SimilarityFunction.res
        if similarity_function_name == 'jcn':
            return SimilarityFunction.jcn
        if similarity_function_name == 'lin':
            return SimilarityFunction.lin


class Comparator:
    def __init__(self, couples: list, similarity_function):
        self.couples = couples
        self.similarity_function = similarity_function

    def write_similarities(self, path, header):
        output = open(path, 'w+')
        output.write(header)
        output.writelines(self._get_string_similarities())
        output.close()

    def get_similarities(self):
        similarities = []
        for couple in self.couples:
            if type(couple) is SynsetCouple:
                similarities.append([couple.s1.name(), couple.s2.name(), couple.w1, couple.w2,
                                     str(self.similarity_function(couple.s1, couple.s2)), couple.s_pos])
            else:
                raise ValueError
        return similarities

    def _get_string_similarities(self):
        similarities = []
        for similarity in self.get_similarities():
            similarities.append('\t'.join(similarity + ['#\n']))
        return similarities
