from nltk.corpus import wordnet as wn
from nltk.corpus.reader import Synset


class WordInSynset:
    def __init__(self, word, synset_name, pos):
        self.word = word
        self.synset_name = synset_name
        self.pos = pos

    @staticmethod
    def from_word_and_pos(word, pos):
        ss = wn.synsets(word, pos=pos)
        if len(ss) > 0:
            synset_name = ss[0].name()
            return WordInSynset(word, synset_name, pos)
        else:
            return None

    def equals(self, s):
        return self.synset_name == s.synset_name and self.word == s.word

    def to_dict(self):
        return {
            'word': self.word,
            'synset_name': self.synset_name,
            'pos': self.pos
        }


class WordInSynsetWriter:
    @staticmethod
    def write(words_in_synsets, output_path):
        with open(output_path, 'w+') as output:
            header = '\t'.join(['WORD', 'SYN_NAME', 'POS', '#\n'])
            output.write(header)
            for s in words_in_synsets:
                output.write('\t'.join([s.word, s.synset_name, s.pos, '#\n']))


class WordInSynsetReader:
    @staticmethod
    def read(input_path, has_header=True):
        words_in_synsets = []
        with open(input_path, 'r') as input:
            lines = input.readlines()
            if has_header:
                lines = lines[1:]
            for line in lines:
                split = line.split('\t')
                words_in_synsets.append(WordInSynset(word=split[0], synset_name=split[1], pos=split[2]))
        return words_in_synsets


class SynsetCouple:
    def __init__(self, s1: Synset, w1, s2: Synset, w2, s_pos):
        self.s1 = s1
        self.w1 = w1
        self.s2 = s2
        self.w2 = w2
        self.s_pos = s_pos

    def to_dictionary(self):
        return {'s1': self.s1,
                'w1': self.w1,
                's2': self.s2,
                'w2': self.w2,
                's_pos': self.s_pos,
                }


class SynsetOOVCouple:
    def __init__(self, oov, synset_oov, first, second, synset_second, target_pos, w1_pos, w2_pos):
        self.synset_oov = wn.synset(synset_oov)
        self.oov = oov
        self.first = first
        self.synset_second = wn.synset(synset_second)
        self.second = second
        self.target_pos = target_pos
        self.w1_pos = w1_pos
        self.w2_pos = w2_pos

    def to_dictionary(self):
        return {'synset_oov': self.synset_oov,
                'oov': self.oov,
                'first': self.first,
                'synset_second': self.synset_second,
                'second': self.second,
                'target_pos': self.target_pos,
                'w1_pos': self.w1_pos,
                'w2_pos': self.w2_pos}

    def equals(self, other):
        if (self.synset_oov.name() != other.synset_oov.name() or
                self.oov != other.oov or
                self.first != other.first or
                self.synset_second.name() != other.synset_second.name() or
                self.second != other.second or
                self.target_pos != other.target_pos or
                self.w1_pos != other.w1_pos or
                self.w2_pos != other.w2_pos):
            return False
        else:
            return True


class SaverSynsetCouples:
    @staticmethod
    def save(couples, output_path, header):
        with open(output_path, 'w+') as output:
            output.write(header)
            for couple in couples:
                output.write('\t'.join([couple.s1.name(), couple.s2.name(),
                                        couple.w1, couple.w2, couple.s_pos, '#\n']))

    @staticmethod
    def append(couples, output_path, header):
        with open(output_path, 'a+') as output:
            output.write(header)
            for couple in couples:
                output.write('\t'.join([couple.s1.name(), couple.s2.name(),
                                        couple.w1, couple.w2, couple.s_pos, '#\n']))


class ReaderSynsetCouples:
    @staticmethod
    def read(input_path, s1_index=0, w1_index=2, s2_index=1, w2_index=3, s_pos_index=4, exclude_first=True):
        couples = []
        first = True
        with open(input_path, 'r') as input:
            while True:
                line = input.readline()
                if not line:
                    return couples

                if exclude_first and first:
                    first = False
                    continue

                split = line.split('\t')
                couples.append(SynsetCouple(s1=wn.synset(split[s1_index]), w1=split[w1_index],
                                            s2=wn.synset(split[s2_index]), w2=split[w2_index],
                                            s_pos=split[s_pos_index]))
        return couples


class ReaderSynsetOOVCouple:
    @staticmethod
    def read(input_path, s1_index=5, w1_index=1, s2_index=9, w2_index=10, first_indexes=[2, 4],
             s_pos_index=6, w1_pos=7, w2_pos=8, exclude_first=False):
        couples = []
        first = True
        with open(input_path, 'r') as input:
            while True:
                line = input.readline()
                if not line:
                    return couples

                if exclude_first and first:
                    first = False
                    continue

                split = line.split('\t')

                couples.append(SynsetOOVCouple(oov=split[w1_index], synset_oov=split[s1_index],
                                               first=split[first_indexes[0]:first_indexes[1]], second=split[w2_index],
                                               synset_second=split[s2_index], target_pos=split[s_pos_index],
                                               w1_pos=split[w1_pos], w2_pos=split[w2_pos]))
        return couples