# -*- coding: utf-8 -*-

import torch

from framework.data.pad_utils import pad_2d_values, sequence_mask
from pyshrg_utils.chart_items import iter_items

from . import single_thread

BATCH_SIZE = 1000


class FullSearchEvaluator(single_thread.SingleThreadEvaluator):
    def run(self, inputs, graph, context):
        node_embeddings = self.network.run_encoder(inputs)

        all_chart_items = list(iter_items(context, context.result_item))
        num_chart_items = len(all_chart_items)
        for start in range(0, num_chart_items, BATCH_SIZE):
            batch_chart_items = all_chart_items[start:min(start + BATCH_SIZE, num_chart_items)]
            single_thread.collect_partitions(context, graph, batch_chart_items, inputs)

            hrg_outputs = self.network.hrg(inputs, node_embeddings)

            for score, chart_item in zip(hrg_outputs.scores, batch_chart_items):
                chart_item.score = score

        root_item = context.result_item
        context.set_best_item(root_item)

        all_chart_items = context.find_best_derivation(root_item)

        if not all_chart_items or not self.predict_cfg:
            return

        single_thread.collect_partitions(context, graph, all_chart_items, inputs)

        shrg_rules = [[cfg.shrg_index for cfg in chart_item.grammar.iter_cfgs()]
                      for chart_item in all_chart_items]
        lengths = [len(_) for _ in shrg_rules]

        # shape: [batch_size, num_cfg_choices, partition_size]
        hrg_outputs = self.network.hrg(inputs, node_embeddings, return_partitions=True)
        gold_partitions = hrg_outputs.partitions.unsqueeze(1).repeat(1, max(lengths), 1)

        device = gold_partitions.device
        # shape: [batch_size, num_cfg_choices]
        inputs['shrg_instances'] = torch.from_numpy(pad_2d_values(shrg_rules)).to(device=device)
        # shape: [batch_size, num_cfg_choices]
        inputs['shrg_instances_mask'] = torch.from_numpy(sequence_mask(lengths)).to(device=device)

        cfg_outputs = self.network.cfg(inputs, gold_partitions=gold_partitions)
        for chart_item, index in zip(all_chart_items, cfg_outputs.scores.argmax(dim=1)):
            chart_item.cfg_index = index
