from json import JSONDecodeError
from nltk.tree import ParentedTree
from defiNNet.parsing.stanfordNLP import StanfordNLP

def collinsrule():
    rules = {}

    KEYS = ['dir', 'candidate']
    LEFT = 'LEFT'
    RIGHT = 'RIGHT'

    rules['ADJP'] = {KEYS[0]: LEFT,
                     KEYS[1]: ['NNS', 'QP', 'NN', 'ADVP', 'JJ', 'VBN', 'VBG', 'ADJP', 'JJR', 'NP',
                               'JJS', 'DT', 'FW', 'RBR', 'RBS', 'SBAR', 'RB']}
    rules['ADVP'] = {KEYS[0]: RIGHT,
                     KEYS[1]: ['RB', 'RBR', 'RBS', 'FW', 'ADVP', 'TO', 'CD', 'JJR', 'JJ', 'IN', 'NP', 'JJS',
                               'NN']}
    rules['CONJP'] = {KEYS[0]: RIGHT, KEYS[1]: ['CC', 'RB', 'IN']}

    rules['PP'] = {KEYS[0]: LEFT, KEYS[1]: ['NP', 'S', 'PP', 'ADJP', 'VBG', 'VBN', 'RP', 'FW', 'IN', 'TO']}
    rules['S'] = {KEYS[0]: LEFT, KEYS[1]: ['IN', 'VP', 'S', 'SBAR', 'ADJP', 'UCP', 'NP']}
    rules['UCP'] = {KEYS[0]: LEFT, KEYS[1]: ['IN', 'VP', 'S', 'SBAR', 'ADJP', 'UCP', 'NP']}
    rules['S1'] = {KEYS[0]: LEFT, KEYS[1]: ['S', 'TO', 'IN', 'VP', 'SBAR', 'ADJP', 'UCP', 'NP']}
    rules['SBAR'] = {KEYS[0]: LEFT,
                     KEYS[1]: ['S', 'WHNP', 'WHPP', 'WHADVP', 'WHADJP', 'DT', 'SQ', 'SINV', 'SBAR',
                               'FRAG', 'VP', 'NP',  'IN']}
    rules['FRAG'] = {KEYS[0]: LEFT,
                     KEYS[1]: ['S', 'WHNP', 'WHPP', 'WHADVP', 'WHADJP', 'DT', 'SQ', 'SINV', 'SBAR',
                               'FRAG',  'VP', 'NP', 'IN']}

    rules['VP'] = {KEYS[0]: LEFT,
                   KEYS[1]: ['VBD', 'VBN', 'MD', 'VBZ', 'VB', 'VBG', 'VBP', 'VP', 'ADJP', 'TO', 'NN', 'NNS',
                             'NP']}
    rules['WHADJP'] = {KEYS[0]: LEFT, KEYS[1]: ['CC', 'WRB', 'JJ', 'ADJP']}
    rules['WHADVP'] = {KEYS[0]: RIGHT, KEYS[1]: ['CC', 'WRB']}
    rules['WHNP'] = {KEYS[0]: LEFT, KEYS[1]: ['WDT', 'WP', 'WP', 'WHADJP', 'WHPP', 'WHNP']}
    rules['WHPP'] = {KEYS[0]: RIGHT, KEYS[1]: ['IN', 'TO', 'FW']}

    rules['NP'] = {KEYS[0]: LEFT, KEYS[1]: ['NN', 'NNP', 'NNPS', 'NNS', 'NX', 'POS', 'PRP', 'NP', 'JJR', 'JJ', 'VBG']}
    rules['NML'] = {KEYS[0]: LEFT, KEYS[1]: ['NN', 'NNP', 'NNPS', 'NNS', 'NX', 'POS', 'PRP', 'NP', 'JJR', 'JJ', 'VBG']}

    return rules


