import re
from itertools import chain

import nltk
import numpy

word_tags = ['CC', 'CD', 'DT', 'EX', 'FW', 'IN', 'JJ', 'JJR',
             'JJS', 'LS', 'MD', 'NN', 'NNS', 'NNP', 'NNPS', 'PDT',
             'POS', 'PRP', 'PRP$', 'RB', 'RBR', 'RBS', 'RP',
             'SYM', 'TO', 'UH', 'VB', 'VBD', 'VBG', 'VBN', 'VBP',
             'VBZ', 'WDT', 'WP', 'WP$', 'WRB']
currency_tags_words = ['#', '$', 'C$', 'A$']
ellipsis = ['*', '*?*', '0', '*T*', '*ICH*', '*U*', '*RNR*', '*EXP*',
            '*PPA*', '*NOT*']
punctuation_tags = ['.', ',', ':', '-LRB-', '-RRB-', '\'\'', '``']
punctuation_words = ['.', ',', ':', '-LRB-', '-RRB-', '\'\'', '``',
                     '--', ';', '-', '?', '!', '...', '-LCB-',
                     '-RCB-']
delated_tags = ['TOP', '-NONE-', ',', ':', '``', '\'\'']
sentence_labels = ['S', 'SBAR']


def precess_arc(label):
    labels = label.split('+')
    new_arc = []
    for l in labels:
        if l == 'ADVP':
            l = 'PRT'
        # if len(new_arc) > 0 and l == new_arc[-1]:
        #     continue
        new_arc.append(l)
    label = '+'.join(new_arc)
    return label


def process_NONE(tree):
    if isinstance(tree, nltk.Tree):
        label = tree.label()
        if label == '-NONE-':
            return None
        else:
            tr = []
            for node in tree:
                new_node = process_NONE(node)
                if new_node is not None:
                    tr.append(new_node)
            if tr == []:
                return None
            else:
                return nltk.Tree(label, tr)
    else:
        return tree


def dtree2ctree(tree, node_idx=None):
    if node_idx is None:
        node = tree.root
    else:
        node = tree.nodes[node_idx]

    index = node['address']
    children = list(chain.from_iterable(node['deps'].values()))
    if len(children) > 0:
        children_sorted = sorted(children)

        left_children = []
        right_children = []
        for c in children_sorted:
            if c < index:
                left_children.append(c)
            else:
                right_children.append(c)

        root = []

        for c in left_children:
            root.append(dtree2ctree(tree, c))
        root.append(node['word'])
        for c in right_children:
            root.append(dtree2ctree(tree, c))

        return root
    else:
        return node['word']


def ctree2distance(root, idx):
    if isinstance(root, list):
        dist_list = ctree2distance(root[0], idx)
        for child in root[1:]:
            dist = ctree2distance(child, idx + 1)
            dist_list.extend(dist)
        return dist_list
    else:
        return [idx]


def ctree2last(root, last):
    if isinstance(root, list):
        lasts = []
        for i, n in enumerate(root):
            sub_list = ctree2last(n, last)
            lasts.extend(sub_list)
            last = -len(lasts)
        return lasts
    else:
        return [last]


def last2distance(lasts):
    distances = [0]
    for last in lasts[1:]:
        if last == 0:
            distances.append(0)
        else:
            distances.append(distances[last] + 1)

    return distances


from collections import deque


def tree2list(tree, parent_arc=[], binary=False):
    if isinstance(tree, nltk.Tree):
        label = tree.label()
        if isinstance(tree[0], nltk.Tree):
            label = re.split('-|=', tree.label())[0]
        root_arc_list = parent_arc + [label]
        root_arc = '+'.join(root_arc_list)
        if len(tree) == 1:
            root, arc, tag = tree2list(tree[0], parent_arc=root_arc_list)
        elif len(tree) == 2:
            c0, arc0, tag0 = tree2list(tree[0])
            c1, arc1, tag1 = tree2list(tree[1])
            root = [c0, c1]
            arc = arc0 + [root_arc] + arc1
            tag = tag0 + tag1
        else:
            c0, arc0, tag0 = tree2list(tree[0])
            c1, arc1, tag1 = tree2list(nltk.Tree('<empty>', tree[1:]))
            if not binary:
                root = [c0] + c1
            else:
                root = [c0, c1]
            arc = arc0 + [root_arc] + arc1
            tag = tag0 + tag1
        return root, arc, tag
    else:
        if len(parent_arc) == 1:
            parent_arc.insert(0, '<empty>')
        # parent_arc[-1] = '<POS>'
        del parent_arc[-1]
        return str(tree), [], ['+'.join(parent_arc)]


