import gzip
import pickle
import re
import sys
from collections import defaultdict
from functools import lru_cache
from pathlib import Path
from typing import Optional

import six
from dataclasses import dataclass

from coli.hrgguru.const_tree import Lexicon, ConstTree
from coli.hrgguru.extract_sync_grammar import ExtractionParams, get_cfg_and_hg, process_modified_sentence
from coli.hrgguru.hrg import HRGDerivation, CFGRule
from coli.hrgguru.sync_grammar_statistics import get_statistics
from coli.hrgguru.unlexicalized_rules import anonymize_derivation
from coli.hrg_parser.webapi_dispatcher import draw_hg_as_eds

detect_funcs = {"lexicalized": HRGDerivation.detect_lexicalized,
                "small": HRGDerivation.detect_small}


@dataclass
class Grammar(object):
    base_dir: str
    java_out_dir: str
    deepbank_export_path: str
    name: str
    project_name: str
    extraction_params: Optional[ExtractionParams]
    cfg_index_file: str = "wsj*"

    @property
    def file_name(self):
        return self.base_dir + "/count-{}.pickle".format(self.project_name)

    def __post_init__(self):
        self._cfg_index = None
        self._rules = None
        if getattr(sys, "gettrace", None) is None:
            self.get_derivations = lru_cache(maxsize=32)(self.get_derivations)
        if self.extraction_params is None:
            with open(self.base_dir + "params-{}.pickle".format(self.project_name), "rb") as f:
                self.extraction_params = pickle.load(f)

    @property
    def cfg_index(self):
        if self._cfg_index is None:
            self._cfg_index = self.load_cfg_index()
        return self._cfg_index

    @property
    def rules(self):
        if self._rules is None:
            self._rules = self.load_rules(self.file_name)
        return self._rules

    @classmethod
    def load_rules(cls, filename):
        with open(filename, "rb") as f:
            rule_count = pickle.load(f)
            return sorted(((rule, (count, example))
                           for rule, (count, example) in rule_count.items()
                           ),
                          key=lambda x: x[1][0], reverse=True)

    def load_cfg_index(self):
        results = {}
        main_dir = self.java_out_dir
        for tree_file in Path(main_dir).glob(self.cfg_index_file):
            with open(tree_file) as f:
                bank = tree_file.parts[-1]
                while True:
                    sent_id = f.readline().strip()
                    if not sent_id:
                        break
                    assert sent_id.startswith("#")
                    sent_id = sent_id[1:]
                    tree_literal = f.readline().strip()
                    results[sent_id] = (bank, tree_literal)
        return results

    def get_cfg_and_hg(self, sent_id):
        bank, tree_literal = self.cfg_index[sent_id]
        with gzip.open(self.deepbank_export_path + bank + "/" + sent_id + ".gz",
                       "rb") as f_gz:
            contents = f_gz.read().decode("utf-8")
        cfg, hg = get_cfg_and_hg(tree_literal, contents,
                                 self.extraction_params,
                                 )
        return cfg, hg

    def get_derivations(self, sent_id):
        cfg, hg = self.get_cfg_and_hg(sent_id)
        ret = CFGRule.extract(hg, cfg,
                              self.extraction_params.ep_permutation_methods,
                              extra_labeler_class=self.extraction_params.extra_labels,
                              sent_id=sent_id,
                              draw=True,
                              graph_type=self.extraction_params.graph_type,
                              fully_lexicalized=self.extraction_params.fully_lexicalized,
                              detect_func=self.extraction_params.detect_func)
        new_derivation, lemma_mappings = anonymize_derivation(ret["derivations"])
        ret["orig_derivations"] = ret["derivations"]
        ret["derivations"] = new_derivation
        return ret, cfg, hg

    def get_derivations_mod(self, sent_id, mod_part, return_pics=True):
        bank, tree_literal = self.cfg_index[sent_id]
        with gzip.open(self.deepbank_export_path + bank + "/" + sent_id + ".gz",
                       "rb") as f_gz:
            contents = f_gz.read().decode("utf-8")
        _, ret, _, _ = process_modified_sentence(
            sent_id, tree_literal, contents, mod_part, self.extraction_params
        )
        return ret, contents, tree_literal

    def get_statistics(self):
        statistics = getattr(self, "statistics", None)
        if statistics is not None:
            return statistics
        with open(self.base_dir + "cfg_hrg_mapping-{}.pickle".format(self.project_name), "rb") as f:
            cfg_hrg_mapping = pickle.load(f)

        self.statistics = get_statistics(cfg_hrg_mapping, True)
        return self.statistics


