from __future__ import absolute_import
from __future__ import division
from __future__ import print_function

from typing import Set, Tuple, Dict
from itertools import chain
from collections import OrderedDict, defaultdict
from argparse import Namespace

import torch
from torch import LongTensor
import numpy as np


def a_better_than_b(a, b):
    for k, v in a.items():
        if v > b[k]:
            return True
        elif v < b[k]:
            return False
    return False


def namespace_add(a, b):
    return Namespace(**{k: a.__dict__[k] + b.__dict__[k] for k in a.__dict__})


class Metric(object):
    """
    A very general abstract class representing a metric which can be accumulated.
    """

    def __init__(self):
        self.counter = self.counter_factory()
        self.best = None

    def is_best(self, metric: Dict) -> bool:
        """
        根据key的顺序比较metric，在前者优先，默认数值越大越好。
        """
        if self.best is None or a_better_than_b(metric, self.best):
            self.best = metric
            return True
        return False

    def __call__(self, predictions: torch.Tensor, gold_labels: torch.Tensor,
                 mask: torch.LongTensor) -> Dict:
        """
        每个batch调用，更新counter，计算当前batch的分数并返回。
        """
        raise NotImplementedError

    def get_metric(self, counter=None, reset=False) -> Dict:
        """
        用counter计算出metric。
        """
        raise NotImplementedError

    @staticmethod
    def counter_factory(**kwargs) -> Namespace:
        raise NotImplementedError

    @staticmethod
    def metric_factory(**kwargs) -> Dict:
        """
        注意按重要性排列参数。
        """
        raise NotImplementedError


class TaggingMetric(Metric):
    def __init__(self, ignore_index: int = 0):
        super().__init__()
        self.ignore_index = ignore_index

    def __call__(self, predictions: torch.Tensor, gold_labels: torch.Tensor,
                 mask: torch.LongTensor) -> Dict:
        batch = self.counter_factory()

        mask = (gold_labels != self.ignore_index).long() * mask  # 只看标注
        batch.total = mask.sum().item()
        batch.positive = ((predictions != self.ignore_index).long() * mask).sum().item()
        batch.correct = ((predictions == gold_labels).long() * mask).sum().item()

        self.counter = namespace_add(self.counter, batch)

        return self.get_metric(batch)

    @staticmethod
    def counter_factory(total=0, positive=0, correct=.0) -> Namespace:
        return Namespace(total=total, positive=positive, correct=correct)

    @staticmethod
    def metric_factory(f1=.0, recall=.0, precision=.0) -> Dict:
        return dict(F1=f1, recall=recall, precision=precision)

    def get_metric(self, counter=None, reset=False) -> Dict:
        c = counter or self.counter
        total, correct, positive = c.total, c.correct, c.positive
        recall = 0 if total == 0 else correct / total
        precision = 0 if positive == 0 else correct / positive

        if precision + recall == 0:
            f1 = 0
        else:
            f1 = 2 * precision * recall / (precision + recall)

        if reset:
            self.counter = self.counter_factory()

        return self.metric_factory(f1, recall, precision)


class ExactMatch(TaggingMetric):
    def __init__(self, o_id, token_to_id, ouput_class=False, output_detail=False):
        super().__init__(o_id)
        self.o_id = o_id
        self.id_to_label = dict()  # map[i_x] = x
        self.bi_map = dict()  # map[b_x] = i_x
        for label, index in token_to_id.items():
            if label.startswith('B-'):
                self.bi_map[label[2:]] = index[0]
        for label, index in token_to_id.items():
            if label.startswith('I-'):
                b_id = self.bi_map.pop(label[2:])
                self.bi_map[b_id] = index[0]
                self.id_to_label[index[0]] = label[2:]

        self.label_counter = {k: self.counter_factory() for k in self.id_to_label}
        self.ouput_class = ouput_class
        self.output_detail = output_detail
        self.data_info = dict()

    def __call__(self,
                 predictions: LongTensor,
                 gold_labels: LongTensor,
                 lengths: LongTensor) -> OrderedDict:
        batch = self.counter_factory()

        for prediction, gold, length in zip(predictions, gold_labels, lengths):
            predict_entities = self.get_entities(prediction.tolist()[:length])
            gold_entities = self.get_entities(gold.tolist()[:length])
            correct_entities = self.get_correct(predict_entities, gold_entities)

            # print('pred: ', predict_entities)
            # print('gold: ', gold_entities)
            # print('cor: ', correct_entities)
            # print(prediction.tolist()[:length])
            # print(gold.tolist()[:length])
            # exit()

            for e in gold_entities:
                self.label_counter[e[2]].total += 1
                batch.total += 1
            for e in predict_entities:
                self.label_counter[e[2]].positive += 1
                batch.positive += 1
            for e in correct_entities:
                self.label_counter[e[2]].correct += e[3]
                batch.correct += e[3]

        self.counter = namespace_add(self.counter, batch)

        return self.get_metric(batch)

    def get_entities(self, labels) -> Set[Tuple[int]]:
        entities, one = set(), None
        for i, label in enumerate(chain(labels, [self.o_id])):
            if one:
                if label == one[2]:  # I-x
                    one[1] = i
                    continue
                else:
                    entities.add(tuple(one))
                    one = None
            if label in self.bi_map:  # B-x
                one = [i, i, self.bi_map[label]]  # start, end, I-x
        return entities

    @staticmethod
    def get_correct(predict_entities, gold_entities):
        correct_entities = predict_entities & gold_entities
        correct_entities = {tuple(chain(e, [1])) for e in correct_entities}
        return correct_entities

    def get_metric(self, counter=None, reset=False) -> OrderedDict:
        if not reset:
            return super().get_metric(counter)
        if not self.ouput_class:
            return super().get_metric(reset=True)

        key_list = ['F1', 'precision', 'recall'] if self.output_detail else ['F1']

        metrics = dict(main=super().get_metric(reset=True))
        for k, counter in self.label_counter.items():
            self.data_info[self.id_to_label[k]] = counter.total
            metrics[k] = super().get_metric(counter)
            self.label_counter[k] = self.counter_factory()

        metric_with_prefix = OrderedDict()
        for prefix in chain(['main'], self.label_counter.keys()):
            for k in key_list:
                prefix_str = self.id_to_label[prefix] if isinstance(prefix, int) else prefix
                metric_with_prefix[f"{prefix_str}_{k}"] = metrics[prefix][k]

        return metric_with_prefix


