'''
Date: 2021-07-22 15:39:11
LastEditors: Wu Xianze (wuxianze.0@bytedance.com)
LastEditTime: 2021-07-22 15:58:55
'''
import argparse

from fairseq import utils
from fairseq.file_io import PathManager
from fairseq.checkpoint_utils import load_checkpoint_to_cpu, torch_persistent_save
from fairseq import tasks

def load_checkpoints_fn(filenames, arg_overrides=None):
    models, ckpt_states = [], []
    for filename in filenames:
        if not PathManager.exists(filename):
            raise IOError("Model file not found: {}".format(filename))
        state = load_checkpoint_to_cpu(filename)

        args = state["args"]
        if arg_overrides is not None:
            for k in arg_overrides:
                setattr(args, k, arg_overrides[k])
        task = tasks.setup_task(args)
        model = task.build_model(args)
        model.load_state_dict(state["model"], args=args)
        models.append(model)
        ckpt_states.append(state)
    return models, ckpt_states

def check_checkpoint(args):
    # for path in args.in_ckpts:
    arg_overrides = {
        "data": args.data,
        "extended_dict": args.extended_dict
    }
    models, ckpt_states = load_checkpoints_fn(
        args.in_ckpts, arg_overrides
    )
    for model in models:
        print(model)

def calTwoCkpts(args):
    """
        we assume that the architeture of the first and the second checkpoint are same
    """
    arg_overrides = {
        "data": args.data,
        "extended_dict": args.extended_dict
    }
    models, ckpt_states = load_checkpoints_fn(
        args.in_ckpts, arg_overrides
    )

    assert len(models) >= 2, "at least two models must be provided"
    model_state_dict = {}
    for k in models[0].state_dict():
        if args.oper == "diff":
            model_state_dict[k] = models[0].state_dict()[k] - models[1].state_dict()[k]
    # for k in model_state_dict:
    #     print(k, model_state_dict[k].size())
    saved_ckpt_state_dict = {
        "args": ckpt_states[0]["args"],
        "model": model_state_dict,
        "optimizer_history": [ckpt_states[0]["optimizer_history"][-1]],
        "extra_state": ckpt_states[0]["extra_state"]
    }

    with PathManager.open(args.out_ckpt, "wb") as f:
        torch_persistent_save(saved_ckpt_state_dict, f)


if __name__ == "__main__":
    parser = argparse.ArgumentParser()
    
    parser.add_argument("--user-dir", type=str, default="examples/summarization")
    parser.add_argument("--in-ckpts", nargs="+", type=str, default=["/home/tiger/xgiga_finetune_pretrainedMspm4_freezeDecoder_bsz32i_ls0.2_siu4000_p40_eval_zh/model/mbart.cc25/model.pt"])
    parser.add_argument("--out-ckpt", type=str)
    parser.add_argument("--oper", choices=['add', 'diff'], default="diff", help="the type of operations")
    parser.add_argument("-m", type=str, default="check_checkpoint")

    # fairseq task specific arguments
    parser.add_argument("--data", type=str,
        # default="/home/tiger/zeroshot_multi_finetune_inplaceBartNoiseMspm4_unconGenerateTgt_freezeEncoder_zh/resource/dataset"
        default="/home/tiger/xgiga_finetune_pretrainedMspm4_freezeDecoder_bsz32i_ls0.2_siu4000_p40_eval_zh/resource/dataset"
    )
    parser.add_argument("--extended-dict", type=str,
        default="/home/tiger/zeroshot_multi_finetune_inplaceBartNoiseMspm4_unconGenerateTgt_freezeEncoder_zh/resource/dataset/extra_extend_0716.txt"
    )
    parser.add_argument('-s', '--source-lang', default="doc", metavar='SRC',
                        help='source language')
    parser.add_argument('-t', '--target-lang', default="sum", metavar='TARGET',
                        help='target language')
    args = parser.parse_args()

    utils.import_user_module(args)

    eval("{}(args)".format(args.m))
    # check_checkpoint(args)

    """
    python3 ckptOperations.py --in-ckpts /home/tiger/xgiga_finetune_pretrainedMspm4_freezeDecoder_bsz32i_ls0.2_siu4000_p40_eval_zh/model/mbart.cc25/model.pt  -m check_checkpoint --data /home/tiger/xgiga_finetune_pretrainedMspm4_freezeDecoder_bsz32i_ls0.2_siu4000_p40_eval_zh/resource/dataset
    """