
import numpy as np

class DataBalancer:
    def __init__(self, X, groups, z, pz):
        self.X = X
        self.groups = groups
        self.z = z
        self.pz = pz

    def _validate_min_group_size(self, counts, min_group_size):
        if (counts < min_group_size).any():
            raise ValueError(f'Smallest group has fewer than {min_group_size} observations!')

    def _subsample(self, n_examples, seed):
        idx = []
        rng = np.random.RandomState(seed)
        for group, n in zip(self.unique_groups, n_examples):
            group_indices = np.where(self.groups == group)[0]
            selected_indices = rng.choice(group_indices, size=n, replace=False)
            idx.extend(selected_indices)
        return np.array(idx)

    def subsample_proportional(self, min_group_size=10, seed=0):
        self.unique_groups, counts = np.unique(self.groups, return_counts=True)
        self._validate_min_group_size(counts, min_group_size)
        
        shares = counts / self.groups.shape[0]
        n_examples = np.round(shares / shares.min() * min_group_size).astype(int)
        
        idx = self._subsample(n_examples, seed)
        self._update_train_data(idx)

    def subsample_allequal(self, min_group_size=10, seed=0):
        self.unique_groups, counts = np.unique(self.groups, return_counts=True)
        self._validate_min_group_size(counts, min_group_size)
        
        n_examples = np.array([min_group_size] * len(self.unique_groups))
        
        idx = self._subsample(n_examples, seed)
        self._update_train_data(idx)

    def _update_train_data(self, idx):
        self.X_train = self.X[idx]
        self.z_train = self.z[idx]
        self.pz_train = self.pz[idx]
        self.groups_train = self.groups[idx]
