from io import open

import os
import re
from collections import Counter

import attr
import numpy as np

from coli.basic_tools.common_utils import deprecated
from coli.bilexical_base.conll_reader import CoNLLUSentence, CoNLLUNode
from coli.bilexical_base import Edge


@attr.s
class SentenceNode(object):
    id = attr.ib(type=int, convert=int)
    form = attr.ib(type=str)
    lemma = attr.ib(type=str)
    cpos = attr.ib(type=str)
    pos = attr.ib(type=str)
    feats = attr.ib(type=str)
    parent_id = attr.ib(type=int, convert=int)
    relation = attr.ib(type=str)

    def __attrs_post_init__(self):
        self.norm = normalize(self.form)

    @classmethod
    def from_conllu_node(cls, conllu_node):
        # noinspection PyArgumentList
        return cls(conllu_node.id_, conllu_node.form, conllu_node.lemma,
                   conllu_node.cpostag.upper(),
                   conllu_node.postag.upper(), conllu_node.feats,
                   int(conllu_node.head), conllu_node.deprel)

    def copy(self):
        return attr.evolve(self)

    @classmethod
    def root_node(cls):
        # noinspection PyArgumentList
        return cls(0, '*root*', '*root*', 'ROOT-POS', 'ROOT-CPOS', '_', -1, 'rroot')

    @property
    def postag(self):
        return self.pos

    def __str__(self):
        return u"{}: {}-{}, head {}".format(self.id, self.form, self.pos, self.parent_id)

    def __repr__(self):
        return self.__str__()


class Sentence(list):
    NodeType = SentenceNode
    performance_pattern = re.compile(r"^(.+?)[\s|]+([\d.]+)", re.MULTILINE)

    __slots__ = ("extra",)

    def __init__(self, seq=()):
        super(Sentence, self).__init__(seq)
        self.extra = {}

    @classmethod
    def from_conllu_sentence(cls, sent, root_last=True):
        ret = cls(cls.NodeType.from_conllu_node(i) for i in sent)
        if root_last:
            ret.append(cls.NodeType.root_node())
        else:
            ret.insert(0, cls.NodeType.root_node())
        return ret

    def copy(self):
        return self.__class__(i.copy() for i in self)

    @classmethod
    def from_file(cls, file_name, use_edge=True, root_last=False):
        with open(file_name) as f:
            return [cls.from_conllu_sentence(i, root_last)
                    for i in CoNLLUSentence.get_all_sentences(f)]

    @classmethod
    def from_words_and_postags(cls, items):
        ret = cls()
        ret.append(cls.NodeType.root_node())
        for idx, (word, postag) in enumerate(items, 1):
            ret.append(cls.NodeType(idx, word, word, postag, postag, None, 0, None))
        return ret

    @staticmethod
    def evaluate_with_external_program(gold_file, output_file):
        current_path = os.path.dirname(__file__)
        eval_script = os.path.join(current_path, "utils/evaluation_script/conll17_ud_eval.py")
        weight_file = os.path.join(current_path, "utils/evaluation_script/weights.clas")
        os.system("python {} -v -w {} {} {} > {}.txt".format(eval_script, weight_file, gold_file, output_file, output_file))
        os.system("cat {}.txt".format(output_file))

    @classmethod
    def extract_performance(cls, perf_file_name):
        with open(perf_file_name) as f:
            content = f.read()

        def generate_items():
            for k, v in cls.performance_pattern.findall(content):
                yield k, float(v)

        result = dict(generate_items())
        epoch = re.findall(r"epoch_(\d+)[_.]", perf_file_name)[0]
        result["epoch"] = int(epoch)
        return result


    def to_matrix(self):
        ret = np.zeros((len(self), len(self)), dtype=np.bool)
        for dep, head in enumerate((i.parent_id for i in self[1:]), 1):
            ret[head, dep] = 1
        return ret

    def generate_edges(self):
        for dep, (head, label) in enumerate(((i.parent_id, i.relation) for i in self[1:]), 1):
            yield Edge(head, label, dep)

    def replaced_edges(self, heads, labels):
        ret = self.copy()
        for head, label, node in zip(heads, labels, ret):
            node.parent_id = head
            node.relation = label
        return ret

    def to_string(self):
        def convert_node(node):
            # noinspection PyArgumentList
            return CoNLLUNode(node.id, node.form, node.lemma, node.cpos,
                              node.pos, node.feats,
                              node.parent_id, node.relation,
                              "_", "_")

        result = CoNLLUSentence(convert_node(node) for node in self if node.id > 0)
        return result.to_string()

    @property
    def words(self):
        return [i.form for i in self]


class Sentence06(Sentence):
    @staticmethod
    def evaluate_with_external_program(gold_file, output_file):
        current_path = os.path.dirname(__file__)
        eval_script = os.path.join(current_path, "utils/eval.pl")
        os.system("(perl {} -g {} -s {} > {}.txt && head -n 3 {}.txt) &".format(
            eval_script, gold_file, output_file, output_file, output_file))


class Sentence06RootLast(Sentence06):
    @classmethod
    def from_conllu_sentence(cls, sent, root_last=True):
        return super(Sentence06RootLast, cls).from_conllu_sentence(sent, root_last)

    @classmethod
    def from_file(cls, file_name, use_edge=True, root_last=True):
        return super(Sentence06RootLast, cls).from_file(file_name, use_edge, root_last)


@deprecated
def vocab(sentences):
    wordsCount = Counter()
    posCount = Counter()
    relCount = Counter()

    for sentence in sentences:
        wordsCount.update([node.norm for node in sentence if isinstance(node, SentenceNode)])
        posCount.update([node.pos for node in sentence if isinstance(node, SentenceNode)])
        relCount.update([node.relation for node in sentence if isinstance(node, SentenceNode)])

    return (wordsCount, {w: i for i, w in enumerate(wordsCount.keys())},
            posCount.keys(), relCount.keys())


numberRegex = re.compile("[0-9]+|[0-9]+\\.[0-9]+|[0-9]+[0-9,]+");


def normalize(word):
    return 'NUM' if numberRegex.match(word) else word.lower()