class HeadAwareParentedTree(ParentedTree):
    def __init__(self, node, children=None):
        super(HeadAwareParentedTree, self).__init__(node, children)
        self.rules = collinsrule()

    def has_label(self, label):
        return str(self.label()).startswith(label)

    def couples_having_label(self, label1, label2):
        return self._recursive_couples_having_label(label1, label2, [])

    def _recursive_couples_having_label(self, label1, label2, siblings):
        for child in self:
            if type(child) is str:
                continue
            if child.has_label(label1):
                for descendant in child:
                    if type(descendant) is HeadAwareParentedTree and descendant.has_label(label2):
                        if label2 == 'JJ':
                            siblings.append((descendant, child))
                        else:
                            siblings.append((child, descendant))

            siblings = child._recursive_couples_having_label(label1, label2, siblings)
        return siblings

    def head(self):
        children = [x for x in self]

        if len(children) == 1 and type(children[0]) is str:
            return self

        label = str(self.label())
        if label in self.rules:
            if self.rules[label]['dir'] == 'RIGHT':
                children = reversed(children)

            for candidate in self.rules[label]['candidate']:
                for child in children:
                    if child.has_label(candidate):
                        head = child.head()
                        if head is not None:
                            return head

        for child in self:
            head = child.head()
            if head is not None:
                return head

    def lemma_head(self):
        head = self.head()
        if head is not None:
            return str(head.leaves()[0])
        else:
            return None

    def first_composite_constituent(self):
        if len(self) == 1:
            if isinstance(self[0], HeadAwareParentedTree):
                first = self[0].first_composite_constituent()
                if first is not None:
                    return first
            else:
                return None

        if len(self) > 1:
            constituents = [constituent for constituent in self if isinstance(constituent, HeadAwareParentedTree)]

            labels = [x.label() for x in constituents]
            if 'CC' not in labels and ':' not in labels and '-LRB-' not in labels and \
                    'FRAG' not in labels and 'S' not in labels and 'SBAR' not in labels:
                if constituents[0].label() == 'S' or constituents[0].label() == 'SQ':
                    return constituents[0].first_composite_constituent()
                if constituents[0].label() == 'TO':
                    return constituents[1].first_composite_constituent()
                return self
            else:
                if 'S' in labels:
                    if self.label() != 'VP':
                        first = constituents[labels.index('S')].first_composite_constituent()
                        if first is not None:
                            return first
                    return self

                if 'FRAG' in labels:
                    return constituents[labels.index('FRAG')].first_composite_constituent()
                if 'SBAR' in labels:
                    return constituents[labels.index('SBAR')].first_composite_constituent()
                if 'CC' in labels:
                    i = labels.index('CC')
                    if len(constituents) > i + 1:
                        first = constituents[i + 1].first_composite_constituent()
                        if first is not None:
                            return first
                        else:
                            return self

                    first = constituents[0].first_composite_constituent()
                    if first is not None:
                        return first
                if ':' in labels:
                    i = labels.index(':')
                    first = constituents[i - 1].first_composite_constituent()
                    if first is not None:
                        return first
                    else:
                        return self
                if '-LRB-' in labels:
                    i = labels.index('-LRB-')
                    first = constituents[i - 1].first_composite_constituent()
                    if first is not None:
                        return first




class Parser:
    def __init__(self):
        self.sNLP = StanfordNLP()

    def pos(self, word):
        return self.sNLP.pos(word)

    def parse(self, word, definition, category=None):
            if category is not None and category == 'A':
                category = 'J'

            word, pos = self.sNLP.pos(word)[0]
            if category is not None and not pos.startswith(category):
                if category == 'V':
                    pos = 'VB'
                if category == 'N':
                    pos = 'NN'
                if category == 'J':
                    pos == 'JJ'

            if definition is not None and definition != '':
                try:
                    definition = self.sNLP.parse(definition)
                    return definition, pos
                except JSONDecodeError:
                    print("JSONDecodeError")