from collections import Counter
from prettytable import PrettyTable


def print_table(data):
    table = PrettyTable()
    table.field_names = ["Entity", "Precision", "Recall", "F1"]

    for entity, metrics in data.items():
        table.add_row([entity, metrics["precision"], metrics["recall"], metrics["f1"]])

    print(table)


class SpanEntityScore(object):
    def __init__(self, id2label, markup="bio"):
        self.id2label = id2label
        self.markup = markup
        self.reset()

    def reset(self):
        self.origins = []
        self.founds = []
        self.rights = []

    def compute(self, origin, found, right):
        recall = 0 if origin == 0 else (right / origin)
        precision = 0 if found == 0 else (right / found)
        f1 = (
            0.0
            if recall + precision == 0
            else (2 * precision * recall) / (precision + recall)
        )
        return recall, precision, f1

    def result(self):
        class_info = {}
        origin_counter = Counter([x[0] for x in self.origins])
        found_counter = Counter([x[0] for x in self.founds])
        right_counter = Counter([x[0] for x in self.rights])
        for type_, count in origin_counter.items():
            origin = count
            found = found_counter.get(type_, 0)
            right = right_counter.get(type_, 0)
            recall, precision, f1 = self.compute(origin, found, right)
            class_info[type_] = {
                "precision": round(precision, 4),
                "recall": round(recall, 4),
                "f1": round(f1, 4),
            }
        origin = len(self.origins)
        found = len(self.founds)
        right = len(self.rights)
        recall, precision, f1 = self.compute(origin, found, right)
        return {"precision": precision, "recall": recall, "f1": f1}, class_info

    def update(self, label_paths, pred_paths):
        """
        labels_paths: [[],[],[],....]
        pred_paths: [[],[],[],.....]

        :param label_paths:
        :param pred_paths:
        :return:
        Example:
            >>> labels_paths = [['O', 'O', 'O', 'B-MISC', 'I-MISC', 'I-MISC', 'O'], ['B-PER', 'I-PER', 'O']]
            >>> pred_paths = [['O', 'O', 'B-MISC', 'I-MISC', 'I-MISC', 'I-MISC', 'O'], ['B-PER', 'I-PER', 'O']]
        """
        for label_path, pre_path in zip(label_paths, pred_paths):
            label_entities = self.get_entities(label_path, self.markup)
            pre_entities = self.get_entities(pre_path, self.markup)
            self.origins.extend(label_entities)
            self.founds.extend(pre_entities)
            self.rights.extend(
                [
                    pre_entity
                    for pre_entity in pre_entities
                    if pre_entity in label_entities
                ]
            )

    def get_entities(self, seq, markup="bio"):
        """
        equal to seqeval
        :param seq:
        :param id2label:
        :param markup:
        :return:
        """
        assert markup in ["bio", "bios", "bmes"]
        if markup == "bio":
            return self.get_entity_bio(seq)
        elif markup == "bios":
            return self.get_entity_bios(seq)
        else:
            return self.get_entity_bmes(seq)

    def get_entity_bio(self, seq):
        """Gets entities from sequence.
        note: BIO
        Args:
            seq (list): sequence of labels.
        Returns:
            list: list of (span_type, span_start, span_end).
        Example:
            seq = ['B-PER', 'I-PER', 'O', 'B-LOC']
            get_entity_bio(seq)
            #output
            [['PER', 0, 1], ['LOC', 3, 3]]
        """
        spans = []
        span = [-1, -1, -1]
        for indx, tag in enumerate(seq):
            if not isinstance(tag, str):
                tag = self.id2label[tag]
            if tag.startswith("B-"):
                if span[2] != -1:
                    spans.append(span)
                span = [-1, -1, -1]
                span[1] = indx
                span[0] = tag.split("-")[1]
                span[2] = indx
                if indx == len(seq) - 1:
                    spans.append(span)
            elif tag.startswith("I-") and span[1] != -1:
                _type = tag.split("-")[1]
                if _type == span[0]:
                    span[2] = indx

                if indx == len(seq) - 1:
                    spans.append(span)
            else:
                if span[2] != -1:
                    spans.append(span)
                span = [-1, -1, -1]
        return spans

    def get_entity_bios(self, seq):
        """Gets entities from sequence.
        note: BIOS
        Args:
            seq (list): sequence of labels.
        Returns:
            list: list of (span_type, span_start, span_end).
        Example:
            # >>> seq = ['B-PER', 'I-PER', 'O', 'S-LOC']
            # >>> get_entity_bios(seq)
            [['PER', 0,1], ['LOC', 3, 3]]
        """
        spans = []
        span = [-1, -1, -1]
        for indx, tag in enumerate(seq):
            if not isinstance(tag, str):
                tag = self.id2label[tag]
            if tag.startswith("S-"):
                if span[2] != -1:
                    spans.append(span)
                span = [-1, -1, -1]
                span[1] = indx
                span[2] = indx
                span[0] = tag.split("-")[1]
                spans.append(span)
                span = (-1, -1, -1)
            if tag.startswith("B-"):
                if span[2] != -1:
                    spans.append(span)
                span = [-1, -1, -1]
                span[1] = indx
                span[0] = tag.split("-")[1]
            elif tag.startswith("I-") and span[1] != -1:
                _type = tag.split("-")[1]
                if _type == span[0]:
                    span[2] = indx
                if indx == len(seq) - 1:
                    spans.append(span)
            else:
                if span[2] != -1:
                    spans.append(span)
                span = [-1, -1, -1]
        return spans

    def get_entity_bmes(self, seq, classes_to_ignore=None):
        """
        kNN-NER: BMES metric
        Given a sequence of BMES-{entity type} labels, extracts spans.
        Gets entities from sequence.
        note: BMES
        Args:
            seq (list): sequence of labels.
        Returns:
            list: list of (span_type, span_start, span_end).
        Example:
            # >>> seq = ['B-PER', 'E-PER', 'O', 'S-LOC']
            # >>> get_entity_bios(seq)
            [['PER', 0,1], ['LOC', 3, 3]]
        """
        spans = []
        classes_to_ignore = classes_to_ignore or []
        index = 0
        while index < len(seq):
            label = seq[index]
            if label[0] == "S":
                spans.append([label.split("-")[1], index, index])
            elif label[0] == "B":
                sign = 1
                start = index
                start_cate = label.split("-")[1]
                while label[0] != "E":
                    index += 1
                    if index >= len(seq):
                        spans.append([start_cate, start, start])
                        sign = 0
                        break
                    label = seq[index]
                    if not (label[0] == "M" or label[0] == "E"):
                        spans.append([start_cate, start, start])
                        sign = 0
                        break
                    if label.split("-")[1] != start_cate:
                        spans.append([start_cate, start, start])
                        sign = 0
                        break
                if sign == 1:
                    spans.append([start_cate, start, index])
            else:
                if label != "O":
                    pass
            index += 1
        return [span for span in spans if span[0] not in classes_to_ignore]
