import datasets
from typing import Union
import numpy as np

from torch.utils.data import Dataset
from scipy.signal import resample


class SensorDataset(Dataset):
    """Sensor dataset class for loading sensor data"""

    def __init__(
        self,
        data_dir: str,
        resample_from: Union[None, int] = None,
        resample_to: Union[None, int] = None,
        filter_by_domain: Union[None, str] = None,
    ):
        self.ds = datasets.load_from_disk(data_dir)
        if filter_by_domain:
            self.ds = self.ds.filter(lambda x: x["domain"] == filter_by_domain)
        if resample_from and resample_to:
            if resample_from != resample_to:
                self.ds = self.ds.map(
                    lambda x: self.resample_data(x["data"], resample_from, resample_to),
                    batched=True,
                )

        self.resample_from = resample_from
        self.resample_to = resample_to

    def __len__(self):
        return len(self.ds)

    def __getitem__(self, idx):
        return np.array(self.ds[idx]["data"]), self.ds[idx]["label"]

    def resample_data(self, data: np.array, sr: int, target_sr: int) -> np.array:
        """Resample data to target sample rate"""
        resampled_size = int(data.shape[0] * target_sr / sr)
        resampled_data = resample(data, resampled_size)
        return resampled_data
