from __future__ import absolute_import
from __future__ import division
from __future__ import print_function

import os
import logging
import argparse
import torch


logging.basicConfig(format='%(asctime)s - %(levelname)s - %(name)s -   %(message)s',
                    datefmt='%m/%d/%Y %H:%M:%S',
                    level=logging.INFO)
logger = logging.getLogger(__name__)


def get_args():
    parser = argparse.ArgumentParser()
    parser.add_argument("--model_dir", type=str, help="Folder of checkpoints.")
    parser.add_argument("--steps", default=None, type=str, help="step1_step2_step3_....")
    parser.add_argument("--num_ckpts", default=0, type=int, help="Fusing last n checkpoints.")
    parser.add_argument("--step_length", default=1, type=int, help="Step of last n checkpoints.")
    parser.add_argument("--num_max_steps", default=500000, type=int, help="Maximum step of checkpoints.")
    parser.add_argument("--decay", default=None, type=float)
    return parser.parse_args()


def main():
    args = get_args()
    if args.steps is None:
        assert args.num_ckpts > 0
        assert args.step_length > 0
        assert args.num_max_steps > 0
        num_ckpts = args.num_ckpts
        cur_steps = args.num_max_steps
        all_steps = []
        while num_ckpts > 0:
            all_steps.append(str(cur_steps))
            cur_steps -= args.step_length
            num_ckpts -= 1
        args.steps = "_".join(all_steps)

    all_steps = [int(step) for step in args.steps.split('_')]
    state_dict = None

    for step in sorted(all_steps):
        checkpoint_dir = os.path.join(args.model_dir, "steps-%d" % step)
        checkpoint_state_dict = torch.load(os.path.join(checkpoint_dir, "pytorch_model.bin"), map_location='cpu')

        logger.info("Load checkpoint from %s" % checkpoint_dir)

        if state_dict is None:
            state_dict = checkpoint_state_dict
            if args.decay is None:
                logger.info("Init moving average")
                for key in state_dict:
                    state_dict[key] /= len(all_steps)
            else:
                logger.info("Init exp moving average")
        else:
            for key in checkpoint_state_dict:
                if args.decay is None:
                    state_dict[key] += checkpoint_state_dict[key] / len(all_steps)
                else:
                    state_dict[key] = state_dict[key] * args.decay + checkpoint_state_dict[key] * (1 - args.decay)

    output_dir_name = "n%d-l%d-%f" % (args.num_ckpts, args.step_length, 0 if args.decay is None else args.decay)
    output_dir = os.path.join(args.model_dir, output_dir_name)
    if not os.path.exists(output_dir):
        os.makedirs(output_dir)

    logger.info("***** Dump model: %s *****", output_dir)
    torch.save(state_dict, os.path.join(output_dir, "pytorch_model.bin"))


if __name__ == "__main__":
    main()
