# -*- coding: utf-8 -*-

import numpy as np

from framework.data.pad_utils import pad_2d_values, pad_3d_values, sequence_mask
from framework.data.vocab import lookup_words

PROPERTY_NAMES = ['postag', 'sense', 'tense', 'num', 'pers', 'prog', 'perf']
EMPTY_FRAGMENT = ([], [], [], [])


def edsgraph_to_dict(graph, vocabs=None):
    """
    @type graph: pyshrg.EdsGraph
    @type exporter: VocabularySet
    """
    node_cargs = []
    node_lemmas = []
    node_properties = []
    nodes = graph.nodes

    if vocabs is not None:
        lemma_vocab = vocabs.get_or_new('word')
        edge_vocab = vocabs.get_or_new('edge')
        carg_vocab = vocabs.get_or_new('carg')
        prop_vocabs = [vocabs.get_or_new(name) for name in PROPERTY_NAMES]

    for node in nodes:  # type: pyshrg.EdsGraph.Node
        carg, lemma = node.carg.lower(), node.lemma.lower()
        properties = [node.pos_tag, node.sense] + node.properties

        if vocabs is not None:
            lemma_vocab.add(lemma)
            carg_vocab.add(carg)
            for vocab, prop in zip(prop_vocabs, properties):
                vocab.add(prop)

        node_cargs.append(carg)
        node_lemmas.append(lemma)
        node_properties.append(properties)

    incoming_indices = [[i] for i in range(len(nodes))]
    outgoing_indices = [[i] for i in range(len(nodes))]
    incoming_labels = [['Self'] for _ in range(len(nodes))]
    outgoing_labels = [['Self'] for _ in range(len(nodes))]

    for edge in graph.edges:  # type: pyshrg.EdsGraph.Edge
        linked_nodes = edge.linked_nodes
        if len(linked_nodes) < 2:
            continue

        from_node, to_node = linked_nodes

        incoming_indices[to_node.index].append(from_node.index)
        incoming_labels[to_node.index].append(edge.label)
        outgoing_indices[from_node.index].append(to_node.index)
        outgoing_labels[from_node.index].append(edge.label)

        if vocabs is not None:
            edge_vocab.add(edge.label)

    # transposed properties data [num_prop, num_nodes]
    transposed_properties = [[] for _ in node_properties[0]]
    for props in node_properties:
        for prop, node_single_props in zip(props, transposed_properties):
            node_single_props.append(prop)

    return {
        'cargs': node_cargs,
        'lemmas': node_lemmas,
        'properties': transposed_properties,
        'incoming_indices': incoming_indices,
        'outgoing_indices': outgoing_indices,
        'incoming_labels': incoming_labels,
        'outgoing_labels': outgoing_labels,
        'sentence_id': graph.sentence_id
    }


def simplify_derivation(derivation):

    steps = []
    hrg_indices = []
    shrg_indices = []
    center_parts = []
    left_parts = []
    right_parts = []
    all_partition_indices = []

    for step, partitions in enumerate(derivation):
        if partitions is None:
            continue
        steps.append(step)

        shrg_index, partitions = partitions
        shrg_indices.append(shrg_index)

        start_index = len(hrg_indices)
        all_partition_indices.append(list(range(start_index, start_index + len(partitions))))

        for hrg_index, center, left, right in partitions:
            hrg_indices.append(hrg_index)
            center_parts.append(center)
            left_parts.append(left or EMPTY_FRAGMENT)
            right_parts.append(right or EMPTY_FRAGMENT)

    # steps:
    # hrg_indices: [num_partitions]
    # all_partition_indices:
    all_partitions = (
        steps,  # [num_steps]
        all_partition_indices,  # [num_steps, num_partitions_per_step]
        shrg_indices,  # [num_steps] index of the shrg rule of each step
        hrg_indices,  # [num_partitions] index of the hrg rule of each partition
        center_parts, left_parts, right_parts  # [num_partitions]
    )
    return len(hrg_indices), all_partitions


