# -*- coding: utf-8 -*-

import torch

import pyshrg
from framework.common.utils import DotDict, ProgressReporter

from ..batch_utils import process_batch
from ..data_utils import edsgraph_to_dict, pad_batch_derivations, pad_batch_graphs
from .utils import Evaluator

EMPTY_GRAPH = ([], [], [], [])


def collect_partitions(context, graph, chart_items, inputs, device):
    steps = [0]
    partition_indices = [list(range(len(chart_items)))]
    rule_indices = []
    center_parts = []
    left_parts = []
    right_parts = []
    for chart_item in chart_items:
        center, left_ptr, right_ptr = context.split_item(chart_item, graph)
        rule_indices.append(chart_item.grammar_index)
        center_parts.append(center)
        left_parts.append(left_ptr.to_list(graph) if left_ptr is not None else EMPTY_GRAPH)
        right_parts.append(right_ptr.to_list(graph) if right_ptr is not None else EMPTY_GRAPH)

    derivation = (steps, partition_indices, None,
                  rule_indices, center_parts, left_parts, right_parts)

    pad_batch_derivations(inputs, [derivation], include_shrg_instances=False)
    process_batch(inputs, device=device)


def collect_cfg_choices(shrg_rule, gold_partitions, inputs):
    num_choices = shrg_rule.size
    # shape: [1, num_choices, partition_size]
    gold_partitions = gold_partitions.view(1, 1, -1).repeat(1, num_choices, 1)
    # shape: [1, num_choices]
    inputs['shrg_instances'] = \
        torch.tensor([cfg_rule.shrg_rule_index for cfg_rule in shrg_rule.iter_cfgs()],
                     dtype=torch.long, device=gold_partitions.device).unsqueeze(0)
    # shape: [1, num_choices]
    inputs['shrg_instances_mask'] = \
        torch.ones(num_choices, dtype=torch.uint8, device=gold_partitions.device).unsqueeze(0)
    return gold_partitions


class SingleThreadEvaluator(Evaluator):
    def run(self, inputs, graph, context):
        node_embeddings = self.network.run_encoder(inputs)

        stack = [context.result_item]
        while stack:
            item_head = stack.pop()
            chart_items = list(item_head.all())

            collect_partitions(context, graph, chart_items, inputs, self.device)

            hrg_outputs = self.network.hrg(inputs, node_embeddings,
                                           return_partitions=True,
                                           hrg_min_choices=1)

            best_partition_index = 0
            if len(chart_items) > 1:
                scores = hrg_outputs.scores[0]  # only this sentence

                best_partition_index = scores.argmax().item()
                tmp = chart_items[best_partition_index]
                if item_head is not tmp:  # now we have found the correct partition
                    item_head.swap(tmp)

            shrg_rule = item_head.grammar
            num_choices = shrg_rule.size
            if num_choices > 1 and self.predict_cfg:
                # shape: [1, num_choices, partition_size]
                gold_partitions = hrg_outputs.partitions[best_partition_index]
                gold_partitions = collect_cfg_choices(shrg_rule, gold_partitions, inputs)
                cfg_outputs = self.network.cfg(inputs, gold_partitions=gold_partitions)
                # index of the cfg rule in this shrg rule
                item_head.cfg_index = cfg_outputs.scores.argmax().item()

            left_ptr, right_ptr = context.split_item(item_head)
            if right_ptr:
                stack.append(right_ptr)
            if left_ptr:
                stack.append(left_ptr)

        context.set_best_item(context.result_item)

    def forward(self):
        manager = pyshrg.get_manager()
        error_count = 0
        skippd_count = 0

        progress = ProgressReporter(
            stop=manager.graph_size,
            message_fn=lambda _: f'error/skipped: {error_count}/{skippd_count}')

        context = manager.get_context(0)

        for graph_index, graph in progress(enumerate(manager.iter_graphs())):
            if graph_index in self.skipped_indices:
                skippd_count += 1
                continue

            code = context.parse(graph_index)
            if code != pyshrg.ParserError.kNone:
                error_count += 1
                self.skipped_indices.add(graph_index)
                continue

            inputs = DotDict(pad_batch_graphs([edsgraph_to_dict(graph)], self.vocabs,
                                              training=False))
            process_batch(inputs, device=self.device)

            self.run(inputs, graph, context)
            self.generate(context, graph)
