import math

class SamplerScheduler:
    def __init__(self, warmup_steps):
        self.warmup_steps = warmup_steps

    def need_sampling(self, current_step):
        if current_step < self.warmup_steps:
            return False  # 在warmup阶段不采样

        raise NotImplementedError("Subclasses must implement need_sampling method")


class FixedSamplerScheduler(SamplerScheduler):
    def __init__(self, warmup_steps, sample_interval):
        super().__init__(warmup_steps)
        self.sample_interval = sample_interval

    def need_sampling(self, current_step):
        if current_step < self.warmup_steps:
            return False  # 在warmup阶段不采样

        return (current_step - self.warmup_steps) % self.sample_interval == 0


class LinearSamplerScheduler(SamplerScheduler):
    def __init__(self, warmup_steps, start_sampling_step, end_sampling_step):
        super().__init__(warmup_steps)
        self.start_sampling_step = start_sampling_step
        self.end_sampling_step = end_sampling_step

    def need_sampling(self, current_step):
        if current_step < self.warmup_steps or current_step < self.start_sampling_step:
            return False  # 在warmup阶段或者未到达采样起始步骤时不采样

        if current_step > self.end_sampling_step:
            return False  # 超过采样结束步骤时不再采样

        return True


class ExponentialSamplerScheduler(SamplerScheduler):
    """越来越稀疏的指数采样器"""
    def __init__(self, warmup_steps, scale=1.0):
        super().__init__(warmup_steps)
        self.scale = scale

    def need_sampling(self, current_step):
        if current_step < self.warmup_steps:
            return False  # 在warmup阶段不采样

        return current_step == self.warmup_steps or math.log(current_step - self.warmup_steps, self.scale).is_integer()

class IncreasingDensityScheduler(SamplerScheduler):
    """越来越密集的采样器"""
    def __init__(self, warmup_steps, initial_interval, increase_factor):
        super().__init__(warmup_steps)
        self.interval = initial_interval
        self.factor = increase_factor

    def need_sampling(self, current_step):
        if current_step < self.warmup_steps:
            return False  # 在warmup阶段不采样
        if (current_step - self.warmup_steps) % self.interval == 0:
            self.increase_density()
            print(f"Sampling at step {current_step}, interval {self.interval}")
            return True
        return False

    def increase_density(self):
        self.interval = max(32, int(self.interval / self.factor))  # 逐渐减少interval，增加密集度


# # 使用示例
# warmup_steps = 1000

# fixed_scheduler = FixedSamplerScheduler(warmup_steps, sample_interval=50)
# linear_scheduler = LinearSamplerScheduler(warmup_steps, start_sampling_step=2000, end_sampling_step=5000)
# exponential_scheduler = ExponentialSamplerScheduler(warmup_steps, scale=2.0)

# # 在训练循环中使用：
# current_training_step = 3000  # 假设当前训练步骤为3000

# if fixed_scheduler.need_sampling(current_training_step):
#     print("进行固定时间采样")

# if linear_scheduler.need_sampling(current_training_step):
#     print("进行线性采样")

# if exponential_scheduler.need_sampling(current_training_step):
#     print("进行指数采样")
