import torch
from torch import Tensor
from torch.optim.lr_scheduler import LambdaLR


def required_space_param(p: Tensor) -> float:
    dtype = p.dtype
    numel = p.numel()

    if dtype == torch.bool:  # 1-bit
        return numel / 8.0
    elif dtype in [torch.uint8, torch.int8]:  # 8-bit, 1-byte
        return numel
    elif dtype in [torch.float16, torch.int16]:  # 16-bit, 2-byte
        return numel * 2.0
    elif dtype in [torch.float32, torch.int32]:  # 32-bit, 4-byte
        return numel * 4.0
    else:  # 64-bit, 8-byte
        return numel * 8.0


def get_available_device_count(default=1):
    if torch.cuda.is_available():
        return torch.cuda.device_count()
    else:
        return default


def get_sqrt_schedule_with_warmup(optimizer, num_warmup_steps, last_epoch=-1):
    """ Create a schedule with a learning rate that decreases linearly after
    linearly increasing during a warmup period.
    """
    warmup_steps_const = num_warmup_steps ** (-1.5)

    def lr_lambda(current_step):
        step_num = (current_step + 1)
        return min(step_num ** (-0.5), step_num * warmup_steps_const)

    return LambdaLR(optimizer, lr_lambda, last_epoch)


__all__ = ['required_space_param', 'get_available_device_count', 'get_sqrt_schedule_with_warmup']
