# -*- coding: utf-8 -*-

import base64
import functools
import gzip
import json
import os
import pickle

from flask import Flask, abort, jsonify

from framework.common.dataclass_options import TRAINING_KEY, field_choices

from .const_tree import Lexicon
from .dataset_utils import load_cfg_strings
from .graph_draw import (draw_const_tree,
                         draw_derivation,
                         draw_hrg_rule,
                         draw_hyper_graph,
                         draw_tree_decomposition)
from .shrg_extract import DeepBankExtractor, ExtractionOptions
from .tree_decomposition import tree_decomposition


def strange_rules(graph):
    from SHRG.shrg_detect import _get_edges_by_node
    free_nodes = set()
    edges_by_node = _get_edges_by_node(graph.edges)
    for node in graph.nodes:
        if all(not linked_edge.is_terminal for linked_edge in edges_by_node[node]):
            free_nodes.add(node)

    is_terminal_empty = True
    is_weakly_regular = True
    for edge in graph.edges:
        if edge.is_terminal:
            is_terminal_empty = False
        elif all(node in free_nodes for node in edge.nodes):
            is_weakly_regular = False

    return not is_weakly_regular and not is_terminal_empty


def _format_cfg_rule_item(tag, edge):
    if isinstance(tag, Lexicon):
        return f'"{tag.string}"'
    assert isinstance(tag, str)

    if edge is None:
        return tag

    if tag.startswith('<') and edge.is_terminal:
        carg = getattr(edge, 'carg', '???')
        return f'{tag}({carg})'
    return f'{tag}#{len(edge.nodes)}'


class Grammar:
    def __init__(self, grammar_dir, data_path, tree_path, options=None):
        self.grammar_dir = grammar_dir
        self.data_path = data_path
        self.tree_path = tree_path

        self._rules = None
        self._cache = {}

        self.tree_strings = self._load_tree(tree_path)

        if options is None:
            options = ExtractionOptions.from_file(os.path.join(grammar_dir, 'config'))

        self.options = options

    def _build_indices(self):
        pass

    def sample_id_to_index(self, sample_id):
        self._build_indices()
        return self._indices.get(sample_id)

    def index_to_sample_id(self, index):
        self._build_indices()
        return self._sample_ids[index % len(self._sample_ids)]

    def _load_tree(self, tree_path):
        trees = load_cfg_strings(tree_path)
        self._sample_ids = sorted(trees)
        self._indices = {sample_id: index for index, sample_id in enumerate(self._sample_ids)}
        return trees

    def copy(self, options):
        return self.__class__(self.grammar_dir, self.data_path, self.tree_path, options)

    def __str__(self):
        return f'<Grammar {self.name}>'

    def __repr__(self):
        return str(self)

    @property
    def name(self):
        return os.path.basename(self.grammar_dir)

    @property
    def rules_filename(self):
        return os.path.join(self.grammar_dir, 'train.counter.p')

    @property
    def rules(self):
        if self._rules is None:
            try:
                self._rules, _ = pickle.load(open(self.rules_filename, 'rb'))
            except FileNotFoundError:
                pass
        return self._rules

    def get_derivations(self, sentence_id):
        output = self._cache.get(sentence_id)
        if output is None:
            bank, tree_string = self.tree_strings[sentence_id]
            filename = os.path.join(self.data_path, bank, sentence_id + '.gz')
            with gzip.open(filename, 'rb') as fin_bank:
                graph_data = fin_bank.read().decode()

            output = DeepBankExtractor.extract(self.options, tree_string, graph_data, sentence_id)
            self._cache[sentence_id] = output
        return output


