from nltk.corpus import wordnet as wn

from defiNNet.parsing.parse_tree import HeadAwareParentedTree


class WordSynset:
    def __init__(self, word, synset_name, pos, definition):
        self.word = word
        self.synset = wn.synset(synset_name)
        self.pos = pos
        self.definition = definition


class ParsedDefinition:
    def __init__(self, word, synset_name, pos, definition):
        self.word = word
        self.word_syn = WordSynset(word, synset_name, pos, definition)
        self.definition_tree = HeadAwareParentedTree.fromstring(definition)

    def definition(self):
        return self.word_syn.definition

    def word(self):
        return self.word_syn.word

    def synset(self):
        return self.word_syn.synset

    def pos(self):
        return self.word_syn.pos

    def first_phrase(self, max_n=2):
        sentence = self.definition_tree[0]
        first_child = sentence.first_composite_constituent()
        if isinstance(first_child, HeadAwareParentedTree):
            constituents = [constituent for constituent in first_child if constituent.label() != 'DT']
            if len(constituents) >= max_n:
                return sentence, first_child, constituents

        return None

    def first_phrase_summarized(self, max_n=2):
        first = self.first_phrase(max_n=max_n)

        if first is None:
            return None
        sentence, first_child, constituents = first
        heads = []
        labels = []
        i = 0
        for c in constituents:
            if i >= max_n:
                break
            head = c.head()

            if head is None:
                print("HEAD IS NONE")
                return None
            else:
                label = head.label()
                heads.append(head.leaves()[0])
                labels.append(label)
                i += 1

        return sentence.label(), first_child.label(), labels, heads



class DefAnalyzer:
    def __init__(self, parser):
        self.parser = parser

    def analyze(self, word, category=None):
        if category == 'j':
            category = 'a'

        if len(wn.synsets(word, pos=category)) > 0:
            syn = wn.synsets(word, pos=category)[0]
        elif len(wn.synsets(word)) > 0:
            syn = wn.synsets(word)[0]
        else:
            print("word with no synset:", word)
            syn = wn.synsets("entity")[0]

        p_definition, pos_word = self.parser.parse(word, syn.definition(), category=category)
        parsed_definition = ParsedDefinition(word, synset_name=syn.name(),
                                             definition=p_definition, pos=pos_word)

        summarization = parsed_definition.first_phrase_summarized()
        if summarization is None:
            w1 = syn.definition().split()[0]
            if len(syn.definition().split()) > 1:
                w2 = syn.definition().split()[1]
            else:
                w2 = syn.definition().split()[0]

            _, target_pos = self.parser.pos(word)[0]
            _, w1_pos = self.parser.pos(w1)[0]
            _, w2_pos = self.parser.pos(w2)[0]
            return [word, w1, w2], {'target_pos': target_pos, 'w1_pos': w1_pos, 'w2_pos': w2_pos}

        sentence_label, first_child_label, labels, heads = summarization
        return [word, heads[0], heads[1]], {'target_pos': pos_word, 'w1_pos': labels[0], 'w2_pos': labels[1]}
