from __future__ import unicode_literals

import re
import sys
import traceback
from io import open

import pickle
from collections import Counter, defaultdict
import gzip

from multiprocessing import Pool, Process, Manager
from pathlib import Path

from typing import Any, List, Dict, Optional, Mapping, Tuple, Callable

from coli.basic_tools.dataclass_argparse import OptionsBase
from coli.hrgguru.derivation_analysis import SampleCounter
from coli.hrgguru.extra_labels import EPCountLabeler
from coli.hrgguru.graph_readers import graph_readers
from coli.hrgguru.const_tree import ConstTree, Lexicon
from coli.hrgguru.hrg import CFGRule, HRGDerivation
from coli.hrgguru.hyper_graph import HyperGraph, HyperEdge, GraphNode, strip_category
from dataclasses import dataclass, field

from coli.hrgguru.strip_utils import DONT_STRIP, STRIP_ALL_LABELS, STRIP_TO_UNLABEL, \
    FUZZY_TREE, STRIP_INTERNAL_LABELS, \
    punct_hyphen_fixer, strip_label, strip_label_internal, strip_unary, strip_to_unlabel, fuzzy_cfg, STRIP_TO_HEADS, \
    strip_to_heads
from coli.hrgguru.unlexicalized_rules import counter_factory, anonymize_derivation, transform_edge_label

detect_funcs = {"small": HRGDerivation.detect_small,
                "large": HRGDerivation.detect_large,
                "lexicalized": HRGDerivation.detect_lexicalized}

datasets = {"wsj": [("dev", "wsj20*"), ("test", "wsj21*"), ("train", ["wsj0*", "wsj1*"])],
            "wsj-debug": [("train", "wsj05*"), ("dev", "wsj03*"), ("test", "wsj04*")],
            "ws": [("dev", "ws12"), ("test", "ws13"), ("train", ["ws0*", "ws10", "ws11"])]}


@dataclass
class ExtractionParams(OptionsBase):
    strip_tree: int = field(default=DONT_STRIP,
                            metadata={"choices": [
                                DONT_STRIP, STRIP_ALL_LABELS,
                                STRIP_INTERNAL_LABELS, STRIP_TO_UNLABEL]})
    detect_func: Any = HRGDerivation.detect_small
    extra_labels: Any = EPCountLabeler
    ep_permutation_methods: Any = frozenset(["first_combine_time", "extra_label"])
    graph_type: str = field(default="eds",
                            metadata={"choices": [
                                "eds", "dmrs", "lfrg"]})
    strip_options: set = field(default_factory=set)
    extra: dict = field(default_factory=dict)
    limit: Optional[int] = None
    graph_transformers: List[Callable] = field(default_factory=list)
    tree_type: str = "normal"
    punct_hyphen_fixer: str = field(default="merge", metadata={"choices": punct_hyphen_fixer.keys()})
    fully_lexicalized: bool = False


def extract_features(hg,  # type: HyperGraph
                     cfg,  # type: ConstTree
                     log_func=print
                     ):
    delphin_span_to_word_span = {}
    for idx, node in enumerate(cfg.generate_words()):
        node.word_span = delphin_span_to_word_span[node.span] = (idx, idx + 1)
    for idx, node in enumerate(cfg.generate_rules()):
        node.word_span = delphin_span_to_word_span[node.span] = (
            node.child[0].word_span[0], node.child[-1].word_span[1])

    node_mapping = {}  # node -> pred edge
    real_edges = []
    for edge in hg.edges:  # type: HyperEdge
        if len(edge.nodes) == 1:
            main_node = edge.nodes[0]  # type: GraphNode
            if node_mapping.get(main_node) is not None:
                log_func("Dumplicate node name {} and {}!".format(
                    node_mapping[main_node],
                    edge.label
                ))
                continue
            if not edge.is_terminal:
                log_func("non-terminal edge {} found.".format(edge.label))
            node_mapping[main_node] = edge
        elif len(edge.nodes) == 2:
            real_edges.append(edge)
        else:
            log_func("Invalid hyperedge with node count {}".format(len(edge.nodes)))

    names = []
    args = []
    for node, pred_edge in node_mapping.items():
        assert pred_edge.span is not None
        names.append((delphin_span_to_word_span[pred_edge.span], strip_category(pred_edge.label)))

    for edge in real_edges:
        pred_edges = [node_mapping.get(i) for i in edge.nodes]
        if any(i is None for i in pred_edges):
            log_func("No span for edge {}, nodes {}!".format(edge, pred_edges))
            continue
        args.append((delphin_span_to_word_span[pred_edges[0].span], strip_category(pred_edges[0].label),
                     delphin_span_to_word_span[pred_edges[1].span], strip_category(pred_edges[1].label),
                     edge.label))
    return set(names), set(args)