class DerivationServiceDispatcher(object):
    def __init__(self, grammars):
        self.grammars = grammars

    def dispatch(self, service):
        service.api.dispatcher.add_method(self.get_rule)
        service.api.dispatcher.add_method(self.get_sentence)
        service.api.dispatcher.add_method(self.get_mod_sentence)
        service.api.dispatcher.add_method(self.get_statistics)
        service.api.dispatcher.add_method(self.search_rule)
        service.api.dispatcher.add_method(self.read_dataset)

    def encode_cfg_label(self, rule: CFGRule):
        def format_tag_and_edge(tag, edge):
            if isinstance(tag, Lexicon):
                return tag.string
            assert isinstance(tag, six.string_types)
            return "{}#{}".format(tag.replace("+++", "@"),
                                  0 if edge is None else len(edge.nodes))

        external_point_count = len(rule.hrg.lhs.nodes) if rule.hrg is not None else 0
        cfg_string = "{}#{}".format(rule.lhs.replace("+++", "@"), external_point_count) + " -> " + \
                     " + ".join(format_tag_and_edge(tag, edge) for tag, edge in rule.rhs)
        return cfg_string

    def search_by_lexical_edge(self, rule: CFGRule, tag, expect_lexical_label):
        may_be_lexical_edge = rule.rhs[0][1]
        lexical_label = may_be_lexical_edge.label \
            if may_be_lexical_edge is not None else "None"
        return tag == rule.lhs and lexical_label == expect_lexical_label

    def get_rule(self, grammar, rule_id):
        rule, (count, example) = self.grammars[grammar].rules[rule_id]
        viz = rule.draw_source()

        return {"viz": viz, "count": count,
                "cfg": self.encode_cfg_label(rule),
                "comment": rule.hrg.comment if rule.hrg is not None else None,
                "label": rule.hrg.lhs.label if rule.hrg is not None else 0,
                # "oracle": get_oracle_2(rule),
                "example_file": example[0], "step": example[1]}

    def search_rule(self, grammar, condition):
        query = "[idx for idx, (rule, (count, example)) in " \
                "enumerate(self.grammars[\"{}\"].rules) " \
                "if {}]".format(grammar, condition.replace("%23", "#"))
        return eval(query, dict(self=self))

    @classmethod
    def to_string_with_spans(cls, cfg):
        return "({}___{}___{},{} {})".format(cfg.tag, cfg.postorder_idx, cfg.span[0], cfg.span[1],
                                             " ".join([(cls.to_string_with_spans(i))
                                                       if isinstance(i, ConstTree)
                                                       else "{}".format(
                                                 i.string.strip()) for i in cfg.child]))

    def read_dataset(self, grammar, sent_id):
        cfg, hg = self.grammars[grammar].get_cfg_and_hg(sent_id)
        if self.grammars[grammar].extraction_params.graph_type != "lfrg":
            graph = draw_hg_as_eds(hg, svg=False)
        else:
            graph = hg.draw(file_format="source")
        return {"text": self.to_string_with_spans(cfg).replace("+!+", "@"),
                "graph": graph}

    def get_sentence(self, grammar, sent_id, step):
        ret, cfg, hg = self.grammars[grammar].get_derivations(sent_id)
        cfg: ConstTree = cfg.copy()
        node_blame_dict, edge_blame_dict = ret["blame_dicts"]
        ret["extra_labeler"].rewrite_cfg_label(cfg, ret["orig_derivations"], ret["original_node_map"])
        rules = list(cfg.generate_rules())
        tree_node = rules[step]

        cfg.add_postorder_idx()
        left_idx = tree_node.child[0].postorder_idx \
            if len(tree_node.child) >= 1 and isinstance(tree_node.child[0], ConstTree) else -1
        leftmost_idx = next(rules[left_idx].generate_rules()).postorder_idx
        right_idx = tree_node.child[1].postorder_idx \
            if len(tree_node.child) >= 2 and isinstance(tree_node.child[1], ConstTree) else -1

        attrs = defaultdict(dict)
        for edge in hg.edges:
            edge_step = edge_blame_dict.get(edge, -1)
            if edge_step == step:
                attrs[edge] = {"color": "red"}
            elif leftmost_idx <= edge_step <= left_idx:
                attrs[edge] = {"color": "blue"}
            elif left_idx < edge_step <= right_idx:
                attrs[edge] = {"color": "green"}

        for node in hg.nodes:
            node_blame_step = node_blame_dict.get(node, -1)
            if node_blame_step == step:
                attrs[node] = {"color": "red"}
            if leftmost_idx <= node_blame_step <= left_idx:
                attrs[node] = {"color": "blue"}
            elif left_idx < node_blame_step <= right_idx:
                attrs[node] = {"color": "green"}

        rule: CFGRule = ret["derivations"][step]
        original_node_map = ret["original_node_map"][step]

        if rule.hrg is not None:
            for node in rule.hrg.lhs.nodes:
                attrs[original_node_map[node]].update(
                    {"shape": "square"})

        panorama = hg.draw(
            "", file_format="source", attr_map=attrs,
            node_name_map={original_node_map[node]: name
                           for node, name in rule.hrg.get_node_name_map().items()
                           } if rule.hrg is not None else None
        )

        return {"before": ret["pics"][step + 1],
                "after": ret["pics"][step],
                "lhs": rule.lhs.replace("+++", "@"),
                "span": list(tree_node.span),
                "text": self.to_string_with_spans(cfg).replace("+!+", "@"),
                "rule": rule.draw_source(),
                "panorama": panorama,
                "comment": rule.hrg.comment if rule.hrg is not None else None,
                "cfg_rhs": [{"tag": child.tag if isinstance(child, ConstTree) else child.string,
                             "start": child.span[0],  # i.span[0]
                             "stop": child.span[1]}  # i.span[1]}
                            for child in tree_node.child]
                }

    def get_mod_sentence(self, grammar, sent_id, mod_part, step):
        (derivations, pics, cfg), gzip_contents, tree_literal = self.grammars[grammar].get_derivations_mod(sent_id,
                                                                                                           mod_part)
        rules = list(cfg.generate_rules())
        tree_node = rules[step]
        text = re.findall(r"`(.*?)'$", gzip_contents.strip().split("\n\n")[1],
                          re.MULTILINE)[0]

        rule: CFGRule = derivations[step]
        return {"before": pics[step + 1],
                "after": pics[step],
                "lhs": rule.lhs.replace("+++", "@"),
                "span": list(tree_node.span),
                "text": text + "   " + tree_literal,
                "rule": rule.draw_source(),
                "cfg_rhs": [{"tag": child.tag if isinstance(child, ConstTree) else child.string,
                             "start": child.span[0],  # i.span[0]
                             "stop": child.span[1]}  # i.span[1]}
                            for child in tree_node.child]
                }

    def get_statistics(self, grammar):
        return self.grammars[grammar].get_statistics()
