from itertools import chain, combinations, product, repeat

from extract_rules import is_empty_state
from rule import (is_normal_variable,
                  equations_lemma_pos,
                  equations_sense_pos)


def is_label_with_pos(label, pos):
    return label[0] == '_' and label[1] in pos


def is_noun_label(label):
    return is_label_with_pos(label, 'nu') or label == 'named'


def is_verb_label(label):
    return is_label_with_pos(label, 'va')


SPECIAL_STATES_BEFORE = [
    ('poss:ARG1', None),
    ('card:ARG1', None),
    ('basic_card:ARG1', None),
    ('AP:ARG1', [is_noun_label]),
    ('ADJ:ARG1', [is_noun_label]),
    ('ADV:ARG1', None),
    # compound 也可以指向动词, 这个时候的动词是动名词形态, 不考虑这种情况
    ('compound:ARG1',  [is_noun_label]),
    ('NP:ARG1', None)
]

SPECIAL_STATES_AFTER = [
    ('NP:ARG2', None),
    ('appos:ARG1', None),
    ('PP:ARG1', [is_verb_label, is_noun_label]),
    ('loc_nonsp:ARG1', [is_verb_label, is_noun_label])
]

SPECIAL_STATES_STRINGS = [
    ('poss:ARG1', 1),
    ('AP:ARG1', 3),
    ('compound:ARG1', 3),
    ('PP:ARG1', 3),
    ('ADJ:ARG1', 3),
    ('ADV:ARG1', 3),
    ('card:ARG1', 1),
    ('basic_card:ARG1', 1),
    ('appos:ARG1', 1),
    ('P:ARG1', 3),
    ('loc_nonsp:ARG1', 2)
]


def subsets(s, from_size=1):
    return chain.from_iterable(combinations(s, i)
                               for i in range(from_size, len(s) + 1))


def subsets_multi(s, from_size=1):
    for subset in subsets(s, from_size):
        for counts in product(*[range(1, x[1] + 1) for x in subset]):
            yield list(zip([x[0] for x in subset], counts))


def add_states_at_places(states, places):
    new_states = []
    last_index = 0
    for (index, _), count in places:
        new_states.extend(states[last_index:index])
        new_states.extend(repeat(states[index - 1], count))
        last_index = index
    new_states.extend(states[last_index:])
    return tuple(new_states)


def add_vars_at_place(new_eqs, place):
    (index, (eq_index, var_index)), count = place
    for i, equation in enumerate(new_eqs):
        for j, var in enumerate(equation):
            if is_normal_variable(var) and var[0] >= index:
                equation[j] = var[0] + count, var[1]

    equation = new_eqs[eq_index]
    new_eq = equation[:var_index]
    # 默认是 1 个变量
    new_eq.extend((i, 0) for i in range(index, index + count))
    new_eq.extend(equation[var_index:])
    new_eqs[eq_index] = new_eq


def add_vars_at_places(eqs, places):
    new_eqs = [list(eq) for eq in eqs]
    # 必须从大到小计算
    for place in sorted(places, key=lambda x: -x[0][0]):
        add_vars_at_place(new_eqs, place)
    return tuple(tuple(eq) for eq in new_eqs)


def remove_var_by_indices(equation, indices):
    new_equation = []
    for var in equation:
        if not (is_normal_variable(var) and var[0] in indices):
            new_equation.append(var)

    # 偏移量
    n = len(indices)
    for index, var in enumerate(new_equation):
        if not is_normal_variable(var):
            continue
        j = 0
        while j < n and var[0] > indices[j]:
            j += 1
        new_equation[index] = var[0] - j, var[1]

    return tuple(new_equation)


def does_eqs_have_continuous_indices(eqs, indices):
    '''
    等式有连续 count 个变量来自同一个组状态
    '''
    count = len(indices)
    for eq_index, equation in enumerate(eqs):
        c = count
        equation = equation[1:]  # 跳过左边
        for var_index, var in enumerate(equation):
            if is_normal_variable(var) and var[0] in indices:
                c -= 1
            else:
                c = count
            if c == 0:
                return eq_index, var_index + 2  # 跳过当前的和最左边的
    return None


