import re
import sys
from itertools import chain

from span import spans_in_span
from mark_edge_with_label import regex_not_word

LEMMA_VARIABLE = -1
W_VARIABLE = -2
SENSE_VARIABLE = -3

# v_be_forms = ['be', 'been', 'has been', 'had been', 'have been',
#               'are', 'is', 'am', 'will be', 'was', 'were',
#               'being', 'is being', 'are being']

be_1_forms = ['is', 'are', 'am', 'was', 'were', 'being', 'be',
              'been', '\'s', '\'re', '\'m']
v_tense_forms = ['has', 'have', 'had', 'will', 'having', 'shall',
                 'would', '\'d', '\'ll', '\'ve']
w_forms = ['who', 'what', 'which', 'where', 'when', 'whom', 'that']
do_forms = ['did', 'does']


def equations_var_pos(eqs, var):
    for i, eq in enumerate(eqs):
        for j, v in enumerate(eq):
            if v == var:
                return (i, j)


def equations_lemma_pos(eqs):
    return equations_var_pos(eqs, LEMMA_VARIABLE)


def equations_sense_pos(eqs):
    return equations_var_pos(eqs, SENSE_VARIABLE)


def is_normal_variable(var):
    return isinstance(var, tuple)


def simplify(equation, node, string):
    i, n = 0, len(equation)
    new_eq = []
    sense = node['sense']
    while i < n:
        var = equation[i]
        if not isinstance(var, str):
            new_eq.append(var)
            i += 1
            continue

        i += 1
        if var in v_tense_forms or var == '':
            pass
        elif var in do_forms:
            new_eq.append('do')
        elif var in be_1_forms:
            if i < n:
                var = equation[i]
                if var == LEMMA_VARIABLE and \
                   node['pos'] in ['v', 'u'] and \
                   string.endswith('ing'):
                    # 进行时, 忽略 be
                    i += 1
                    new_eq.append(LEMMA_VARIABLE)
                    continue
                elif var == 'being':
                    i += 1
            new_eq.append('be')
        elif var in w_forms:
            new_eq.append(W_VARIABLE)
        elif sense and sense.find(var) >= 0:
            # 如果连续的字符串都在 sense 里出现
            while i < n:
                var = equation[i]
                if not isinstance(var, str) or not sense.find(var) >= 0:
                    break
                i += 1
            new_eq.append(SENSE_VARIABLE)
            node['sense-used'] = True
        else:
            new_eq.append(var)
    return new_eq


def equations_information(in_edges, out_edges, uttr):
    all_spans = []
    output_spans = [(0, len(uttr))]
    output_index = -1
    # 收集所有的 span
    for i, (_, spans, is_output) in enumerate(chain(in_edges, out_edges)):
        if spans is None:
            continue
        if is_output:
            output_spans = spans
            output_index = i
        else:
            all_spans.extend((span, (i, j)) for j, span in enumerate(spans))

    all_spans.sort()
    return all_spans, output_spans, output_index


def get_strings(string):
    for s in string.strip().lower().split():
        s = re.sub(regex_not_word, '', s)
        found = False
        if s == 'it\'s':
            found = True
            yield 'it'
            yield 'be'
        elif s.startswith('there\''):
            found = True
            yield 'there'
            yield s[5:]
        else:
            for w in w_forms:
                if s.startswith(w + '\''):
                    found = True
                    yield w
                    yield s[len(w):]
        if not found:
            yield s


def make_equations(in_edges, out_edges, uttr, node, after_lemma_var=None):
    all_spans, output_spans, output_index = \
        equations_information(in_edges, out_edges, uttr)

    obeg, oend = node['span']
    for j, (beg, end) in enumerate(output_spans):
        # 选出覆盖的 spans
        spans = spans_in_span(all_spans, (beg, end), key=lambda x: x[0])
        spans.append(((end, end), None))
        equation = [(output_index, j)]

        for span, index in spans:
            end = span[0]
            if end > beg:
                if obeg >= beg and oend <= end:
                    equation.extend(get_strings(uttr[beg:obeg]))
                    equation.append(LEMMA_VARIABLE)  # self
                    if after_lemma_var:
                        equation.append(after_lemma_var)
                    equation.extend(get_strings(uttr[oend:end]))
                else:
                    equation.extend(get_strings(uttr[beg:end]))
            if index:
                equation.append(index)
            beg = span[1]
        yield tuple(simplify(equation, node,
                             uttr[obeg:oend].strip().lower()))


def print_equations(equations, out=sys.stdout):
    for (ti, tj), *sources in equations:
        if ti < 0:
            tar = 'G'
        else:
            tar = _format_variable(ti, tj)
        eq = []
        for src in sources:
            if isinstance(src, str):
                eq.append('"' + src.replace('"', '\\"') + '"')
            elif src == LEMMA_VARIABLE:
                eq.append('L')
            elif src == W_VARIABLE:
                eq.append('W')
            elif src == SENSE_VARIABLE:
                eq.append('S')
            else:
                eq.append(_format_variable(*src))
        out.write(tar)
        out.write(' = ')
        out.write(' + '.join(eq))
        out.write('\n')


def make_rule(label, in_states, out_states):
    return label, tuple(in_states), tuple(out_states)


def print_rule(rule, labels, out=sys.stdout):
    # out.write(self.filename)
    out.write('{')
    out.write(_format_multiset(rule[1], labels))
    out.write('} => ')
    out.write(rule[0])
    out.write(' => {')
    out.write(_format_multiset(rule[2], labels, len(rule[1])))
    out.write('}\n')


def _format_multiset(states, labels, start_index=0):
    return ', '.join(_format_state(labels[index], i, num)
                     for i, (index, num) in enumerate(states, start_index))


def _format_variable(i, j):
    return chr(i + 97) + str(j)


def _format_state(label, i, num):
    return '{}({})'.format(label,
                           ', '.join(_format_variable(i, j)
                                     for j in range(num)))
