from algorithms.single_model_algorithm import SingleModelAlgorithm
from models.initializer import initialize_model
import torch


class BMO(SingleModelAlgorithm):
    """ """
    def __init__(self, config, d_out, grouper, loss, metric, n_train_steps):
        # initialize model
        model = initialize_model(config, d_out).to(config.device)
        self.bmo_lambda = config.bmo_lambda
        # initialize module
        super().__init__(
            config=config,
            model=model,
            grouper=grouper,
            loss=loss,
            metric=metric,
            n_train_steps=n_train_steps,
        )
        # additional logging
        self.logged_fields.append('bmo_penalty')

    def objective(self, results):
        group_losses, _, _ = self.loss.compute_label_wise(
            results['y_pred'],
            results['y_true'],
            device=self.device,
            return_dict=False)
        loss = torch.mean(group_losses[group_losses > 0])
        penalty = (results['y_pred'].flatten() ** 2).mean()
        loss += self.bmo_lambda * penalty
        results['bmo_penalty'] = penalty.item()
        return loss
