# -*- coding: utf-8 -*-

import torch

import pyshrg
from framework.common.logger import LOGGER
from framework.common.utils import DotDict, ProgressReporter
from framework.data.pad_utils import pad_2d_values, sequence_mask

from . import single_thread
from ..batch_utils import process_batch
from ..data_utils import edsgraph_to_dict, pad_batch_derivations, pad_batch_graphs
from .utils import Evaluator


def _get_graph_batches(graph_size, skipped_indices, context_size):
    skipped_count = 0
    graph_batches = []
    graph_batch = []
    for graph_index in range(graph_size):
        if graph_index in skipped_indices:
            skipped_count += 1
            continue
        graph_batch.append(graph_index)
        if len(graph_batch) == context_size:
            graph_batches.append(graph_batch)
            graph_batch = []
    if graph_batch:
        graph_batches.append(graph_batch)
    return graph_batches


def _collect_partitions(contexts, graphs, graph_indices, batch_chart_items, inputs, device):
    derivations = []
    for chart_items, context, graph in zip(batch_chart_items, contexts, graphs):
        partition_indices = [list(range(len(chart_items)))]
        rule_indices = []
        center_parts = []
        left_parts = []
        right_parts = []
        for chart_item in chart_items:
            center, left_ptr, right_ptr = context.split_item(chart_item, graph)
            rule_indices.append(chart_item.grammar_index)
            center_parts.append(center)
            left_parts.append(left_ptr.to_list(graph)
                              if left_ptr is not None else single_thread.EMPTY_GRAPH)
            right_parts.append(right_ptr.to_list(graph)
                               if right_ptr is not None else single_thread.EMPTY_GRAPH)

        derivations.append((None, partition_indices, None,
                            rule_indices, center_parts, left_parts, right_parts))

    offsets = pad_batch_derivations(inputs, derivations, graph_indices,
                                    include_shrg_instances=False)
    process_batch(inputs, device=device)

    return offsets


def _collect_cfg_choices(shrg_rules, lengths, gold_partitions, inputs):
    num_cfg_choices = max(lengths)
    # shape: [batch_size, num_cfg_choices, partition_size]
    gold_partitions = torch.stack(gold_partitions).unsqueeze(1).repeat(1, num_cfg_choices, 1)

    device = gold_partitions.device
    # shape: [batch_size, num_cfg_choices]
    inputs['shrg_instances'] = torch.from_numpy(pad_2d_values(shrg_rules)).to(device)
    # shape: [batch_size, num_cfg_choices]
    inputs['shrg_instances_mask'] = torch.from_numpy(sequence_mask(lengths)).to(device)
    return gold_partitions


class MultiThreadsEvaluator(Evaluator):
    def run(self, inputs, graphs, contexts):
        node_embeddings = self.network.run_encoder(inputs)

        indices_in_batch = range(len(graphs))
        stack = [(contexts, graphs, indices_in_batch,
                  [context.result_item for context in contexts])]

        for context in contexts:
            context.set_best_item(context.result_item)

        while stack:
            current_contexts, current_graphs, current_indices, current_item_heads = stack.pop()

            batch_chart_items = [list(item_head.all()) for item_head in current_item_heads]
            partition_offsets = \
                _collect_partitions(current_contexts, current_graphs, current_indices,
                                    batch_chart_items, inputs, self.device)

            hrg_outputs = self.network.hrg(inputs, node_embeddings,
                                           return_partitions=True,
                                           hrg_min_choices=1)

            gold_partitions = []
            shrg_rules = []
            cfg_lengths = []
            cfg_indices = []
            best_partition_indices = hrg_outputs.scores.argmax(dim=1)
            for index, (item_head, chart_items) in \
                    enumerate(zip(current_item_heads, batch_chart_items)):
                best_index = best_partition_indices[index].item()

                tmp = chart_items[best_index]
                if item_head is not tmp:   # now we have found the correct partition
                    item_head.swap(tmp)

                best_index += partition_offsets[index]
                shrg_rule = item_head.grammar
                num_choices = shrg_rule.size
                if num_choices > 1 and self.predict_cfg:
                    shrg_rules.append([cfg_rule.shrg_rule_index
                                       for cfg_rule in shrg_rule.iter_cfgs()])
                    cfg_indices.append(index)
                    cfg_lengths.append(num_choices)
                    gold_partitions.append(hrg_outputs.partitions[best_index])

            if gold_partitions:  # needs to predict_cfg
                gold_partitions = \
                    _collect_cfg_choices(shrg_rules, cfg_lengths, gold_partitions, inputs)
                cfg_outputs = self.network.cfg(inputs, gold_partitions=gold_partitions)
                for index, best_cfg_index in zip(cfg_indices, cfg_outputs.scores.argmax(dim=1)):
                    # index of cfg rule in this hrg rule
                    current_item_heads[index].cfg_index = best_cfg_index

            next_contexts = []
            next_graphs = []
            next_item_heads = []
            next_indices = []
            for context, graph, index, item_head in \
                    zip(current_contexts, current_graphs, current_indices, current_item_heads):
                left_ptr, right_ptr = context.split_item(item_head)
                if left_ptr:
                    next_contexts.append(context)
                    next_graphs.append(graph)
                    next_item_heads.append(left_ptr)
                    next_indices.append(index)
                if right_ptr:
                    next_contexts.append(context)
                    next_graphs.append(graph)
                    next_item_heads.append(right_ptr)
                    next_indices.append(index)

            if next_contexts:
                stack.append((next_contexts, next_graphs, next_indices, next_item_heads))

    def forward(self):
        manager = pyshrg.get_manager()

        error_count = 0
        skipped_count = 0

        graph_size = manager.graph_size
        context_size = manager.context_size
        runner = pyshrg.Runner(manager, verbose=self.verbose)

        batches = _get_graph_batches(graph_size, self.skipped_indices, context_size)
        progress = ProgressReporter(
            stop=len(batches),
            message_fn=lambda _: f'error/skip: {error_count}/{skipped_count}',
            print_time=True)

        for graph_batch in progress(batches):
            results = runner(graph_batch)

            graphs = []
            contexts = []
            for context_index, (code, graph_index) in enumerate(zip(results, graph_batch)):
                try:
                    if code != pyshrg.ParserError.kNone:
                        error_count += 1
                        self.skipped_indices.add(graph_index)
                        continue

                    graphs.append(manager.get_graph(graph_index))
                    contexts.append(manager.get_context(context_index))
                except Exception as err:
                    error_count += 1
                    LOGGER.error('%s', err)
                    continue

            if not graphs:
                continue

            inputs = [edsgraph_to_dict(graph) for graph in graphs]
            inputs = DotDict(pad_batch_graphs(inputs, self.vocabs, training=False))
            process_batch(inputs, device=self.device)

            self.run(inputs, graphs, contexts)

            for context, graph in zip(contexts, graphs):
                self.generate(context, graph)
