
import numpy as np
import torch


class DataProcessor:
    def __init__(self, X, y, pp=None, z=None, groups=None, text=None):
        self.X = X
        self.y = y
        self.pp = pp
        self.groups = groups
        self.text = np.array(text) if text is not None else None
        self.z = z
        self.pz = None

    def estimate_z_binacc(self):
        self.py = np.argmax(self.pp, axis=2)
        self.z = (self.py == self.y[:, None])
        self.pz = self.pp[np.arange(len(self.y)), :, self.y]

    def estimate_z_brierscore(self):
        self.py = np.argmax(self.pp, axis=2)
        y_one_hot = np.eye(self.pp.shape[2])[self.y]
        y_one_hot = np.repeat(y_one_hot[:, np.newaxis, :], self.pp.shape[1], axis=1)
        self.z = np.mean((self.pp - y_one_hot) ** 2, axis=2)
        self.pz = np.zeros_like(self.z)  # placeholder

    def estimate_z_crossentropy(self):
        epsilon = 1e-12
        pp_clipped = np.clip(self.pp, epsilon, 1. - epsilon)
        y_one_hot = np.eye(self.pp.shape[2])[self.y]
        y_one_hot = np.repeat(y_one_hot[:, np.newaxis, :], self.pp.shape[1], axis=1)
        self.py = np.zeros_like(y_one_hot)  # placeholder
        self.z = -np.sum(y_one_hot * np.log(pp_clipped), axis=2)
        self.pz = np.zeros_like(self.z)  # placeholder

    def estimate_general_metric(self, metric_values):
        self.z = metric_values
        self.pz = np.zeros_like(self.z)  # placeholder
        self.py = np.zeros_like(self.z)  # placeholder

    def format_data(self):
        if len(self.pp.shape) == 1:
            self.pp = torch.cat([1 - self.pp.unsqueeze(1), self.pp.unsqueeze(1)], dim=1)
        
        self.X = self._tensor_to_numpy(self.X)
        self.y = self._tensor_to_numpy(self.y)
        self.pp = self._tensor_to_numpy(self.pp)
        self.groups = self._tensor_to_numpy(self.groups)
        
    def _tensor_to_numpy(self, tensor):
        if isinstance(tensor, torch.Tensor):
            if tensor.is_cuda:
                tensor = tensor.cpu()
            return tensor.numpy()
        return tensor

    def concatenate_group_dimensions(self):
        if self.groups is not None and isinstance(self.groups[0], list):
            self.groups = np.array(['-'.join(map(str, group)) for group in self.groups])
        else:
            self.groups = np.array(self.groups)

    def reduce_featuredim(self, featuredim_algorithm):
        self.X = featuredim_algorithm(self.X)

    def create_groups(self, clustering_algorithm, min_clusters=40, max_clusters=200):
        if self.groups is not None:
            raise ValueError("Groups are already present")
        self.groups = clustering_algorithm(X=self.X, min_clusters=min_clusters, max_clusters=max_clusters)

    def filter_by_groupsize(self, min_group_size):
        unique, counts = np.unique(self.groups, return_counts=True)
        valid_clusters = unique[counts >= min_group_size]
        valid_mask = np.isin(self.groups, valid_clusters)

        self.X = self.X[valid_mask]
        self.y = self.y[valid_mask]
        self.pp = self.pp[valid_mask]
        self.groups = self.groups[valid_mask]
        self.py = self.py[valid_mask]
        self.pz = np.array(self.pz)[valid_mask]
        self.z = np.array(self.z)[valid_mask]
        