def get_entities(seq, suffix=False):
    """Gets entities from sequence.
    Args:
        seq (list): sequence of labels.
    Returns:
        list: list of (chunk_type, chunk_start, chunk_end).
    Example:
        >>> from seqeval.metrics.sequence_labeling import get_entities
        >>> seq = ['B-PER', 'I-PER', 'O', 'B-LOC']
        >>> get_entities(seq)
        [('PER', 0, 1), ('LOC', 3, 3)]
    """
    # for nested list
    # print(seq)
    if any(isinstance(s, list) for s in seq):
        seq = [item for sublist in seq for item in sublist + ['O']]

    prev_tag = 'O'
    prev_type = ''
    begin_offset = 0
    chunks = []
    # print(seq)
    for i, chunk in enumerate(seq + ['O']):
        if suffix:
            tag = chunk[-1]
            type_ = chunk.split('-')[0]
        else:
            # print(chunk)
            tag = chunk[0]
            type_ = chunk.split('-')[-1]

        if end_of_chunk(prev_tag, tag, prev_type, type_):
            chunks.append((prev_type, begin_offset, i-1))
        if start_of_chunk(prev_tag, tag, prev_type, type_):
            begin_offset = i
        prev_tag = tag
        prev_type = type_

    return chunks


def end_of_chunk(prev_tag, tag, prev_type, type_):
    """Checks if a chunk ended between the previous and current word.
    Args:
        prev_tag: previous chunk tag.
        tag: current chunk tag.
        prev_type: previous type.
        type_: current type.
    Returns:
        chunk_end: boolean.
    """
    chunk_end = False

    if prev_tag == 'E': chunk_end = True
    if prev_tag == 'S': chunk_end = True

    if prev_tag == 'B' and tag == 'B': chunk_end = True
    if prev_tag == 'B' and tag == 'S': chunk_end = True
    if prev_tag == 'B' and tag == 'O': chunk_end = True
    if prev_tag == 'I' and tag == 'B': chunk_end = True
    if prev_tag == 'I' and tag == 'S': chunk_end = True
    if prev_tag == 'I' and tag == 'O': chunk_end = True

    if prev_tag != 'O' and prev_tag != '.' and prev_type != type_:
        chunk_end = True

    return chunk_end


def start_of_chunk(prev_tag, tag, prev_type, type_):
    """Checks if a chunk started between the previous and current word.
    Args:
        prev_tag: previous chunk tag.
        tag: current chunk tag.
        prev_type: previous type.
        type_: current type.
    Returns:
        chunk_start: boolean.
    """
    chunk_start = False

    if tag == 'B': chunk_start = True
    if tag == 'S': chunk_start = True

    if prev_tag == 'E' and tag == 'E': chunk_start = True
    if prev_tag == 'E' and tag == 'I': chunk_start = True
    if prev_tag == 'S' and tag == 'E': chunk_start = True
    if prev_tag == 'S' and tag == 'I': chunk_start = True
    if prev_tag == 'O' and tag == 'E': chunk_start = True
    if prev_tag == 'O' and tag == 'I': chunk_start = True

    if tag != 'O' and tag != '.' and prev_type != type_:
        chunk_start = True

    return chunk_start


