
from dataclasses import dataclass

import numpy as np


@dataclass
class MetricScores:
    metric: str
    system: str
    dataset: str
    scores: np.array

    def pair_binary(self, human_binary: 'BinaryHuman') -> 'BinaryPaired':
        return human_binary.pair(metric_scores=self)

    def __len__(self):
        return self.scores.shape[0]


@dataclass
class BinaryPaired(MetricScores):
    human_binary: np.array


@dataclass
class BinaryHuman:
    system: str
    dataset: str
    binary_scores: np.array

    def pair(self, metric_scores: MetricScores) -> BinaryPaired:
        if metric_scores.system != self.system:
            raise ValueError(
                f"cannot pair metric for {metric_scores.system}"
                f" with annotations for {self.system}")

        if metric_scores.dataset != self.dataset:
            raise ValueError(
                f"cannot pair metric for {metric_scores.dataset}"
                f" with annotations for {self.dataset}")

        return BinaryPaired(
            metric=metric_scores.metric,
            system=metric_scores.system,
            dataset=metric_scores.dataset,
            scores=metric_scores.scores,
            human_binary=self.binary_scores,
        )

    def __len__(self):
        return self.binary_scores.shape[0]


@dataclass
class PartialBinaryHuman(BinaryHuman):
    indices: np.array

    def pair(self, metric_scores: MetricScores) -> BinaryPaired:
        if metric_scores.system != self.system:
            raise ValueError(
                f"cannot pair metric for {metric_scores.system}"
                f" with annotations for {self.system}")

        if metric_scores.dataset != self.dataset:
            raise ValueError(
                f"cannot pair metric for {metric_scores.dataset}"
                f" with annotations for {self.dataset}")

        return BinaryPaired(
            metric=metric_scores.metric,
            system=metric_scores.system,
            dataset=metric_scores.dataset,
            scores=metric_scores.scores[self.indices],
            human_binary=self.binary_scores,
        )


