


from torch.optim import AdamW

class MyOptimizer(AdamW):
    def __init__(self,  params, 
                        train_batch_cnt=None,
                        train_epoch_cnt=None,
                        warm_up_ratio = 0.1, 
                        warmup_steps=10000,
                        max_training_steps=100000,

                        cold_lr=1e-3,
                        learning_rate=1e-4, min_lr=1e-5,
                        betas=(0.9, 0.999), 
                        weight_decay=0.01, eps=1e-9, 
                        
                        ):
        self.learning_rate = learning_rate

        if train_batch_cnt is not None and train_epoch_cnt is not None:
            max_training_steps = train_batch_cnt * train_epoch_cnt 
            warmup_steps = int( max_training_steps * warm_up_ratio )
    
        self.warmup = warmup_steps
        self.lr = 0
        self.step_num = 0
        self.min_lr = min_lr
        self.max_training_steps = max_training_steps # calculated by epoch and batch_size

        self.cold_lr=cold_lr
        self.in_cold_start = False
        super(MyOptimizer, self).__init__(params, betas=betas, weight_decay=weight_decay, eps=eps)

    def step(self, closure=None):
        if self.in_cold_start:
            super(MyOptimizer, self).step()
            return 

        self.step_num += 1
        self.lr = self._learning_rate()
        for group in self.param_groups:
            group['lr'] = self.lr
        super(MyOptimizer, self).step()

    def _learning_rate(self, ):
        ## write learning rate schedule 
        if self.step_num < self.warmup:
            return self.step_num / self.warmup  * self.learning_rate
        if self.step_num >= self.max_training_steps:
            lr = self.min_lr
        else: 
            lr = (self.step_num - self.max_training_steps) / (self.warmup - self.max_training_steps ) * self.learning_rate
            lr = max( lr, self.min_lr )
        return lr

    def set_cold_start(self, ):
        self.in_cold_start = True
        for group in self.param_groups:
            group['lr'] = self.cold_lr

    def close_cold_start(self, ):
        self.in_cold_start = False 
        self.lr = 0
        self.step_num = 0
