#!/usr/bin/env python3
"""
Adapted from learn2learn library. 
Uses different learning rates for BERT and the Decoder.
"""
import traceback
from torch.autograd import grad

from learn2learn.algorithms.base_learner import BaseLearner
from learn2learn.utils import clone_module


def maml_update(model, lr, lr_small, grads=None, in_recursion=False, doing_BERT=False):
    """
    [[Source]](https://github.com/learnables/learn2learn/blob/master/learn2learn/algorithms/maml.py)

    **Description**

    Performs a MAML update on model using grads two learning rates.
    Modification of the original learn2learn code.
    The function re-routes the Python object, thus avoiding in-place
    operations.

    NOTE: The model itself is updated in-place (no deepcopy), but the
          parameters' tensors are not.

    **Arguments**

    * **model** (Module) - The model to update.
    * **lr** (float) - The learning rate used to update the model.
    * **grads** (list, *optional*, default=None) - A list of gradients for each parameter
        of the model. If None, will use the gradients in .grad attributes.

    **Example**
    ~~~python
    maml = l2l.algorithms.MAML(Model(), lr=0.1)
    model = maml.clone() # The next two lines essentially implement model.adapt(loss)
    grads = autograd.grad(loss, model.parameters(), create_graph=True)
    maml_update(model, lr=0.1, grads)
    ~~~
    """
    if grads is not None:
        params = list(model.parameters())
        if not len(grads) == len(list(params)):
            msg = "WARNING:maml_update(): Parameters and gradients have different length. ("
            msg += str(len(params)) + " vs " + str(len(grads)) + ")"
            print(msg)
        for p, g in zip(params, grads):
            p.grad = g

    # When depth = 0, separate BERT from NON-BERT.
    if not in_recursion:

        # Do BERT updates with lower lr
        # Update the params for text_field_embedder
        for param_key in model.text_field_embedder._parameters:
            p = model.text_field_embedder._parameters[param_key]

            if p is not None and p.grad is not None:
                model.text_field_embedder._parameters[param_key] = p - lr_small * p.grad

        # Second, handle the buffers if necessary
        for buffer_key in model.text_field_embedder._buffers:
            # print("There are buffers")
            buff = model.text_field_embedder._buffers[buffer_key]
            if buff is not None and buff.grad is not None:
                model.text_field_embedder._buffers[buffer_key] = (
                    buff - lr_small * buff.grad
                )

        # Then, recurse for each submodule
        for module_key in model.text_field_embedder._modules:
            model.text_field_embedder._modules[module_key] = maml_update(
                model.text_field_embedder._modules[module_key],
                lr=lr,
                lr_small=lr_small,
                grads=None,
                in_recursion=True,
                doing_BERT=True,
            )

        # Do scalar mix parameters and decoder:
        # Update the params for decoder
        for param_key in model.decoders._parameters:
            p = model.decoders._parameters[param_key]

            if p is not None and p.grad is not None:
                model.decoders._parameters[param_key] = p - lr * p.grad
        # Second, handle the buffers if necessary
        for buffer_key in model.decoders._buffers:
            print("There are buffers")
            buff = model.decoders._buffers[buffer_key]
            if buff is not None and buff.grad is not None:
                model.decoders._buffers[buffer_key] = buff - lr * buff.grad

        # Then, recurse for each submodule
        for module_key in model.decoders._modules:
            model.decoders._modules[module_key] = maml_update(
                model.decoders._modules[module_key],
                lr=lr,
                lr_small=lr_small,
                grads=None,
                in_recursion=True,
            )

        # Update the params for scalar_mix
        for param_key in model.scalar_mix._parameters:
            p = model.scalar_mix._parameters[param_key]

            if p is not None and p.grad is not None:
                model.scalar_mix._parameters[param_key] = p - lr * p.grad
        # Second, handle the buffers if necessary
        for buffer_key in model.scalar_mix._buffers:
            print("There are buffers")
            buff = model.scalar_mix._buffers[buffer_key]
            if buff is not None and buff.grad is not None:
                model.scalar_mix._buffers[buffer_key] = buff - lr * buff.grad

        # Then, recurse for each submodule
        for module_key in model.scalar_mix._modules:
            model.scalar_mix._modules[module_key] = maml_update(
                model.scalar_mix._modules[module_key],
                lr=lr,
                lr_small=lr_small,
                grads=None,
                in_recursion=True,
            )

    # We are in recursion, update using the correct learning rate
    else:
        real_lr = lr if not doing_BERT else lr_small
        for param_key in model._parameters:
            p = model._parameters[param_key]

            if p is not None and p.grad is not None:
                model._parameters[param_key] = p - real_lr * p.grad

        # Second, handle the buffers if necessary
        for buffer_key in model._buffers:
            print("There are buffers")
            buff = model._buffers[buffer_key]
            if buff is not None and buff.grad is not None:
                model._buffers[buffer_key] = buff - real_lr * buff.grad

        # Then, recurse for each submodule
        for module_key in model._modules:
            model._modules[module_key] = maml_update(
                model._modules[module_key],
                lr=real_lr,
                lr_small=lr_small,
                grads=None,
                in_recursion=True,
            )

    # Finally, rebuild the flattened parameters for RNNs
    # See this issue for more details:
    # https://github.com/learnables/learn2learn/issues/139
    model._apply(lambda x: x)
    return model


