# -*- coding: utf-8 -*-

import sys
from collections import Counter


class LabelingFScorer:
    all_label = '__ALL__'

    def __init__(self, name=None):
        self.name = name
        self.correct_count = Counter()
        self.gold_count = Counter()
        self.system_count = Counter()

    @classmethod
    def add(cls, a: 'LabelingFScorer', b: 'LabelingFScorer', new_name=None):
        ret = cls(new_name)
        for k, v in a.correct_count.items():
            ret.correct_count[k] += v
        for k, v in b.correct_count.items():
            ret.correct_count[k] += v
        for k, v in a.gold_count.items():
            ret.gold_count[k] += v
        for k, v in b.gold_count.items():
            ret.gold_count[k] += v
        for k, v in a.system_count.items():
            ret.gold_count[k] += v
        for k, v in b.system_count.items():
            ret.gold_count[k] += v
        return ret

    def update_item(self, gold_item, system_item, except_tag=None):
        if except_tag is None or gold_item != except_tag:
            self.gold_count[self.all_label] += 1
            self.gold_count[gold_item] += 1
        if except_tag is None or system_item != except_tag:
            self.system_count[self.all_label] += 1
            self.system_count[system_item] += 1
            if gold_item == system_item:
                self.correct_count[self.all_label] += 1
                self.correct_count[gold_item] += 1

    def update(self, gold_items, system_items, except_tag=None):
        for gold_item, system_item in zip(gold_items, system_items):
            self.update_item(gold_item, system_item, except_tag)

    def update_set(self, gold_set, system_set):
        self.gold_count[self.all_label] += len(gold_set)
        for gold_item in gold_set:
            self.gold_count[gold_item] += 1
        self.system_count[self.all_label] += len(system_set)
        for system_item in system_set:
            self.system_count[system_item] += 1

        correct_set = Counter(gold_set) & Counter(system_set)
        self.correct_count[self.all_label] += sum(correct_set.values())
        for correct_item in correct_set:
            self.correct_count[correct_item] += 1

    def update_sets(self, gold_sets, system_sets):
        for gold_set, system_set in zip(gold_sets, system_sets):
            self.update_set(gold_set, system_set)

    def get_p_r_f(self, label):
        p = (self.correct_count[label] /
             (self.system_count[label] + sys.float_info.epsilon) * 100)
        r = (self.correct_count[label] /
             (self.gold_count[label] + sys.float_info.epsilon) * 100)
        f = 2 * p * r / (p + r + sys.float_info.epsilon)
        return p, r, f

    def get_report(self):
        ret = ''
        total_p, total_r, total_f = self.get_p_r_f(self.all_label)
        ret += f'===== {self.name} (total) =====\n'
        ret += f'P: {total_p:.4f}\n'
        ret += f'R: {total_r:.4f}\n'
        ret += f'F: {total_f:.4f}\n\n'
        ret += f'===== {self.name} (each label) =====\n'
        ret += f'Label\t\tGold\t\tSystem\t\tP\t\tR\t\tF\n'
        labels = [i[0] for i in sorted(self.gold_count.items(), key=lambda x: x[1])]
        labels.extend(i[0]
                      for i in sorted(self.system_count.items(), key=lambda x: x[1])
                      if i[0] not in self.gold_count)
        for label in labels:
            label_p, label_r, label_f = self.get_p_r_f(label)
            gold_count = self.gold_count[label]
            system_count = self.system_count[label]
            ret += f'{label}\t\t{gold_count}\t\t{system_count}\t\t'
            ret += f'{label_p:.4f}\t\t{label_r:.4f}\t\t{label_f:.4f}\n'
        ret += '\n'
        return ret, total_f