class Worker:
    def __init__(self, rules, states, state2index):

        self. extended_rules = {}
        self. ex_rule_list = []

        self.special_states_before = [
            (state2index.get(string, -1), predicts)
            for string, predicts in SPECIAL_STATES_BEFORE
        ]

        self.special_states_after = [
            (state2index.get(string, -1), predicts)
            for string, predicts in SPECIAL_STATES_AFTER
        ]

        self.special_states_map = {
            state2index.get(string, -1): max_count
            for string, max_count in SPECIAL_STATES_STRINGS
        }
        self.states = states
        self.rules = rules
        self.state2index = state2index

    def add_rules(self, rule, eqs, prototype_index, strict=True):
        '''
        添加新的规则
        '''
        eqs2index = None
        if strict and eqs not in self.rules.get(rule, {}) or \
           rule not in self.rules and rule not in self.extended_rules:
            eqs2index = self.extended_rules.setdefault(rule, {})
        if eqs2index is not None:
            if eqs not in eqs2index:
                eqs2index[eqs] = - len(self.ex_rule_list) - 2
                self.ex_rule_list.append((rule, eqs, prototype_index))
                return True
        return False

    def get_states_to_add(self, node_label, states, equation, use_before):
        states_to_add = []
        states_enabled = []
        # 方向 ... <- L -> ...
        if use_before:
            items = reversed(self.special_states_before)
        else:
            items = self.special_states_after
        for k, (index, predicts) in enumerate(items):
            state_index, var_index = -1, -1
            for i, state in enumerate(states):
                if state[0] == index and state[1] <= 1:
                    state_index = i
                    break

            if state_index != -1:
                for j, var in enumerate(equation):
                    if is_normal_variable(var) and var[0] == state_index:
                        var_index = j
                        # 对 L 之前的取最前面, 之后取最后面
                        if use_before:
                            break
            # 是否需要添加
            if predicts and \
               state_index == -1 and \
               any(map(lambda x: x(node_label), predicts)):
                states_enabled.append(k)
            states_to_add.append(((index, 1), var_index))
        return states_to_add, states_enabled

    # ----------------------------------------

    def generate_rules_by_deletion(self, rule, eqs):
        node_label, in_states, out_states = rule
        stream_target_index = -1
        # normal variable
        if len(eqs) > 0 and is_normal_variable(eqs[0][0]):
            stream_target_index = eqs[0][0][0]  # eqs[0][0] = (index, num)

        phrase_indices = []
        for index, in_state in enumerate(in_states):
            if in_state[0] in self.special_states_map \
               and index != stream_target_index:
                phrase_indices.append(index)

        for indices_to_remove in subsets(phrase_indices):
            indices_to_remove = sorted(indices_to_remove)
            new_in_states = tuple(state for i, state in enumerate(in_states)
                                  if i not in indices_to_remove)
            new_eqs = tuple(remove_var_by_indices(eq, indices_to_remove)
                            for eq in eqs)
            yield (node_label, new_in_states, out_states), new_eqs

    def generate_rules_by_addition(self, rule, eqs):
        node_label, in_states, out_states = rule
        stream_target_index = -1
        # normal variable
        if len(eqs) > 0 and is_normal_variable(eqs[0][0]):
            stream_target_index = eqs[0][0][0]  # eqs[0][0] = (index, num)

        addition_places = []
        index, n = 0, len(in_states)
        while index < n:
            in_state = in_states[index]
            index += 1
            # 不计算 poss 这种
            max_count = self.special_states_map.get(in_state[0], 0)
            if max_count <= 1 or \
               index - 1 == stream_target_index or \
               in_state[1] != 1:  # 多个变量
                continue

            count = 1
            while index < n and in_states[index] == in_state:
                count += 1
                index += 1

            var_pos = does_eqs_have_continuous_indices(
                eqs, range(index - count, index))
            if var_pos is None:
                continue

            max_count -= count
            # 还可以添加若干个
            if max_count > 0:
                addition_places.append(((index, var_pos), max_count))

        for places in subsets_multi(addition_places):
            new_in_states = add_states_at_places(in_states, places)
            new_eqs = add_vars_at_places(eqs, places)
            yield (node_label, new_in_states, out_states), new_eqs

    def generate_rules_by_aggressive_addition(self, rule, eqs):
        lemma_pos = equations_lemma_pos(eqs)
        if len(eqs) != 1 or not lemma_pos:
            return

        node_label, in_states, out_states = rule
        states_before_all, states_before = \
            self.get_states_to_add(node_label,
                                   in_states, eqs[0], use_before=True)
        states_after_all, states_after = \
            self.get_states_to_add(node_label,
                                   in_states, eqs[0], use_before=False)

        lemma_index, before_n = lemma_pos[1], len(states_before_all)
        sense_pos = equations_sense_pos(eqs)
        sense_index = -1
        if sense_pos:
            sense_index = sense_pos[1]

        set_iter = product(subsets(states_before, from_size=0),
                           subsets(states_after, from_size=0))
        # print_rule(rule, self.states)
        # print_equations(eqs)
        # print(states_before_all, states_before)
        # print(states_after_all, states_after)
        for states_b, states_a in set_iter:
            if not states_b and not states_a:
                continue
            equation = list(eqs[0])
            vars_to_add = []  # (var, place_index)
            new_states = list(in_states)
            # 方向 ... <- L -> ...
            state_iter = enumerate(chain(states_before_all,
                                         states_after_all))
            last_var_index = lemma_index
            # print(states_b, states_a)
            for count, (state, var_index) in state_iter:
                if count == before_n:  # 后向
                    last_var_index = max(lemma_index, sense_index) + 1
                # 变量要添加的位置
                if count >= before_n:
                    var_index = max(var_index + 1, last_var_index)
                elif var_index >= 0:
                    var_index = min(var_index, last_var_index)
                else:
                    var_index = last_var_index

                # 对不添加的设置 last_var_index
                if count >= before_n:
                    if count - before_n not in states_a:
                        last_var_index = var_index
                        continue
                elif count not in states_b:
                    last_var_index = var_index
                    continue

                j = 0  # state 将会插入到 j 位置
                for old_state in new_states:
                    if old_state > state:
                        break
                    j += 1
                new_states = new_states[:j] + [state] + new_states[j:]
                # 修改等式的变量
                for i, var in enumerate(equation):
                    if is_normal_variable(var) and var[0] >= j:
                        equation[i] = var[0] + 1, var[1]
                # 修改要添加的位置
                for i, (var, place) in enumerate(vars_to_add):
                    if var[0] >= j:
                        vars_to_add[i] = (var[0] + 1, var[1]), place
                vars_to_add.append(((j, 0), var_index))

                last_var_index = var_index
            last_end = 0
            new_eq = []
            vars_to_add.sort(key=lambda x: x[1])
            # print(vars_to_add)
            for var, place in vars_to_add:
                new_eq.extend(equation[last_end: place])
                new_eq.append(var)
                last_end = place
            new_eq.extend(equation[last_end:])
            new_eqs = tuple((tuple(new_eq),))
            new_rule = (node_label, tuple(new_states), out_states)
            yield new_rule, new_eqs

    def generate_rules_by_removing_empty(self, rule, eqs):
        '''
        去掉出边的 EMPTY 状态
        '''
        node_label, in_states, out_states = rule
        empty_indices = []
        for index, out_state in enumerate(out_states):
            if is_empty_state(out_state):
                empty_indices.append(index)

        for indices_to_remove in subsets(empty_indices):
            new_out_states = tuple(state for i, state in enumerate(out_states)
                                   if i not in indices_to_remove)

            indices_to_remove = [i + len(in_states)
                                 for i in indices_to_remove]
            new_eqs = tuple(remove_var_by_indices(eq, indices_to_remove)
                            for eq in eqs)
            yield (node_label, in_states, new_out_states), new_eqs

    def generate(self, ignore_indices=[]):
        def rule_items():
            return self.rules.items()

        def rule_items_all():
            return chain(self.rules.items(), list(self.extended_rules.items()))

        methods = [
            (self.generate_rules_by_deletion, rule_items, True),
            (self.generate_rules_by_addition, rule_items, True),
            (self.generate_rules_by_aggressive_addition, rule_items_all, False),
            (self.generate_rules_by_removing_empty, rule_items_all, True)
        ]
        count = 0
        for mindex, (method, sources, strict) in enumerate(methods):
            if mindex in ignore_indices:
                continue
            for rule, eqs2index in sources():
                for eqs, rule_index in eqs2index.items():
                    # 原型链的最顶端
                    while rule_index < 0:
                        rule_index = self.ex_rule_list[-rule_index - 2][2]
                    # if rule_index != 37145:
                    #     continue
                    for new_rule, new_eqs in method(rule, eqs):
                        self.add_rules(new_rule, new_eqs, rule_index, strict)
            total = len(self.ex_rule_list)
            print('[INFO] %d rules (%s)' % (total - count, method.__name__))
            count = total

        # check
        for rule, eqs, prototype_index in self.ex_rule_list:
            l1, l2 = len(rule[1]), len(rule[2])
            for equation in eqs:
                for i, var in enumerate(equation):
                    if is_normal_variable(var):
                        state_index = var[0]
                        if state_index == -1:
                            continue
                        var_index = var[1]
                        var_count = -1
                        if state_index < l1:
                            var_count = rule[1][state_index][1]
                        elif state_index - l1 < l2:
                            var_count = rule[2][state_index - l1][1]

                        if var_index >= var_count and \
                           (i > 0 or var_index >= 0):
                            print('prototype_index =', prototype_index)
                            print_rule(rule, self.states)
                            print_equations(eqs)
                            raise Exception("Something wrong !!!")


