# encoding: utf-8
from __future__ import unicode_literals, division

from dataclasses import dataclass

__assets__ = ["utils/evalb"]

import tempfile

import os
from multiprocessing.pool import Pool
from typing import List, Generator, Tuple, Union, Optional, Dict

import numpy as np
import re
import six
import attr

from coli.data_utils.dataset import PAD, UNKNOWN, START_OF_WORD, END_OF_WORD, START_OF_SENTENCE, END_OF_SENTENCE, \
    CHAR_START_OF_SENTENCE, CHAR_END_OF_SENTENCE
from coli.data_utils.vocab_utils import Dictionary
from coli.basic_tools.common_utils import smart_open, add_slots


class ConstTreeParserError(Exception):
    def __init__(self, message):
        self.value = message

    def __str__(self):
        return self.value


@six.python_2_unicode_compatible
class Lexicon(object):
    def __init__(self, string, span=None):
        self.string = string
        self.span = span

    def __str__(self):
        return u"Lexicon <{}>".format(self.string)

    def __repr__(self):
        return self.__str__()

    def __hash__(self):
        return hash(self.string) + 2

    def __eq__(self, other):
        return self.string == other.string


@six.python_2_unicode_compatible
class ConstTree(object):
    """
    c-structure of LFG.
    """

    def __init__(self, tag, span=None, extra_info=None):
        self.children: List[Union["ConstTree", Lexicon]] = []
        self.tag: str = tag
        self.span: Optional[Tuple[int, int]] = span
        self.extra: Dict = extra_info or {}
        self.score: float = 0.0

    def __hash__(self):
        return id(self)

    def __str__(self):
        """
        print tree in the console
        :return: tree
        """
        child_string = " + ".join(i.string if isinstance(i, Lexicon) else i.tag
                                  for i in self.children)
        return "{} -> {}".format(self.tag, child_string)

    def __repr__(self):
        return self.__str__()

    def __eq__(self, other):
        if self.tag != other.tag:
            return False
        if len(self.children) != len(other.children):
            return False
        for i, j in zip(self.children, other.children):
            if i != j:
                return False
        return True

    def __getitem__(self, item):
        if isinstance(item, six.string_types):
            for i in self.children:
                if isinstance(i, ConstTree) and i.tag.upper() == item.upper():
                    return i
        if isinstance(item, int):
            return self.children[item]
        raise KeyError

    @classmethod
    def make_leaf_node(cls, tag, lexicon):
        ret = cls(tag)
        ret.children.append(Lexicon(lexicon))
        return ret

    @classmethod
    def make_internal_node(cls, tag, sub_trees):
        ret = cls(tag)
        ret.children.extend(sub_trees)
        return ret

    @classmethod
    def from_string(cls, s):
        tokens = s.replace("(", " ( ").replace(")", " ) ").split()
        stack: List[ConstTree] = [ConstTree("__ROOT__")]
        index = 0

        while index < len(tokens):
            # scan
            token = tokens[index]
            index += 1

            if token == "(":
                # read label
                label = tokens[index]
                index += 1
                tree = cls(label)
                stack.append(tree)
            elif token == ")":
                tree_or_lexicon = stack.pop()
                stack[-1].children.append(tree_or_lexicon)
            else:
                # is lexicon
                postag_node = stack.pop()
                postag_node.children.append(Lexicon(token))
                stack[-1].children.append(postag_node)
                # read rpar
                rpar = tokens[index]
                index += 1
                assert rpar == ")"

        assert len(stack) == 1 and len(stack[0].children) == 1
        tree = stack[0].children[0]
        tree.populate_spans_terminal()
        tree.populate_spans_leaf()
        return tree

    def populate_spans_terminal(self):
        for idx, i in enumerate(self.generate_words()):
            i.span = (idx, idx + 1)

    def populate_spans_leaf(self):
        for i in self.children:
            if isinstance(i, ConstTree):
                i.populate_spans_leaf()
        self.span = (self.children[0].span[0], self.children[-1].span[1])

    def generate_words(self) -> Generator[Lexicon, None, None]:
        stack = [(self, 0)]
        while stack:
            node, next_child = stack.pop()
            if next_child + 1 < len(node.children):
                stack.append((node, next_child + 1))
            child_node = node.children[next_child]

            if isinstance(child_node, ConstTree):
                stack.append((child_node, 0))
            else:
                assert isinstance(child_node, Lexicon)
                yield child_node

    @property
    def words(self) -> List[str]:
        if not hasattr(self, "_words"):
            self._words = list(i.string for i in self.generate_words())
        return self._words

    def generate_preterminals(self) -> Generator["ConstTree", None, None]:
        stack = [(self, 0)]
        while stack:
            node, next_child = stack.pop()
            if next_child + 1 < len(node.children):
                stack.append((node, next_child + 1))
            child_node = node.children[next_child]

            if isinstance(child_node, ConstTree):
                stack.append((child_node, 0))
            else:
                assert isinstance(child_node, Lexicon)
                yield node

    @property
    def postags(self) -> List[str]:
        return list(i.tag for i in self.generate_preterminals())

    def generate_word_and_postag(self) -> Generator[Tuple[str, str], None, None]:
        for i in self.generate_preterminals():
            yield (i.children[0].string, i.tag)

    def generate_rules(self) -> Generator["ConstTree", None, None]:
        stack = [(self, 0)]
        while stack:
            node, next_child = stack.pop()
            if next_child + 1 <= len(node.children):
                stack.append((node, next_child + 1))
            else:
                yield node
                continue

            child_node = node.children[next_child]

            if isinstance(child_node, ConstTree):
                stack.append((child_node, 0))

    def root_first(self) -> Generator["ConstTree", None, None]:
        stack = [(self, 0)]
        while stack:
            node, next_child = stack.pop()
            if next_child == 0:
                yield node
            if next_child + 1 < len(node.children):
                stack.append((node, next_child + 1))
            child_node = node.children[next_child]

            if isinstance(child_node, ConstTree):
                stack.append((child_node, 0))

    def generate_spans(self):
        for i in self.children:
            if isinstance(i, ConstTree):
                for j in i.generate_spans():
                    yield j
            else:
                assert isinstance(i, Lexicon)
                yield i.span
        yield self.span

    def expanded_unary_chain(self, wrap_top=None):
        tags = self.tag.split("+++")
        root_node = last_node = ConstTree(tags[0], self.span)
        for tag in tags[1:]:
            current_node = ConstTree(tag, self.span)
            last_node.children = [current_node]
            last_node = current_node

        if isinstance(self.children[0], Lexicon):
            last_node.children = list(self.children)
        else:
            last_node.children = list(i.expanded_unary_chain()
                                      for i in self.children)

        if wrap_top:
            old_root = root_node
            root_node = self.__class__(wrap_top)
            root_node.children.append(old_root)

        root_node.extra = self.extra
        root_node.score = self.score

        return root_node

    def condensed_unary_chain(self, include_postag=True, strip_top=None):
        if self.tag == strip_top:
            assert len(self.children) == 1
            return self.children[0].condensed_unary_chain(include_postag)

        if len(self.children) == 1:
            if isinstance(self.children[0], Lexicon):
                ret = ConstTree(self.tag if include_postag else "___EMPTY___", self.span)
                ret.children = list(self.children)
                return ret
            else:
                assert isinstance(self.children[0], ConstTree)
                tail = self
                new_tag = self.tag

                while len(tail.children) == 1 and isinstance(tail.children[0], ConstTree):
                    tail = tail.children[0]
                    if include_postag or isinstance(tail.children[0], ConstTree):
                        new_tag += "+++" + tail.tag

                ret = ConstTree(new_tag, self.span)
                if len(tail.children) == 1:
                    ret.children = list(tail.children)
                else:
                    ret.children = list(i.condensed_unary_chain(include_postag)
                                        for i in tail.children)
                return ret
        else:
            ret = ConstTree(self.tag, self.span)
            ret.children = list(i.condensed_unary_chain(include_postag)
                                for i in self.children)
            return ret

    def generate_scoreable_spans(self):
        if self.children and isinstance(self.children[0], ConstTree):
            yield from self.children[0].generate_scoreable_spans()
            for i in range(1, len(self.children) - 1):
                yield from self.children[i].generate_scoreable_spans()
                # yield (self.children[0].span[0], self.children[i].span[1], "___EMPTY___")
            yield from self.children[-1].generate_scoreable_spans()
        yield self.span + (self.tag,)

    def to_parathesis(self, suffix="\n"):
        return "({} {}){}".format(self.tag, " ".join([i.to_parathesis("") if isinstance(i, ConstTree)
                                                      else i.string for i in self.children]),
                                  suffix)

    def to_string(self, with_comments=True):
        ret = ""
        if with_comments:
            ret = "\n".join("# {}: {}".format(k, v) for k, v in self.extra.items()) + \
                  ("\n" if self.extra else "")
        return ret + self.to_parathesis()

    def __len__(self):
        return len(self.words)

    @classmethod
    def from_file(cls, file_name, use_edge=None, limit=float("inf")):
        result = []
        with smart_open(file_name) as f:
            extra_info = {}
            count = 0
            for line in f:
                if count >= limit:
                    break
                line_s = line.strip()
                if not line_s:
                    continue
                if line_s.startswith("#"):
                    key, _, value = line_s[1:].partition(":")
                    if value:
                        extra_info[key.strip()] = value.strip()
                    continue
                tree = cls.from_string(line_s)
                tree.extra = extra_info
                result.append(tree)
                extra_info = {}
                count += 1
        return result

    @classmethod
    def from_file_multiprocess(cls, file_name, use_edge=None, limit=float("inf")):
        result = []
        with smart_open(file_name) as f:
            extra_info = {}
            count = 0
            for line in f:
                if count >= limit:
                    break
                line_s = line.strip()
                if not line_s:
                    continue
                if line_s.startswith("#"):
                    key, _, value = line_s[1:].partition(":")
                    if value:
                        extra_info[key.strip()] = value.strip()
                    continue
                result.append((line_s, extra_info))
                extra_info = {}
                count += 1
        with Pool(processes=8) as pool:
            trees = list(pool.imap_unordered(cls.line_mapper,
                                             list(enumerate(result)),
                                             chunksize=400
                                             ))
        trees = [tree for idx, tree in sorted(trees)]
        return trees

    @classmethod
    def line_mapper(cls, args):
        idx, (line_s, extra_info) = args
        tree = cls.from_string(line_s)
        tree.extra = extra_info
        return idx, tree

    def to_sentence(self, insert_root=False):
        from coli.bilexical_base.tree_utils import Sentence, SentenceNode
        result = Sentence()
        if insert_root:
            result.append(Sentence.NodeType.root_node())
        for idx, (word, postag) in enumerate(self.generate_word_and_postag(), 1):
            result.append(SentenceNode(idx, word, word, postag, postag, None, -1, None))
        return result

    @classmethod
    def internal_evaluate(cls, gold_sents, system_sents, log_file, print=True):
        with tempfile.NamedTemporaryFile(
                "w", encoding="utf-8", delete=False) as gold_tmp:
            for sent in gold_sents:
                gold_tmp.write(sent.to_string(with_comments=False))

        with tempfile.NamedTemporaryFile(
                "w", encoding="utf-8", delete=False) as system_tmp:
            for sent in system_sents:
                system_tmp.write(sent.to_string(with_comments=False))

        ret = cls.run_evalb(gold_tmp.name, system_tmp.name, log_file, print)
        os.remove(gold_tmp.name)
        os.remove(system_tmp.name)
        return ret

    @classmethod
    def evaluate_with_external_program(cls, gold_file, output_file, perf_file=None, print=True):
        with open(gold_file) as f_gold, \
                tempfile.NamedTemporaryFile(
                    "w", encoding="utf-8", delete=False) as gold_tmp:
            for line in f_gold:
                line_s = line.strip()
                if line_s and not line_s.startswith("#"):
                    gold_tmp.write(line)

        with open(output_file) as f_output, \
                tempfile.NamedTemporaryFile(
                    "w", encoding="utf-8", delete=False) as output_tmp:
            for line in f_output:
                line_s = line.strip()
                if line_s and not line_s.startswith("#"):
                    output_tmp.write(line)

        if perf_file is None:
            perf_file = output_file + ".txt"

        ret = cls.run_evalb(gold_tmp.name, output_tmp.name, perf_file, print)
        os.remove(gold_tmp.name)
        os.remove(output_tmp.name)
        return ret

    performance_block_pattern = re.compile(r"-- All --(.*?)\n\n", re.DOTALL)
    performance_pattern = re.compile(r"^(.*?) +=. +([0-9.]+)", re.MULTILINE)

    @classmethod
    def run_evalb(cls, gold_file, output_file, perf_file, print=True):
        current_path = os.path.dirname(__file__)
        os.system('{}/utils/evalb {} {} > {}'.format(
            current_path, gold_file, output_file, perf_file))
        if print:
            os.system('cat %s | awk \'/Summary/,EOF { print $0 }\'' % perf_file)
        return cls.extract_performance(perf_file)

    @classmethod
    def extract_performance(cls, perf_file_name):
        with open(perf_file_name) as f:
            content = f.read()
            content_block = cls.performance_block_pattern.findall(content)[0]

        def generate_items():
            for k, v in cls.performance_pattern.findall(content_block):
                yield k, float(v)

        result = dict(generate_items())
        try:
            epoch = re.findall(r"epoch_(\d+)[_.]", perf_file_name)[0]
            result["epoch"] = int(epoch)
        except IndexError:
            result["epoch"] = -1
        return result

    @classmethod
    def from_words_and_postags(cls, items, escape=True):
        def item_to_preterminal(item):
            word, postag = item
            if escape:
                word = re.sub("[([{]", "-LRB-", word)
                word = re.sub("[)\]}]", "-RRB-", word)
                word = word.replace('"', '``')
            ret = cls(postag)
            ret.children = [Lexicon(word)]
            return ret

        tree = cls("TOP")
        tree_2 = cls("TOP_2")
        tree.children = [tree_2]
        tree_2.children = [item_to_preterminal(item) for item in items]
        tree.populate_spans_terminal()
        tree.populate_spans_leaf()
        return tree


