# Copyright (c) Facebook, Inc. and its affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.

"""
Train a network across multiple GPUs.
"""
import os
import logging

import torch

from fairseq import checkpoint_utils
from fairseq.file_io import PathManager
from fairseq.logging import meters, metrics
from fairseq.trainer import Trainer
from examples.summarization.utils import loadAdapterConfig

logger = logging.getLogger(__name__)

def ckpt_to_adapter_path(
    ckpt_path,
    adapter_key
):
    """
    Input:
        ckpt_path: str
        adapter_key: str
    Return:
        adapter_path: str
    """
    dir, name = os.path.split(ckpt_path)
    adapter_name = name.replace(
        "checkpoint",
        adapter_key
    )
    adapter_path = os.path.join(
        dir, adapter_name
    )
    return adapter_path

def save_iadapter_state(
    args,
    iadapter_state_dict,
    model_ckpt_path,
    model_ckpt_paths
):
    from fairseq import utils

    iadapter_config = loadAdapterConfig(args.iadapter_config)
    for layer_config in iadapter_config["layers"]:
        for iadapter_name in layer_config["adapter_name"]:
            key = "{}_adapter".format(iadapter_name)
            model_state_dict = {}
            for k in iadapter_state_dict:
                if key in k:
                    model_state_dict[k] = iadapter_state_dict[k]
            
            state_dict = {
                "args": args,
                "model": model_state_dict or {},
            }
            state_dict = utils.move_to_cpu(state_dict)

            dir, model_ckpt_name = os.path.split(
                model_ckpt_path
            )
            filename = model_ckpt_name.replace("checkpoint", key)
            filepath = os.path.join(dir, filename)
            with PathManager.open(filepath, "wb") as f:
                checkpoint_utils.torch_persistent_save(state_dict, f)
            for cp in model_ckpt_paths[1:]:
                cp = ckpt_to_adapter_path(cp, key)
                PathManager.copy(filepath, cp, overwrite=True)

def load_iadapter_state(
    model, args, dir
):
    """
    we assume that iadapter has been created in the Model
    """
    iadapter_config = loadAdapterConfig(args.iadapter_config)
    intergratedAdapter = model.iadapter
    for layer_config in iadapter_config["layers"]:
        for iadapter_name in layer_config["adapter_name"]:
            adapter_ckpt_path = os.path.join(
                dir,
                iadapter_name + "_adapter_best.pt"
            )
            if os.path.exists(adapter_ckpt_path):
                model_state_dict = torch.load(adapter_ckpt_path)['model']
                # rename, remove the prefix "iadapter.{iadapter_name}."
                renamed_state_dict = {}
                for k in model_state_dict:
                    rename_k = k.replace(
                        "iadapter.{}_adapter.".format(iadapter_name),
                        ""
                    )
                    renamed_state_dict[rename_k] = model_state_dict[k]
                logger.info("Loading adapter from {} for adapter {}".format(
                    adapter_ckpt_path,
                    iadapter_name
                ))
                intergratedAdapter.adapters[iadapter_name].load_state_dict(
                    renamed_state_dict
                )
            else:
                logger.info("Adapter {} is not loaded".format(iadapter_name))
    print("finish loading")


class iadapterTrainer(Trainer):
    """Main class for data parallel training.

    This class supports synchronous distributed data parallel training,
    where multiple workers each have a full model replica and gradients
    are accumulated across workers before each update. We use
    :class:`~torch.nn.parallel.DistributedDataParallel` to handle
    communication of the gradients across workers.
    """

    def __init__(self, args, task, model, criterion, quantizer=None):
        super().__init__(
            args, task, model, criterion, quantizer
        )

    def save_checkpoint(self, filename, extra_state, model_ckpt_paths: list):
        """Save all training state in a checkpoint file."""
        
        if self.is_data_parallel_master:  # only save one checkpoint
            extra_state["metrics"] = metrics.state_dict()
            # save paras wo iadapter
            state_dict_wo_iadapter = {}
            state_dict_w_iadapter = {}
            for (k, v) in self.get_model().state_dict().items():
                if "iadapter" not in k:
                    state_dict_wo_iadapter[k] = v
                else:
                    state_dict_w_iadapter[k] = v

            checkpoint_utils.save_state(
                filename,
                self.args,
                state_dict_wo_iadapter,
                self.get_criterion(),
                self.optimizer,
                self.lr_scheduler,
                self.get_num_updates(),
                self._optim_history,
                extra_state,
            )

            save_iadapter_state(self.args, state_dict_w_iadapter, filename, model_ckpt_paths)

    def load_checkpoint(
        self,
        filename,
        reset_optimizer=False,
        reset_lr_scheduler=False,
        optimizer_overrides=None,
        reset_meters=False,
    ):
        extra_state = super().load_checkpoint(
            filename, 
            reset_optimizer, 
            reset_lr_scheduler,
            optimizer_overrides,
            reset_meters
        )
        # load adapter
        if self.args.trained_iadapter_dir is not None:
            load_iadapter_state(
                self.get_model(), self.args, self.args.trained_iadapter_dir
            )

        return extra_state
