# -*- coding: utf-8 -*-

import torch.optim as optim
from transformers.optimization import AdamW

from ...common.dataclass_options import BranchSelect, OptionsBase
from .lookahead import Lookahead
from .ranger import Ranger


class AdamOptions(OptionsBase):
    lr: float = 1e-3
    beta_1: float = 0.9
    beta_2: float = 0.999
    eps: float = 1e-8
    weight_decay: float = 0.0
    amsgrad: bool = False

    def create(self, trainable_params):
        return optim.Adam(trainable_params,
                          lr=self.lr,
                          betas=(self.beta_1, self.beta_2),
                          eps=self.eps,
                          weight_decay=self.weight_decay,
                          amsgrad=self.amsgrad)


class AdamWOptions(AdamOptions):
    def create(self, trainable_params):
        assert not self.amsgrad
        return AdamW(trainable_params,
                     lr=self.lr,
                     betas=(self.beta_1, self.beta_2),
                     eps=self.eps,
                     weight_decay=self.weight_decay)


class SGDOptions(OptionsBase):
    lr: float
    momentum: float = 0
    dampening: float = 0
    weight_decay: float = 0
    nesterov: bool = False


class RangerOptions(OptionsBase):
    lr: float = 1e-3
    N_sma_threshold: int = 5
    beta_1: float = 0.95
    beta_2: float = 0.999
    eps: float = 1e-5
    weight_decay: float = 0.0

    def create(self, trainable_params, k, alpha):
        return Ranger(trainable_params,
                      lr=self.lr,
                      alpha=alpha,
                      k=k,
                      N_sma_threshold=self.N_sma_threshold,
                      betas=(self.beta_1, self.beta_2),
                      eps=self.eps,
                      weight_decay=self.weight_decay)


class OptimizerOptions(BranchSelect):
    type = 'adam'
    branches = {'adam': (optim.Adam, AdamOptions),
                'sgd': (optim.SGD, SGDOptions),
                'adamw': (AdamW, AdamWOptions),
                'ranger': (Ranger, RangerOptions)}

    lookahead_step: int = 6
    lookahead_alpha: float = 0.5

    @property
    def learning_rate(self):
        return getattr(self, f'{self.type}_options').lr

    def create(self, *args, **kwargs):
        use_lookahead = (self.lookahead_step > 1)

        if self.type == 'ranger':
            use_lookahead = False
            kwargs.update(k=self.lookahead_step)
            kwargs.update(alpha=self.lookahead_alpha)

        optimizer = super().create(*args, **kwargs)
        if use_lookahead:
            optimizer = Lookahead(optimizer,
                                  k=self.lookahead_step, alpha=self.lookahead_alpha)
        return optimizer