@dataclass
class ConstTreeStatistics(object):
    words: Dictionary
    postags: Dictionary
    leaftags: Dictionary
    labels: Dictionary
    characters: Dictionary
    rules: np.array
    leaftag_to_label: np.array
    internal_labels: np.array
    root_rules: np.array
    max_sentence_length: np.array

    @classmethod
    def from_sentences(cls, sentences, include_postags=True, strip_top=None):
        words = Dictionary(initial=(PAD, UNKNOWN, START_OF_SENTENCE, END_OF_SENTENCE))
        postags = Dictionary(initial=(PAD, UNKNOWN, START_OF_SENTENCE, END_OF_SENTENCE))
        leaftags = Dictionary(initial=())
        labels = Dictionary(initial=("___EMPTY___",))
        characters = Dictionary(initial=(PAD, UNKNOWN, START_OF_WORD, END_OF_WORD,
                                         CHAR_START_OF_SENTENCE, CHAR_END_OF_SENTENCE))
        rules = set()
        internal_labels = {0}
        root_rules = set()
        max_sentence_length = 0
        for sentence in sentences:
            tree_t = sentence.condensed_unary_chain(include_postags, strip_top)
            words_and_postags = list(tree_t.generate_word_and_postag())
            if len(words_and_postags) > max_sentence_length:
                max_sentence_length = len(words_and_postags)

            for word, postag in words_and_postags:
                words.update_and_get_id(word)
                leaftags.update_and_get_id(postag)
                characters.update(word)

            if include_postags:
                for word, postag in sentence.generate_word_and_postag():
                    postags.update_and_get_id(postag)

            for rule in tree_t.generate_rules():
                labels.update_and_get_id(rule.tag)
                if isinstance(rule.children[0], ConstTree) or not include_postags:
                    internal_labels.add(labels.word_to_int[rule.tag])

                if isinstance(rule.children[0], ConstTree) and len(rule.children) == 2:
                    rule_int = (labels.word_to_int[rule.tag],) + \
                               tuple(labels.word_to_int[i.tag]
                                     for i in rule.children)
                    rules.add(rule_int)
                    if rule is tree_t:  # root node
                        root_rules.add(rule_int)

        leaftag_to_label = [labels.word_to_int[tag_name]
                            for tag_name in leaftags.int_to_word]

        rule_list = list(rules)
        rule_reverse = {i: idx for idx, i in enumerate(rule_list)}
        root_rules_idx = [rule_reverse[i] for i in root_rules]

        # noinspection PyArgumentList
        ret = cls(words, postags, leaftags, labels, characters,
                  np.array(rule_list, dtype=np.int32),
                  np.array(leaftag_to_label, dtype=np.int32),
                  np.array(sorted(internal_labels), dtype=np.int32),
                  np.array(root_rules_idx, dtype=np.int32),
                  max_sentence_length
                  )
        return ret

    def __str__(self):
        return "{} words, {} postags, {} leaftags, " \
               "{} labels, {} characters, {} rules, {} internal labels, " \
               "longest sentence has {} words".format(
            len(self.words), len(self.postags), len(self.leaftags), len(self.labels),
            len(self.characters), len(self.rules), len(self.internal_labels), self.max_sentence_length
        )


if __name__ == '__main__':
    a = ConstTree.from_string(u"""(S (NP (N (N (NP (N (N "pierre"))) (N (N (N "Vinken,")))) (AP (ADV (N (ADJ "61") (N (N "years")))) (AP (AP "old,")))))
 (VP (V "will")
  (VP (VP (V (V "join")) (NP (DET "the") (N (N (N "board")))))
   (PP (PP (P "as") (NP (DET "a") (N (AP "nonexecutive") (N (N (N "director")))))) (PP (NP (NP (DET (N (N "nov."))) (N (N (N "29."))))))))))""")
    print(set(a.generate_spans()))