def pad_batch_graphs(batch_graphs, vocabs, property_names=PROPERTY_NAMES,
                     exclude_edges=False, training=True):
    lemma_vocab = vocabs.get('word')
    edge_vocab = vocabs.get('edge')
    carg_vocab = vocabs.get('carg')
    prop_vocabs = [vocabs.get(name) for name in property_names]

    batch_lemmas = []
    batch_cargs = []
    batch_properties_list = [[] for _ in prop_vocabs]
    batch_sentence_ids = []

    if not exclude_edges:
        batch_incoming_labels = []
        batch_outgoing_labels = []
        batch_incoming_indices = []
        batch_outgoing_indices = []

    if training:
        derivations = []

    for graph in batch_graphs:
        if training:
            graph, derivation = graph  # graph, derivation
            derivations.append(derivation)

        batch_sentence_ids.append(graph['sentence_id'])
        batch_lemmas.append(lookup_words(graph['lemmas'], lemma_vocab))
        batch_cargs.append(lookup_words(graph['cargs'], carg_vocab))

        for batch_single_props, single_props, vocab in \
                zip(batch_properties_list, graph['properties'], prop_vocabs):
            batch_single_props.append(lookup_words(single_props, vocab))

        if exclude_edges:
            continue

        batch_incoming_indices.append(graph['incoming_indices'])
        batch_outgoing_indices.append(graph['outgoing_indices'])
        batch_incoming_labels.append([lookup_words(labels, edge_vocab)
                                      for labels in graph['incoming_labels']])
        batch_outgoing_labels.append([lookup_words(labels, edge_vocab)
                                      for labels in graph['outgoing_labels']])

    batch = {}
    # shape: [batch_size]
    batch['sentence_ids'] = batch_sentence_ids
    batch['word_lengths'] = np.array([len(_) for _ in batch_lemmas])
    # shape: [batch_size, num_nodes]
    batch['words'] = pad_2d_values(batch_lemmas)
    batch['nodes_mask'] = sequence_mask(batch['word_lengths'])

    # shape: [batch_size, num_nodes]
    batch['carg'] = pad_2d_values(batch_cargs)
    for name, batch_single_props in zip(property_names, batch_properties_list):
        batch[name] = pad_2d_values(batch_single_props)

    if not exclude_edges:
        # shape: [batch_size, num_nodes]
        batch['incoming_mask'] = \
            sequence_mask(pad_2d_values([[len(_) for _ in incoming_indices]  # for every node
                                         for incoming_indices in batch_incoming_indices]))
        # shape: [batch_size, num_nodes]
        batch['outgoing_mask'] = \
            sequence_mask(pad_2d_values([[len(_) for _ in outgoing_indices]  # for every node
                                         for outgoing_indices in batch_outgoing_indices]))

        # shape: [batch_size, num_nodes, incoming_degree]
        batch['incoming_indices'] = pad_3d_values(batch_incoming_indices)
        # shape: [batch_size, num_nodes, outgoing_degree]
        batch['outgoing_indices'] = pad_3d_values(batch_outgoing_indices)
        # shape: [batch_size, num_nodes, incoming_degree]
        batch['incoming_labels'] = pad_3d_values(batch_incoming_labels)
        # shape: [batch_size, num_nodes, outgoing_degree]
        batch['outgoing_labels'] = pad_3d_values(batch_outgoing_labels)

    if training:
        pad_batch_derivations(batch, derivations)

    return batch


def pad_batch_derivations(batch, derivations, graph_indices=None, include_shrg_instances=True):
    if graph_indices is None:
        graph_indices = range(len(derivations))

    graph_size = batch['words'].shape[1]

    batch_steps = []
    batch_partition_indices = []
    if include_shrg_instances:
        all_shrg_indices = []
    all_partition = []

    def _translate(i, x):
        if isinstance(x, int):
            return i * graph_size + x
        return [_translate(i, _) for _ in x]

    def _pad_part(tag, index):
        nodes = [_translate(_[0], _[index][0]) for _ in all_partition]  # indices of nodes
        num_nodes = max(len(_) for _ in nodes)
        if num_nodes > 0:
            batch[f'{tag}_nodes'] = pad_2d_values(nodes)
            batch[f'{tag}_nodes_mask'] = sequence_mask(np.array([len(_) for _ in nodes]))
        else:
            batch[f'{tag}_nodes'] = None

        borders = [_translate(x[0], x[index][1]) for x in all_partition]
        num_borders = max(len(_) for _ in borders)
        if num_borders > 0:
            batch[f'{tag}_borders'] = pad_2d_values(borders)
            batch[f'{tag}_borders_mask'] = sequence_mask(np.array([len(_) for _ in borders]))
            batch[f'{tag}_border_types'] = pad_2d_values([_[index][2] for _ in all_partition])
            batch[f'{tag}_border_orders'] = pad_2d_values([_[index][3] for _ in all_partition])
        else:
            batch[f'{tag}_borders'] = None

    partition_offsets = []
    for graph_index, derivation in zip(graph_indices, derivations):
        # partitions = hrg_indices, center_parts, left_parts, right_parts
        steps, partition_indices, shrg_indices, *partitions = derivation

        batch_steps.append(steps)
        batch_partition_indices.append(partition_indices)

        if include_shrg_instances:
            all_shrg_indices.extend(shrg_indices)

        partition_offsets.append(len(all_partition))
        for partition in zip(*partitions):
            all_partition.append((graph_index,) + partition)

    batch['steps'] = batch_steps
    batch['hrg_indices'] = np.array([x[1] for x in all_partition])

    # x[0] = graph_index, x[2] = center_parts
    nodes = [_translate(x[0], x[2]) for x in all_partition]
    batch['center_nodes'] = pad_2d_values(nodes)
    batch['center_nodes_mask'] = sequence_mask(np.array([len(_) for _ in nodes]))

    # x[3] = (left_nodes, left_borders, left_border_types, left_border_orders)
    _pad_part('left', 3)
    # x[4] = (right_nodes, right_borders, right_border_types, right_border_orders)
    _pad_part('right', 4)

    instances = [
        [index + offset for index in indices_of_this_step]
        for offset, partition_indices in zip(partition_offsets, batch_partition_indices)
        for indices_of_this_step in partition_indices
    ]
    batch['instances'] = pad_2d_values(instances)
    batch['instances_mask'] = sequence_mask(np.array([len(_) for _ in instances]))

    if include_shrg_instances:
        batch['shrg_instances'] = pad_2d_values(all_shrg_indices)
        batch['shrg_instances_mask'] = \
            sequence_mask(np.array([len(_) for _ in all_shrg_indices]))

    return partition_offsets