class SHRGVisualizationService(Flask):
    def __init__(self, grammars, custom_grammar, static_folder):
        super().__init__(self.__call__.__name__,
                         static_url_path='/files', static_folder=static_folder)
        self.custom_options = custom_grammar.options
        self.custom_grammar = custom_grammar

        self.grammars = {}
        for grammar in grammars:
            self.grammars[grammar.name] = grammar

        self.add_url_rule('/',
                          view_func=self.index, methods=['GET'])
        self.add_url_rule('/api/grammars',
                          view_func=self.get_grammars, methods=['GET'])

        self.add_url_rule('/api/rule/<string:grammar_name>/<int:rule_index>',
                          view_func=self.get_rule, methods=['GET'])
        self.add_url_rule('/api/search-rule/<string:grammar_name>/<string:condition>',
                          view_func=self.search_rule, methods=['GET'])

        self.add_url_rule('/api/sentence/<string:sentence_id>/<int:step>/<string:options>',
                          view_func=self.get_sentence, methods=['GET'])

    def index(self, path='index.html'):
        return self.send_static_file(path)

    def _get_rule(self, rule):
        if rule is None:
            return {'cfg': '???', 'hrgSource': None, 'label': '???', 'comment': {}}

        hrg_rule = rule.hrg
        ep_count, hrg_source = 0, None
        if hrg_rule is not None:
            hrg_source = draw_hrg_rule(hrg_rule, output_format='source')
            ep_count = len(hrg_rule.lhs.nodes)

        cfg = '{}#{} ⇒ {}'.format(rule.cfg.lhs, ep_count,
                                  ' + '.join(_format_cfg_rule_item(*item) for item in rule.cfg.rhs))
        return {
            'cfg': cfg,
            'hrgSource': hrg_source,
            'label': hrg_rule.lhs.label if hrg_rule is not None else 'no semantic',
            'comment': hrg_rule.comment if hrg_rule is not None else {},
        }

    def get_grammars(self):
        options = ExtractionOptions()
        info = []
        for field_def in options.iter_fields():
            choices = field_choices(field_def)
            info.append([field_def.name,
                         getattr(options, field_def.name),
                         list(choices) if choices is not None else None])

        return jsonify({
            'optionsType': info,
            'grammars': [[grammar.name,
                          json.dumps({key: value
                                      for key, value in grammar.options.state_dict().items()
                                      if not key.startswith('_')})]
                         for grammar in self.grammars.values()]})

    def get_rule(self, grammar_name, rule_index):
        if grammar_name not in self.grammars:
            return abort(404)

        grammar = self.grammars[grammar_name]
        total_rule_count = len(grammar.rules)
        if rule_index < 0 or rule_index >= total_rule_count:
            return abort(404)
        rule, counter_item = grammar.rules[rule_index]

        decomposition = None
        if rule.hrg is not None:
            hyper_graph = rule.hrg.rhs
            external_nodes = rule.hrg.lhs.nodes
            tree_root = tree_decomposition(hyper_graph, external_nodes)
            decomposition = draw_tree_decomposition(tree_root, external_nodes,
                                                    output_format='source')

        return jsonify({
            'count': counter_item.count,
            'examples': counter_item.samples,
            'totalRuleCount': total_rule_count,
            'treeDecomposition': decomposition,
            **self._get_rule(rule)
        })

    def search_rule(self, grammar_name, condition):  # TODO: Change to POST
        try:
            rules = self.grammars[grammar_name].rules
            for key in ('import', 'open', 'while'):
                assert key not in condition, f'can not use "{key}" keyword'
            code = compile(condition, '<???>', mode='eval')

            indices = []
            scope = {'strange_rules': strange_rules}
            for index, (rule, counter_item) in enumerate(rules):
                scope.update(count=counter_item.count,
                             label=rule.cfg.lhs,
                             cfg_rhs=rule.cfg.rhs,
                             graph=rule.hrg.rhs if rule.hrg else None)
                try:
                    if eval(code, {}, scope):
                        indices.append(index)
                except Exception:
                    pass

            return jsonify(indices)
        except Exception as err:
            return jsonify({'errorMessage': str(err)})

    def _get_grammar(self, options_string):
        state = dict(json.loads(base64.b64decode(options_string)))
        state[TRAINING_KEY] = True
        options = ExtractionOptions()
        options.load_state_dict(state)

        if options != self.custom_options:
            self.custom_options = options
            self.custom_grammar = self.custom_grammar.copy(options)

        return self.custom_grammar

    @functools.lru_cache(maxsize=4096)
    def get_sentence(self, sentence_id, step, options):
        grammar = self._get_grammar(options)

        sample_index = grammar.sample_id_to_index(sentence_id)
        ret = {
            'nextId': grammar.index_to_sample_id(sample_index + 1),
            'prevId': grammar.index_to_sample_id(sample_index - 1)
        }

        result, const_tree, hyper_graph = grammar.get_derivations(sentence_id)

        cfg_nodes = list(const_tree.traverse_postorder())
        cfg_node = cfg_nodes[step]
        nodes_attrs = {cfg_nodes[step]: {'fillcolor': 'red', 'style': 'filled'}}

        if isinstance(result, Exception):  # error while extracting rules
            rule = self._get_rule(None)
            rule['comment']['error'] = f'{result.__class__.__qualname__} {str(result)}'
            return jsonify({
                'constTree': draw_const_tree(const_tree, output_format='source',
                                             nodes_attrs=nodes_attrs),
                'before': '', 'after': '',
                'rule': rule,
                'panorama': draw_hyper_graph(hyper_graph, output_format='source'),
                'totalStep': 0,
                **ret,
            })

        shrg_rules, (node_blame_dict, edge_blame_dict), derivation_infos = result

        if step >= len(shrg_rules):
            return abort(404)

        node2index = {node: index for index, node in enumerate(cfg_nodes)}
        left_index, right_index, leftmost_index = -1, -1, -1
        if cfg_node.children:
            child = cfg_node.children[0]
            left_index = node2index.get(child, -1)
            if not isinstance(child, Lexicon):
                leftmost_index = node2index.get(next(child.traverse_postorder()), -1)
        if len(cfg_node.children) > 1:
            right_index = node2index.get(cfg_node.children[1], -1)

        attrs_map = {}

        def _fill_attrs_map(items, blame_dict):
            for item in items:
                blame_step = blame_dict.get(item, -1)
                color = None
                if blame_step == step:
                    color = 'red'
                elif leftmost_index <= blame_step <= left_index:
                    color = 'blue'
                elif left_index < blame_step <= right_index:
                    color = 'green'
                if color:
                    attrs_map[item] = {'color': color}

        _fill_attrs_map(hyper_graph.nodes, node_blame_dict)
        _fill_attrs_map(hyper_graph.edges, edge_blame_dict)

        shrg_rule = shrg_rules[step]

        derivation_info = derivation_infos[step]
        labels_map = {}
        if derivation_info is not None:
            for node in derivation_info.external_nodes:
                attrs_map.setdefault(node, {})['shape'] = 'box'

        panorama = draw_hyper_graph(hyper_graph, output_format='source',
                                    attrs_map=attrs_map,
                                    nodes_labels_map=labels_map,
                                    use_node_name=True)

        before = draw_derivation(derivation_infos[step], output_format='source')
        after = draw_derivation(derivation_infos[step + 1], output_format='source')

        return jsonify({
            'constTree': draw_const_tree(const_tree, output_format='source',
                                         nodes_attrs=nodes_attrs),
            'before': before, 'after': after,
            'rule': self._get_rule(shrg_rule),
            'panorama': panorama,
            'totalStep': len(shrg_rules),
            **ret,
        })