def get_brackets(tree, start_idx=0, root=False):
    brackets = set()
    if isinstance(tree, list):
        end_idx = start_idx
        for node in tree:
            node_brac, next_idx = get_brackets(node, end_idx)
            brackets.update(node_brac)
            end_idx = next_idx
        if not root:
            brackets.add((start_idx, end_idx))
    else:
        end_idx = start_idx + 1

    return brackets, end_idx


def MRG(tr):
    if isinstance(tr, str):
        return '( %s )' % tr
        # return tr + ' '
    else:
        s = '('
        for subtr in tr:
            s += MRG(subtr) + ' '
        s += ')'
        return s


def MRG_labeled(tr):
    if isinstance(tr, nltk.Tree):
        if tr.label() in word_tags:
            return tr.leaves()[0] + ' '
        else:
            s = '(%s ' % (re.split(r'[-=]', tr.label())[0])
            for subtr in tr:
                s += MRG_labeled(subtr)
            s += ') '
            return s
    else:
        return ''


def build_tree(depth, sen):
    depth = depth
    queue = deque(sen)
    stack = [queue.popleft()]
    head = depth[0] - 1
    for point in depth[1:]:
        d = point - head
        if d > 0:
            for _ in range(d):
                if len(stack) == 1:
                    break
                x1 = stack.pop()
                x2 = stack.pop()
                stack.append([x2, x1])
        if len(queue) > 0:
            stack.append(queue.popleft())
            head = point - 1
    while len(stack) > 2 and isinstance(stack, list):
        x1 = stack.pop()
        x2 = stack.pop()
        stack.append([x2, x1])
    while len(stack) == 1 and isinstance(stack, list):
        stack = stack.pop()
    return stack


def build_nltktree(depth, arc, tag, sen, arcdict, tagdict, stagdict, stags=None):
    """stags are the stanford predicted tags present in the train/valid/test files.
    """
    assert len(sen) > 0
    assert len(depth) == len(sen) - 1, ("%s_%s" % (len(depth), len(sen)))
    if stags:
        assert len(stags) == len(tag)

    if len(sen) == 1:
        tag_list = str(tagdict[tag[0]]).split('+')
        tag_list.reverse()
        # if stags, put the real stanford pos TAG for the word and leave the
        # unary chain on top.
        if stags is not None:
            assert len(stags) > 0
            tag_list.insert(0, str(stagdict[stags[0]]))
        word = str(sen[0])
        for t in tag_list:
            word = nltk.Tree(t, [word])
        assert isinstance(word, nltk.Tree)
        return word
    else:
        idx = numpy.argmax(depth)
        node0 = build_nltktree(
            depth[:idx], arc[:idx], tag[:idx + 1], sen[:idx + 1],
            arcdict, tagdict, stagdict, stags[:idx + 1] if stags else None)
        node1 = build_nltktree(
            depth[idx + 1:], arc[idx + 1:], tag[idx + 1:], sen[idx + 1:],
            arcdict, tagdict, stagdict, stags[idx + 1:] if stags else None)

        if node0.label() != '<empty>' and node1.label() != '<empty>':
            tr = [node0, node1]
        elif node0.label() == '<empty>' and node1.label() != '<empty>':
            tr = [c for c in node0] + [node1]
        elif node0.label() != '<empty>' and node1.label() == '<empty>':
            tr = [node0] + [c for c in node1]
        elif node0.label() == '<empty>' and node1.label() == '<empty>':
            tr = [c for c in node0] + [c for c in node1]

        arc_list = str(arcdict[arc[idx]]).split('+')
        arc_list.reverse()
        for a in arc_list:
            if isinstance(tr, nltk.Tree):
                tr = [tr]
            tr = nltk.Tree(a, tr)

        return tr
