import torch
from torch.nn.functional import one_hot
from sklearn.metrics.cluster import v_measure_score
from parser.utils.alg import km_match


class Metric(object):
    def __lt__(self, other):
        return self.score < other

    def __le__(self, other):
        return self.score <= other

    def __ge__(self, other):
        return self.score >= other

    def __gt__(self, other):
        return self.score > other

    @property
    def score(self):
        return 0.0


class ManyToOneMetric(Metric):
    def __init__(self, n_clusters, device, eps=1e-8):
        self.n_clusters = n_clusters
        self.eps = eps
        self.clusters = torch.zeros((self.n_clusters, self.n_clusters),
                                    device=device)
        self.pred = []
        self.gold = []
        self._match = None
        self.need_update = False

    def __call__(self, predicts, golds):
        """

        Args:
            predicts: [n]
            golds: [n]

        Returns:

        """
        self.need_update = True
        self.pred += predicts.tolist()
        self.gold += golds.tolist()
        predicts = one_hot(predicts, num_classes=self.n_clusters).unsqueeze(-1)
        golds = one_hot(golds, num_classes=self.n_clusters).unsqueeze(-2)
        clusters = (predicts * golds).sum(0)
        self.clusters += clusters

    def __repr__(self):
        correct, self._match = km_match(self.clusters)
        return f"M-1: {self.score:.2%} " + \
            f"1-1: {correct / (self.clusters.sum() + self.eps):.2%} " + \
            f"VM: {v_measure_score(self.gold, self.pred):.2%} "

    @property
    def match(self):
        if self._match is None or self.need_update:
            _, self._match = km_match(self.clusters)
            self.need_update = False
        return self._match

    @property
    def score(self):
        return float((self.clusters.max(dim=1)[0]).sum() /
                     (self.clusters.sum() + self.eps))

    @property
    def tag_map(self):
        return {
            p: g
            for p, g in enumerate(self.clusters.max(dim=1)[1].tolist())
        }

    @property
    def gold_tag_map(self):
        return {
            g: p
            for g, p in enumerate(self.clusters.max(dim=0)[1].tolist())
        }