if __name__ == '__main__':
    import sys
    from rules_reader import load_rules
    from rule import print_rule, print_equations
    from extract_compiled_rules import (print_compiled_rules,
                                        compile_states,
                                        compile_equations)

    level = 0
    if len(sys.argv) > 1:
        try:
            level = int(sys.argv[1].strip())
        except Exception:
            pass
    print('INFO:level =', level)
    rules, states, state2index, literals, literal2index = load_rules()
    worker = Worker(rules, states, state2index)
    if level < 9:
        worker.generate()
    else:
        worker.generate(ignore_indices=[2])

    with open('data/extended_rules.txt', 'w') as out:
        for index, (rule, eqs, _) in enumerate(worker.ex_rule_list):
            out.write(str(index) + '\n')
            print_rule(rule, states, out)
            print_equations(eqs, out)
            out.write('\n')

    with open('data/extended_compiled_rules.txt', 'w') as out:
        out.write(str(len(worker.ex_rule_list)))
        out.write('\n')
        for rule, eqs, _ in worker.ex_rule_list:
            rule = (rule[0],
                    compile_states(rule[1]),
                    compile_states(rule[2]))
            eqs = compile_equations(eqs, literals, literal2index)
            print_compiled_rules(rule, eqs, out)
            out.write('\n')

    with open('data/extended_rules_prototype.txt', 'w') as out:
        for _, __, prototype_index in worker.ex_rule_list:
            out.write(str(prototype_index) + '\n')
