# Copyright (c) Facebook, Inc. and its affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
"""
RoBERTa: A Robustly Optimized BERT Pretraining Approach.
"""

import logging

import torch
import torch.nn as nn


logger = logging.getLogger(__name__)


class KDWrapper(nn.Module):

    def __init__(self, args, model, task):
        super().__init__()
        self.student_args = args
        self.student_model = model

        try:
            # "DistilBERT, a distilled version of BERT: smaller, faster, cheaper and lighter" (Sanh et al., 2019)
            model.encoder.sentence_encoder.embed_positions.weight.requires_grad = False
        except AttributeError:
            pass

        state = torch.load(args.teacher_file, map_location="cpu")
        logger.info("loaded the teacher model from {}".format(args.teacher_file))
        self.teacher_args = state['args']
        logger.info("teacher args: {}".format(state['args']))
        # update the teacher args
        for k, v in vars(args).items():
            if not hasattr(self.teacher_args, k):
                # logger.info("Add {}={} to the teacher args".format(k, v))
                setattr(self.teacher_args, k, v)
        # build and load the teacher model
        teacher = task.build_model(self.teacher_args)
        teacher.load_state_dict(state["model"], strict=True)
        # do not update the teacher model
        for _, param in teacher.named_parameters():
            param.requires_grad = False
        self.teacher_model = teacher
        # logger.info("teacher model: {}".format(teacher))

    def __getattr__(self, name):
        try:
            return super().__getattr__(name)
        except AttributeError:
            wrapped_module = super().__getattr__("student_model")
            return getattr(wrapped_module, name)

    def get_soft_targets(
        self,
        *args,
        **kwargs
    ):
        if hasattr(self, "teacher_model"):
            with torch.no_grad():
                x, extra = self.teacher_model(*args, **kwargs)
            return x, extra
        else:
            raise NotImplementedError("Do not support distllation without a teacher")

    def upgrade_state_dict_named(self, state_dict, name):
        prefix = name + "." if name != "" else ""

        # add student_model. before upgrading children modules
        for k in list(state_dict.keys()):
            if not k.startswith(prefix + "student_model"):
                new_k = prefix + "student_model." + k[len(prefix):]
                state_dict[new_k] = state_dict[k]
                del state_dict[k]

        # upgrade children modules
        super().upgrade_state_dict_named(state_dict, name)

    def load_state_dict(self, state_dict, strict=True, args=None):
        return self.student_model.load_state_dict(state_dict, strict=strict, args=args)

    def state_dict(self, destination=None, prefix='', keep_vars=False):
        state = super(KDWrapper, self).state_dict(destination=destination, prefix=prefix, keep_vars=keep_vars)
        for k in list(state.keys()):
            if k.startswith('teacher_model'):
                del state[k]
            elif k.startswith('student_model'):
                new_k = k[len('student_model.'):]
                state[new_k] = state[k]
                del state[k]
        return state

    def forward(self, *args, **kwargs):
        return self.student_model(*args, **kwargs)
