# -*- coding: utf-8 -*-

import transformers.optimization as extra_optim
from torch.optim.lr_scheduler import ReduceLROnPlateau

from ...common.dataclass_options import OptionsBase, argfield


class AdvancedLearningOptions(OptionsBase):
    lr_warmup_steps: int = 160
    lr_reduce_method: str = \
        argfield(
            'on_plateau',
            choices=[
                'on_plateau',  # WarmUp + ReduceLROnPlateau
                'linear',  # transformers.WarmupLinearSchedule
                'cosine',  # transformers.WarmupCosineSchedule
                'cosine_with_hard_restarts'  # transformers.WarmupCosineWithHardRestartsSchedule
            ])

    on_plateau_factor: float = 0.5
    on_plateau_patience: int = 5

    restart_cycles: float = 1

    clip_grad_norm: float = 0.0
    # Stop training when learning rate decrease to this value
    min_learning_rate: float = 1e-6
    update_frequency: int = 1


class AdvancedScheduler:
    def __init__(self, options: AdvancedLearningOptions, optimizer, max_steps=None, mode=None):
        self.last_epoch = -1
        self.epoch_scheduler = None
        self.optimizer = optimizer

        self.min_learning_rate = options.min_learning_rate
        self.warmup_steps = warmup_steps = options.lr_warmup_steps

        method = options.lr_reduce_method
        if method == 'on_plateau':
            self.step_scheduler = extra_optim.WarmupConstantSchedule(optimizer, warmup_steps)
            self.epoch_scheduler =\
                ReduceLROnPlateau(optimizer, mode=mode,
                                  factor=options.on_plateau_factor,
                                  patience=options.on_plateau_patience,
                                  verbose=True)
        else:
            infix = ''.join(map(str.capitalize, options.lr_reduce_method.split('_')))
            scheduler_class = getattr(extra_optim, f'Warmup{infix}Schedule')

            args = [optimizer, warmup_steps, max_steps]
            if method == 'cosine_with_hard_restarts':
                args.append(options.restart_cycles)

            self.step_scheduler = scheduler_class(*args)

    def __str__(self):
        return f'{self.__class__.__qualname__}[\n' \
               f'  step={self.step_scheduler},\n' \
               f'  epoch={self.epoch_scheduler},\n' \
               f']'

    def before_step_hook(self, step, max_steps, _, logger=None):
        if step != 0 and (self.epoch_scheduler is None or step <= self.warmup_steps):
            # when step > self.warmup_steps and self.epoch_scheduler is not None
            # learning rate should be set only by after_eval_hook
            self.step_scheduler.step(max(1, step))

        if (step <= self.warmup_steps and step % (self.warmup_steps // 5) == 0) or \
           (self.epoch_scheduler is None and step % 1000 == 0):
            logger.info('learning rate: %s',
                        [float(param_group['lr'])
                         for param_group in self.optimizer.param_groups])

    def after_eval_hook(self, step, max_steps, epoch, metric_value, logger=None):
        self.last_epoch = epoch
        if self.epoch_scheduler is not None:
            if step > self.warmup_steps:
                self.epoch_scheduler.step(metric_value, epoch)
                for param_group in self.optimizer.param_groups:
                    if param_group['lr'] < self.min_learning_rate:
                        return True

    def state_dict(self):
        return {
            'last_epoch': self.last_epoch,
            'step': self.step_scheduler.state_dict(),
            'epoch': (self.epoch_scheduler.state_dict() if self.epoch_scheduler else None)
        }

    def load_state_dict(self, saved_state):
        self.last_epoch = saved_state['last_epoch']
        self.step_scheduler.load_state_dict(saved_state['step'])
        if self.epoch_scheduler is not None:
            self.epoch_scheduler.load_state_dict(saved_state['epoch'])