def to_left_tree(cfg: ConstTree) -> ConstTree:
    sub_trees = list(cfg.generate_words())
    if len(sub_trees) == 1:
        return cfg
    sub_trees.reverse()
    while len(sub_trees) >= 1:
        left_tree = sub_trees.pop()
        if isinstance(left_tree, Lexicon):
            tree = ConstTree("X")
            tree.children.append(left_tree)
            left_tree = tree
        if len(sub_trees) == 0:
            return left_tree
        right_tree = sub_trees.pop()
        if isinstance(right_tree, Lexicon):
            tree = ConstTree("X")
            tree.children.append(right_tree)
            right_tree = tree
        new_tree = ConstTree("X")
        new_tree.children.append(left_tree)
        new_tree.children.append(right_tree)
        sub_trees.append(new_tree)


def get_cfg_and_hg(tree_literal, export_contents, params: ExtractionParams):
    cfg = ConstTree.from_java_code_deepbank_1_1(tree_literal, export_contents)
    modify_tree(cfg, params)

    fields = export_contents.strip().split("\n\n")

    # load semantic graph
    hg = graph_readers[params.graph_type](fields, params, wipe_names=False)
    for graph_transformer in params.graph_transformers:
        hg = graph_transformer(hg)

    cfg = cfg.condensed_unary_chain()
    cfg.populate_spans_internal()
    punct_hyphen_fixer[params.punct_hyphen_fixer](cfg, hg)

    if params.tree_type == "left":
        cfg = to_left_tree(cfg)
        cfg.populate_spans_internal()

    return cfg, hg


def modify_tree(cfg, params):
    # strip labels
    if params.strip_tree == STRIP_ALL_LABELS:
        strip_label(cfg)
        strip_unary(cfg)
    elif params.strip_tree == STRIP_INTERNAL_LABELS:
        strip_label_internal(cfg)
        strip_unary(cfg)
    elif params.strip_tree == STRIP_TO_HEADS:
        strip_to_heads(cfg)
        strip_unary(cfg)
    elif params.strip_tree == STRIP_TO_UNLABEL or params.strip_tree == FUZZY_TREE:
        strip_to_unlabel(cfg)


def process_modified_sentence(sent_id, tree_literal, export_contents,
                              modified_part, params: ExtractionParams):
    cfg = ConstTree.from_java_code_deepbank_1_1(tree_literal, export_contents)
    modify_tree(cfg, params)

    cfg = cfg.condensed_unary_chain()
    cfg.populate_spans_internal()
    punct_hyphen_fixer[params.punct_hyphen_fixer](cfg)

    fields = export_contents.strip().split("\n\n")

    # load semantic graph
    hg = graph_readers[params.graph_type](modified_part, params, direct=True)
    return process_sentence(
        sent_id, cfg, hg, True, params)


def cfg_to_string(sent_id, cfg, additional_props=None, with_comma=False):
    ret = "# ID: " + sent_id + "\n"
    if additional_props is not None:
        for key, value in additional_props.items():
            ret += "# {}: ".format(key) + repr(value) + "\n"
    ret += cfg.to_string(with_comma=with_comma).replace("+++", "+!+")
    return ret