def f1_score(y_true, y_pred, average='micro', digits=2, suffix=False):
    """Compute the F1 score.
    The F1 score can be interpreted as a weighted average of the precision and
    recall, where an F1 score reaches its best value at 1 and worst score at 0.
    The relative contribution of precision and recall to the F1 score are
    equal. The formula for the F1 score is::
        F1 = 2 * (precision * recall) / (precision + recall)
    Args:
        y_true : 2d array. Ground truth (correct) target values.
        y_pred : 2d array. Estimated targets as returned by a tagger.
    Returns:
        score : float.
    Example:
        >>> from seqeval.metrics import f1_score
        >>> y_true = [['O', 'O', 'O', 'B-MISC', 'I-MISC', 'I-MISC', 'O'], ['B-PER', 'I-PER', 'O']]
        >>> y_pred = [['O', 'O', 'B-MISC', 'I-MISC', 'I-MISC', 'I-MISC', 'O'], ['B-PER', 'I-PER', 'O']]
        >>> f1_score(y_true, y_pred)
        0.50
    """
    true_entities = set(get_entities(y_true, suffix))
    # exit()
    pred_entities = set(get_entities(y_pred, suffix))

    nb_correct = len(true_entities & pred_entities)
    nb_pred = len(pred_entities)
    nb_true = len(true_entities)

    p = 100 * nb_correct / nb_pred if nb_pred > 0 else 0
    r = 100 * nb_correct / nb_true if nb_true > 0 else 0
    score = 2 * p * r / (p + r) if p + r > 0 else 0

    return p, r, score


def accuracy_score(y_true, y_pred):
    """Accuracy classification score.
    In multilabel classification, this function computes subset accuracy:
    the set of labels predicted for a sample must *exactly* match the
    corresponding set of labels in y_true.
    Args:
        y_true : 2d array. Ground truth (correct) target values.
        y_pred : 2d array. Estimated targets as returned by a tagger.
    Returns:
        score : float.
    Example:
        >>> from seqeval.metrics import accuracy_score
        >>> y_true = [['O', 'O', 'O', 'B-MISC', 'I-MISC', 'I-MISC', 'O'], ['B-PER', 'I-PER', 'O']]
        >>> y_pred = [['O', 'O', 'B-MISC', 'I-MISC', 'I-MISC', 'I-MISC', 'O'], ['B-PER', 'I-PER', 'O']]
        >>> accuracy_score(y_true, y_pred)
        0.80
    """
    if any(isinstance(s, list) for s in y_true):
        y_true = [item for sublist in y_true for item in sublist]
        y_pred = [item for sublist in y_pred for item in sublist]

    nb_correct = sum(y_t==y_p for y_t, y_p in zip(y_true, y_pred))
    nb_true = len(y_true)

    score = nb_correct / nb_true

    return score


def classification_report(y_true, y_pred, digits=2, suffix=False):
    """Build a text report showing the main classification metrics.
    Args:
        y_true : 2d array. Ground truth (correct) target values.
        y_pred : 2d array. Estimated targets as returned by a classifier.
        digits : int. Number of digits for formatting output floating point values.
    Returns:
        report : string. Text summary of the precision, recall, F1 score for each class.
    Examples:
        >>> from seqeval.metrics import classification_report
        >>> y_true = [['O', 'O', 'O', 'B-MISC', 'I-MISC', 'I-MISC', 'O'], ['B-PER', 'I-PER', 'O']]
        >>> y_pred = [['O', 'O', 'B-MISC', 'I-MISC', 'I-MISC', 'I-MISC', 'O'], ['B-PER', 'I-PER', 'O']]
        >>> print(classification_report(y_true, y_pred))
                     precision    recall  f1-score   support
        <BLANKLINE>
               MISC       0.00      0.00      0.00         1
                PER       1.00      1.00      1.00         1
        <BLANKLINE>
        avg / total       0.50      0.50      0.50         2
        <BLANKLINE>
    """
    true_entities = set(get_entities(y_true, suffix))
    pred_entities = set(get_entities(y_pred, suffix))

    name_width = 0
    d1 = defaultdict(set)
    d2 = defaultdict(set)
    for e in true_entities:
        d1[e[0]].add((e[1], e[2]))
        name_width = max(name_width, len(e[0]))
    for e in pred_entities:
        d2[e[0]].add((e[1], e[2]))

    last_line_heading = 'avg / total'
    width = max(name_width, len(last_line_heading), digits)

    headers = ["precision", "recall", "f1-score", "support"]
    head_fmt = u'{:>{width}s} ' + u' {:>9}' * len(headers)
    report = head_fmt.format(u'', *headers, width=width)
    report += u'\n\n'

    row_fmt = u'{:>{width}s} ' + u' {:>9.{digits}f}' * 3 + u' {:>9}\n'

    ps, rs, f1s, s = [], [], [], []
    for type_name, true_entities in d1.items():
        pred_entities = d2[type_name]
        nb_correct = len(true_entities & pred_entities)
        nb_pred = len(pred_entities)
        nb_true = len(true_entities)

        p = 100 * nb_correct / nb_pred if nb_pred > 0 else 0
        r = 100 * nb_correct / nb_true if nb_true > 0 else 0
        f1 = 2 * p * r / (p + r) if p + r > 0 else 0

        report += row_fmt.format(*[type_name, p, r, f1, nb_true], width=width, digits=digits)

        ps.append(p)
        rs.append(r)
        f1s.append(f1)
        s.append(nb_true)

    report += u'\n'

    # compute averages
    report += row_fmt.format(last_line_heading,
                             np.average(ps, weights=s),
                             np.average(rs, weights=s),
                             np.average(f1s, weights=s),
                             np.sum(s),
                             width=width, digits=digits)

    return report