class MAML(BaseLearner):
    """

    [[Source]](https://github.com/learnables/learn2learn/blob/master/learn2learn/algorithms/maml.py)

    **Description**

    High-level implementation of *Model-Agnostic Meta-Learning*.

    This class wraps an arbitrary nn.Module and augments it with `clone()` and `adapt()`
    methods.

    For the first-order version of MAML (i.e. FOMAML), set the `first_order` flag to `True`
    upon initialization.

    **Arguments**

    * **model** (Module) - Module to be wrapped.
    * **lr** (float) - Fast adaptation learning rate.
    * **first_order** (bool, *optional*, default=False) - Whether to use the first-order
        approximation of MAML. (FOMAML)
    * **allow_unused** (bool, *optional*, default=None) - Whether to allow differentiation
        of unused parameters. Defaults to `allow_nograd`.
    * **allow_nograd** (bool, *optional*, default=False) - Whether to allow adaptation with
        parameters that have `requires_grad = False`.

    **References**

    1. Finn et al. 2017. "Model-Agnostic Meta-Learning for Fast Adaptation of Deep Networks."

    **Example**

    ~~~python
    linear = l2l.algorithms.MAML(nn.Linear(20, 10), lr=0.01)
    clone = linear.clone()
    error = loss(clone(X), y)
    clone.adapt(error)
    error = loss(clone(X), y)
    error.backward()
    ~~~
    """

    def __init__(
        self,
        model,
        lr,
        lr_small=None,
        first_order=False,
        allow_unused=None,
        allow_nograd=False,
    ):
        super(MAML, self).__init__()
        self.module = model
        self.lr = lr
        # Set BERT LR
        if lr_small is None:
            self.lr_small = lr
        else:
            self.lr_small = lr_small
        self.first_order = first_order
        self.allow_nograd = allow_nograd
        if allow_unused is None:
            allow_unused = allow_nograd
        self.allow_unused = allow_unused

    def forward(self, *args, **kwargs):
        return self.module(*args, **kwargs)

    def adapt(self, loss, first_order=None, allow_unused=None, allow_nograd=None):
        """
        **Description**

        Takes a gradient step on the loss and updates the cloned parameters in place.

        **Arguments**

        * **loss** (Tensor) - Loss to minimize upon update.
        * **first_order** (bool, *optional*, default=None) - Whether to use first- or
            second-order updates. Defaults to self.first_order.
        * **allow_unused** (bool, *optional*, default=None) - Whether to allow differentiation
            of unused parameters. Defaults to self.allow_unused.
        * **allow_nograd** (bool, *optional*, default=None) - Whether to allow adaptation with
            parameters that have `requires_grad = False`. Defaults to self.allow_nograd.

        """
        if first_order is None:
            first_order = self.first_order
        if allow_unused is None:
            allow_unused = self.allow_unused
        if allow_nograd is None:
            allow_nograd = self.allow_nograd
        second_order = not first_order

        if allow_nograd:
            # Compute relevant gradients
            diff_params = [p for p in self.module.parameters() if p.requires_grad]
            grad_params = grad(
                loss,
                diff_params,
                retain_graph=second_order,
                create_graph=second_order,
                allow_unused=allow_unused,
            )
            gradients = []
            grad_counter = 0

            # Handles gradients for non-differentiable parameters
            for param in self.module.parameters():
                if param.requires_grad:
                    gradient = grad_params[grad_counter]
                    grad_counter += 1
                else:
                    gradient = None
                gradients.append(gradient)
        else:
            try:
                gradients = grad(
                    loss,
                    self.module.parameters(),
                    retain_graph=second_order,
                    create_graph=second_order,
                    allow_unused=allow_unused,
                )
            except RuntimeError:
                traceback.print_exc()
                print(
                    "learn2learn: Maybe try with allow_nograd=True and/or allow_unused=True ?"
                )

        # Update the module
        self.module = maml_update(self.module, self.lr, self.lr_small, gradients)

    def clone(self, first_order=None, allow_unused=None, allow_nograd=None):
        """
        **Description**

        Returns a `MAML`-wrapped copy of the module whose parameters and buffers
        are `torch.clone`d from the original module.

        This implies that back-propagating losses on the cloned module will
        populate the buffers of the original module.
        For more information, refer to learn2learn.clone_module().

        **Arguments**

        * **first_order** (bool, *optional*, default=None) - Whether the clone uses first-
            or second-order updates. Defaults to self.first_order.
        * **allow_unused** (bool, *optional*, default=None) - Whether to allow differentiation
        of unused parameters. Defaults to self.allow_unused.
        * **allow_nograd** (bool, *optional*, default=False) - Whether to allow adaptation with
            parameters that have `requires_grad = False`. Defaults to self.allow_nograd.

        """
        if first_order is None:
            first_order = self.first_order
        if allow_unused is None:
            allow_unused = self.allow_unused
        if allow_nograd is None:
            allow_nograd = self.allow_nograd
        return MAML(
            clone_module(self.module),
            lr=self.lr,
            lr_small=self.lr_small,
            first_order=first_order,
            allow_unused=allow_unused,
            allow_nograd=allow_nograd,
        )