def process_sentence(sent_id, cfg, hg, is_train, params: ExtractionParams, log_func=print):
    if params.extra.get("connective") and not hg.is_connected():
        raise Exception(f"Sentence {sent_id} is not connected")
    if params.graph_type != "lfrg" and params.strip_tree == FUZZY_TREE:
        names, args = extract_features(hg, cfg)
        cfg = fuzzy_cfg(cfg, names)
    else:
        names, args = None, None

    ret = CFGRule.extract(hg, cfg,
                          ep_permutation_methods=params.ep_permutation_methods,
                          sent_id=sent_id,
                          detect_func=params.detect_func,
                          extra_labeler_class=params.extra_labels,
                          graph_type=params.graph_type,
                          fully_lexicalized=params.fully_lexicalized,
                          log_func=log_func
                          )

    derivations = ret["derivations"]

    additional_props = {
        "DelphinSpans": [i.span for i in cfg.generate_words()],
    }

    if params.graph_type != "lfrg" and params.tree_type == "normal":
        if names is None:
            names, args = extract_features(hg, cfg)
        additional_props["Args"] = list(args)
        additional_props["Names"] = list(names)

    # check binary
    if any(rule
           for rule in cfg.generate_rules() if len(rule.child) > 2):
        if is_train:
            print("{} Not binary tree!".format(sent_id))
        else:
            raise Exception("Not binary tree!")

    new_derivation, lemma_mappings = anonymize_derivation(derivations)
    lexical_labels = []
    lexical_attachments = []
    internal_attachments = []
    lexical_nodes = []
    internal_nodes = []
    for sync_rule, tree_node in zip(new_derivation, cfg.generate_rules()):
        may_be_lexical_edge = None
        if isinstance(sync_rule.rhs[0][0], Lexicon):
            may_be_lexical_edge = sync_rule.rhs[0][1]

        attachments_i = []

        if sync_rule.hrg is not None:
            for edge in sync_rule.hrg.rhs.edges:
                if edge.is_terminal and len(edge.nodes) == 1:
                    if may_be_lexical_edge is None or edge.label != may_be_lexical_edge.label:
                        attachments_i.append(edge.label)

        if isinstance(sync_rule.rhs[0][0], Lexicon):
            lexicon = tree_node.children[0].string
            lexical_labels.append(may_be_lexical_edge.label
                                  if may_be_lexical_edge is not None else "None")
            lexical_attachments.append(attachments_i)
            lexical_nodes.append(
                [(i.nodes[0].name, transform_edge_label(i.label, lexicon, False))
                 for i in ret["node_distribution"][tree_node]])
        else:
            internal_attachments.append(attachments_i)
            internal_nodes.append(
                [(i.nodes[0].name, i.label) for i in ret["node_distribution"][tree_node]])

    additional_props["LexicalLabels"] = lexical_labels
    additional_props["LexicalAttachments"] = lexical_attachments
    additional_props["InternalAttachments"] = internal_attachments
    additional_props["LexicalNodes"] = lexical_nodes
    additional_props["InternalNodes"] = internal_nodes

    if params.graph_type != "lfrg":
        structual_edges = defaultdict(list)
        for i in hg.edges:
            if len(i.nodes) == 2:
                structual_edges[tuple(node.name for node in i.nodes)].append(i.label)
        additional_props["StructualEdges"] = [(s, t, "&&&".join(sorted(labels)))
                                              for (s, t), labels in structual_edges.items()
                                              ]

    original_cfg_string = cfg_to_string(sent_id, cfg, additional_props)
    # TODO: remove side effect on cfg
    ret["extra_labeler"].rewrite_cfg_label(cfg, ret["derivations"], ret["original_node_map"])
    additional_cfg_string = cfg_to_string(sent_id, cfg, additional_props)

    original_word_and_tags = []
    for sync_rule, tree_node in zip(derivations, cfg.generate_rules()):
        if isinstance(sync_rule.rhs[0][0], Lexicon):
            original_word_and_tags.append((sync_rule.rhs[0][0].string,
                                           tree_node.tag.replace("+++", "+!+")))
        else:
            original_word_and_tags.append(None)

    def get_rule_keyword(rule):
        return (rule.tag.replace("+++", "+!+"),
                tuple(Lexicon("{NEWLEMMA}") if isinstance(i, Lexicon) else i.tag.replace("+++", "+!+")
                      for i in rule.children))

    rule_keywords = [get_rule_keyword(i) for i in cfg.generate_rules()]

    return sent_id, new_derivation, original_cfg_string, additional_cfg_string, \
           lemma_mappings, original_word_and_tags, rule_keywords


class SyncGrammarExtractor(object):
    def __init__(self,
                 deepbank_export_path,
                 java_out_dir,
                 output_dir="./deepbank-preprocessed/"
                 ):
        self.deepbank_export_path = deepbank_export_path
        self.java_out_dir = java_out_dir
        self.output_dir = output_dir

    def extract(self, name, params, mode, banks="*", process_count=8):
        manager = Manager()
        return ExtractionSession(
            self, name, params, mode, banks, process_count, manager).extract()


