# -*- coding: utf-8 -*-

from collections import Counter

from framework.common.logger import open_file
from framework.evaluate.bleu import BLEU

from .token_utils import post_process_tokens


def extract_tree(context, graph):
    manager = context.manager
    all_edges = graph.edges

    derivation = []
    for node in context.derivation:
        edges = []
        for edge_index in node.item.edge_set.to_edge_indices():
            edge = all_edges[edge_index]
            edges.append((*(node.id for node in edge.linked_nodes), edge.label))

        children = list(node.children)
        assert len(children) <= 2
        children += [-1] * (2 - len(children))
        derivation.append([(node.grammar.label, tuple(sorted(edges))),
                           manager.index_of_hrg(node.grammar),
                           None,
                           *children])
    return derivation


class EvalBForGraph:
    def __init__(self, item_fns=None):
        self.item_fns = item_fns

        self.num_system_items = 2 * len(item_fns) * [0]
        self.num_gold_items = 2 * len(item_fns) * [0]
        self.num_matched_items = 2 * len(item_fns) * [0]
        self.num_completes = 2 * len(item_fns) * [0]

        self.data = []

    def __len__(self):
        return len(self.data)

    def __iter__(self):
        return iter(self.data)

    def _match(self, system_items, gold_items, index):
        self.num_system_items[index] += len(system_items)
        self.num_gold_items[index] += len(gold_items)

        system_items = Counter(system_items)
        gold_items = Counter(gold_items)

        num_matched = sum((system_items & gold_items).values())

        self.num_completes[index] += (num_matched == len(system_items) == len(gold_items))
        self.num_matched_items[index] += num_matched

    def add(self, system_tree, gold_tree, sentence_id):
        def _tree_to_items(tree):
            if tree is None:
                return [], []
            ret = [], []
            for node in tree:
                item = item_fn(node)
                if (node[-1] == -1 or node[-1] is None) and (node[-2] == -1 or node[-2] is None):
                    ret[1].append(item)
                else:
                    ret[0].append(item)
            return ret

        for index, (_, item_fn) in enumerate(self.item_fns):
            system_items, lexicon_system_items = _tree_to_items(system_tree)
            gold_items, lexicon_gold_items = _tree_to_items(gold_tree)

            self._match(system_items, gold_items, index)
            self._match(lexicon_system_items, lexicon_gold_items, index + len(self.item_fns))

        self.data.append((sentence_id, system_tree, gold_tree))

    def _recall(self, index):
        return self.num_matched_items[index] / max(self.num_gold_items[index], 1) * 100

    def _precision(self, index):
        return self.num_matched_items[index] / max(self.num_system_items[index], 1) * 100

    def _f1(self, index):
        r = self._recall(index)
        p = self._precision(index)
        if p + r == 0:
            return 0
        return 2 * p * r / (p + r)

    def get(self):
        return self._f1(0)

    def _write(self, fp, name, index):
        num_samples = len(self.data)
        fp.write(f'### {name}\n')
        fp.write(f'System Items   : {self.num_system_items[index]}\n')
        fp.write(f'Gold Items     : {self.num_gold_items[index]}\n')
        fp.write(f'Complete match : {self.num_completes[index]} / {num_samples}'
                 f' {self.num_completes[index]/num_samples * 100:.2f}\n')
        fp.write(f'Recall         : {self._recall(index):.2f}\n')
        fp.write(f'Precision      : {self._precision(index):.2f}\n')
        fp.write(f'F1             : {self._f1(index):.2f}\n')
        fp.write('\n')

    def save(self, prefix, sort_by_id=False, no_output=False, no_score=False):
        paths = []

        if not no_output:
            paths.append(prefix + '.output')
            with open_file(paths[-1], 'w') as fp:
                for sentence_id, system_items, _ in self.data:
                    fp.write(f'{sentence_id}\n{repr(system_items)}\n')

        if not no_score:
            paths.append(prefix + '.score')
            with open_file(paths[-1], 'w') as fp:
                for index, (name, _) in enumerate(self.item_fns):
                    self._write(fp, name, index)
                    self._write(fp, name + '(lexicon)', index + len(self.item_fns))

        return paths

    def print_sample(self):
        pass


class Evaluator:
    def __init__(self, network, vocabs, device, return_output=False, verbose=False):
        self.network = network
        self.vocabs = vocabs

        self.return_output = return_output
        self.verbose = verbose
        self.device = device

        self.skipped_indices = set()

        self.derivation_strings = []

    def post_process_tokens(self, tokens):
        return post_process_tokens(tokens)

    def state_dict(self):
        return {'skipped_indices': self.skipped_indices}

    def load_state_dict(self, saved_state):
        self.skipped_indices = saved_state['skipped_indices']

    def generate(self, context, graph):
        context.generate()  # make sure derivation is prepared

        sentence_id = graph.sentence_id
        if self.predict_cfg:
            self.metrics.add(context.sentence.lower().strip(),
                             graph.lemma_sequence.lower().strip(),
                             sentence_id)
        else:
            self.metrics.add(extract_tree(context, graph),
                             self.gold_trees.get(sentence_id.rsplit('/', 1)[-1]),
                             sentence_id)

        self.derivation_strings.append((sentence_id, repr(context.derivation)))
        if self.verbose and len(self.metrics) % (context.manager.graph_size // 20) == 0:
            self.metrics.print_sample()

    def __call__(self, predict_cfg=True, gold_trees=None):
        self.predict_cfg = predict_cfg

        if predict_cfg:
            self.metrics = BLEU(post_process_tokens)
        else:
            self.gold_trees = gold_trees or {}
            self.metrics = EvalBForGraph([
                ('Labelled', lambda node: node[0]),  # (label, edges)
                ('Unlabelled', lambda node: node[0][1]),  # unlabeled
                ('Labelled+Rule', lambda node: node[0] + (node[1], ))  # (label, edges, rule)
            ])

        self.forward()

        if self.return_output:
            return self.metrics

        return self.metrics.get()