class ExtractionSession(object):
    def __init__(self, extractor: SyncGrammarExtractor, name, params, mode,
                 banks, process_count, sync_manager):
        self.extractor = extractor
        self.name = name
        self.mode = mode
        self.params = params
        self.process_count = process_count
        self.bus = sync_manager.Queue()
        self.lexicon_lookup = defaultdict(counter_factory)

        if isinstance(banks, (str, bytes)):
            banks = [banks]

        self.bank_names = [i.name for bank_mask in banks
                           for i in Path(extractor.java_out_dir).glob(bank_mask)]

        if params.limit is not None:
            self.bank_names = self.bank_names[:params.limit]

    def log(self, log_str):
        self.bus.put(("___LOG___", log_str))

    def extract(self):
        result_collector_process = Process(target=self.results_writer)
        result_collector_process.start()
        pool = Pool(processes=self.process_count)
        pool.map(self.worker, self.bank_names)
        self.bus.put(("___TASK_DONE___", None))
        result_collector_process.join()
        result_collector_process.terminate()
        pool.terminate()

    def worker(self, bank_name):
        with open(self.extractor.java_out_dir + "/" + bank_name, encoding="utf-8") as f:
            while True:
                sent_id = f.readline().strip()
                if not sent_id:
                    break
                assert sent_id.startswith("#")
                sent_id = sent_id[1:]
                tree_literal = f.readline().strip()
                with gzip.open(self.extractor.deepbank_export_path + bank_name + "/" + sent_id + ".gz",
                               "rb") as f_gz:
                    contents = f_gz.read().decode("utf-8")
                try:
                    cfg, hg = get_cfg_and_hg(tree_literal, contents, self.params)
                except Exception as e:
                    self.bus.put(("___RESULT___", (sent_id, None, None, None, [], [], None)))
                    print(f"Error on {sent_id}:\n"
                          f"{traceback.format_exc()}", file=sys.stderr)
                    continue

                try:
                    result = process_sentence(
                        sent_id, cfg, hg, self.mode == "train", self.params, self.log)
                    self.bus.put(("___RESULT___", result))
                except Exception as e:
                    cfg_string = cfg_to_string(
                        sent_id, cfg,
                        {"Status": "FAILED", "DelphinSpans":
                            [i.span for i in cfg.generate_words()]
                         })
                    self.bus.put(("___RESULT___", (sent_id, None, cfg_string, cfg_string, [], [], None)))
                    print(f"Error on {sent_id}:\n"
                          f"{traceback.format_exc()}", file=sys.stderr)

    def results_writer(self):
        all_rules: Dict[CFGRule, CFGRule] = {}
        derivations = {}
        output_dir = self.extractor.output_dir
        output_file = output_dir + "/" + self.name + "." + self.mode
        output_fulllabel_file = output_dir + "/" + self.name + ".fulllabel." + self.mode
        lexicon_lookup: Mapping[Tuple[str, str], Mapping[CFGRule, int]] = defaultdict(Counter)
        lemma_dict: Mapping[str, Mapping[str, int]] = defaultdict(Counter)

        def convert_derivation(sent_id, original_word_and_tags, derivation: List[CFGRule]) -> List[CFGRule]:
            """ convert multiple object of the same rule into one object"""
            for rule_idx, (original_word_and_tag, rule) in enumerate(zip(original_word_and_tags, derivation)):
                standard_rule = all_rules.get(rule)
                if standard_rule is None:
                    standard_rule = all_rules[rule] = rule
                    if rule.hrg is not None:
                        ep_permutation = rule.hrg.comment.pop("EP permutation", None)
                        if ep_permutation:
                            ep_permutation = ep_permutation.split("***")[0].strip()
                            standard_rule.hrg.comment[
                                "EP permutation at {} step {}".format(sent_id, rule_idx)] = ep_permutation
                else:
                    if standard_rule.hrg is not None:
                        assert rule.hrg is not None
                        ep_permutation = rule.hrg.comment.get("EP permutation")
                        if ep_permutation is not None:
                            ep_permutation = ep_permutation.split("***")[0].strip()
                            if ep_permutation not in standard_rule.hrg.comment.values():
                                standard_rule.hrg.comment[
                                    "EP permutation at {} step {}".format(sent_id, rule_idx)] = ep_permutation

                if original_word_and_tag is not None:
                    lexicon_lookup[original_word_and_tag][rule] += 1

                yield standard_rule

        rule_count = SampleCounter()
        sync_grammar_lookup: Dict[tuple, Dict[CFGRule, int]] = defaultdict(Counter)

        total_sent = 0
        success_sent = 0
        derivation_sent = 0

        cfg_outputs = {}

        with open(f"{self.extractor.output_dir}/{self.name}-{self.mode}.log", "w") as f_log:
            while True:
                command, result = self.bus.get()
                if command == "___TASK_DONE___":
                    break
                elif command == "___LOG___":
                    f_log.write(result)
                    f_log.write("\n")
                    continue
                elif command != "___RESULT___":
                    print(f"ignore invalid command: {command}")

                sent_id, derivation, original_cfg, additional_cfg, \
                lemma_mappings, original_word_and_tags, keywords = result
                total_sent += 1
                has_tree = original_cfg is not None
                has_derivation = derivation is not None

                if has_tree:
                    success_sent += 1
                    if self.mode != "train" or has_derivation:
                        # only write failed tree when it is not training
                        cfg_outputs[sent_id] = original_cfg, additional_cfg

                if has_derivation:
                    derivation_sent += 1

                    if self.mode == "train":
                        for lexicon, lemma in lemma_mappings:
                            lemma_dict[lexicon][lemma] += 1
                        derivation = derivations[sent_id] = list(
                            convert_derivation(sent_id, original_word_and_tags, derivation))

                        # count rules
                        for step, rule in enumerate(derivation):
                            rule_count.add(rule, (sent_id, step))

                        # fill sync grammar lookup table
                        for keyword, rule in zip(keywords, derivation):
                            sync_grammar_lookup[keyword][rule] = rule_count[rule].count

        with open(output_file, "w") as f_1, \
                open(output_fulllabel_file, "w") as f_2:
            for sent_id, (original_cfg, additional_cfg) in sorted(cfg_outputs.items()):
                f_1.write(original_cfg + "\n")
                f_2.write(additional_cfg + "\n")

        print("Done processing \"{}\", {}/{} read, {}/{} has derivation".format(
            self.mode, success_sent, total_sent, derivation_sent, total_sent))

        if self.mode != "train":
            return

        postag_mapping = defaultdict(Counter)
        lexical_label_mapping = defaultdict(Counter)
        for (word, main_node_count), graph_counter in lexicon_lookup.items():
            try:
                main_node_count = int(main_node_count)
            except ValueError:
                main_node_count = len(main_node_count)
        for keyword, rules_counter in sync_grammar_lookup.items():
            tag = keyword[0]
            if isinstance(keyword[1][0], Lexicon) and keyword[1][0].string == "{NEWLEMMA}":
                postag_mapping[tag].update(rules_counter)
                for rule, count in rules_counter.items():
                    may_be_lexical_edge = rule.rhs[0][1]
                    lexical_label = may_be_lexical_edge.label \
                        if may_be_lexical_edge is not None else "None"
                    lexical_label_mapping[lexical_label, tag][rule] += count

        name = self.name
        # write derivations
        with open(output_dir + "/derivations-{}.pickle".format(name), "wb") as f:
            pickle.dump(derivations, f)

        # write count
        with open(output_dir + "/count-{}.pickle".format(name), "wb") as f:
            pickle.dump(rule_count, f)

        with open(output_dir + "/cfg_hrg_mapping-{}.pickle".format(name), "wb") as f:
            pickle.dump((sync_grammar_lookup, lexicon_lookup, lemma_dict, postag_mapping, lexical_label_mapping), f)

        with open(output_dir + "/params-{}.pickle".format(name), "wb") as f:
            pickle.dump(self.params, f)


def extract_dataset(name, dataset, params, java_out_dir, deepbank_export_path, output_dir):
    extractor = SyncGrammarExtractor(deepbank_export_path, java_out_dir, output_dir)
    for mode, name_mask in dataset:
        extractor.extract(name, params, mode, name_mask)
