import os
import shutil
from pathlib import Path

import torch
from collections import OrderedDict

from utils.distributed_processing import is_main_process
from utils.misc import mkdir
import torch.nn as nn
import glob
import re
from modeling.modeling_ema import EMA_Model


def load_state_dict_with_mismatch(model: nn.Module, state_dict_path: str):
    """operated in-place, no need to return model"""
    if isinstance(state_dict_path, str):
        loaded_state_dict = torch.load(
            state_dict_path, map_location="cpu")
    else:
        raise ValueError(f"{state_dict_path} is not exist.")

    model_keys = set([k for k in list(model.state_dict().keys())])
    load_keys = set(loaded_state_dict.keys())

    toload: OrderedDict[str, torch.Tensor] = OrderedDict()
    mismatched_shape_keys = []
    for k in model_keys:
        if k in load_keys:
            if model.state_dict()[k].shape != loaded_state_dict[k].shape:
                mismatched_shape_keys.append(k)
            else:
                toload[k] = loaded_state_dict[k]

    diff_keys_k = load_keys.difference(model_keys)
    diff_keys_m = model_keys.difference(load_keys)

    parent_path = os.path.basename(os.path.dirname(state_dict_path))
    if len(diff_keys_k) != 0 or len(diff_keys_m) != 0:
        print(f"Model in {parent_path} : ")
        print("Keys in loaded but not in model:")
        print(f"In total {len(diff_keys_k)}, {sorted(diff_keys_k)}")
        print("Keys in model but not in loaded")
        print(f"In total {len(diff_keys_m)}, {sorted(diff_keys_m)}")
        print("Keys in model and loaded, but shape mismatched:")
        print(f"In total {len(mismatched_shape_keys)}, {sorted(mismatched_shape_keys)}")
    else:
        print(f" - Model in {parent_path} : All the keys and parameter are matched.")
    model.load_state_dict(toload)


def save_ema_model_from_ckpt(model:nn.Module, checkpoint_path:str, num_model:int = None) -> str:


    # Set output folder
    parent_path = Path(checkpoint_path).parent.__str__()
    EMA_model_folder = parent_path + f"/EMA_{num_model}_model_checkpoint"
    output_file = EMA_model_folder + "/pytorch_model.bin"

    # Return only the name of the folder if the process is not the main process.
    if is_main_process() is not True:
        return EMA_model_folder

    # Make a directory by copying with the current model directory tree.
    if not os.path.exists(EMA_model_folder):
        shutil.copytree(checkpoint_path, EMA_model_folder)
    else:
        print(f"Overwriting the best folder: EMA_{num_model} folder already exist.")
        shutil.rmtree(EMA_model_folder)
        shutil.copytree(checkpoint_path, EMA_model_folder)

    # Remove prediction files if it copied from other checkpoint.
    pred_files_in_path = glob.glob(EMA_model_folder + '/pred.*')
    for f in pred_files_in_path:
        os.remove(f)

    # Get the folders
    state_dict_folder_path = parent_path + '/checkpoint-*'
    state_dict_folders = glob.glob(state_dict_folder_path)

    # Get folder names and sort it by alphabetical order.
    # - Folders are saved with the name "checkpoint-#epoch-#iteration"
    # - Names will sorted with the iteration#.
    state_dict_folder_names = [
        name for name in state_dict_folders
    ]
    state_dict_folder_names.sort(key=lambda o: int(o.split('-')[-1]))
    state_dict_folder_names = state_dict_folder_names[-num_model:]

    # EMA class
    #  - Set alpha for the exponential moving average
    alpha = 1 - 2 / (num_model + 1)
    #  - Initialize EMA model class
    ema = EMA_Model(model, decay=alpha)
    # Averaging the model
    for i, state_dict_path in enumerate(state_dict_folder_names):
        # iteration = re.findall('\d+', state_dict_path)[-2]
        state_dict_file = state_dict_path + "/pytorch_model.bin"

        # Load model after checking whether the pytorch model exist.
        if os.path.exists(state_dict_file):
            load_state_dict_with_mismatch(model, state_dict_file)
        else:
            raise ValueError(f"{state_dict_path} is not exist.")

        if i == 0:
            # Set initialized model with loaded model.
            ema.set(model)
            continue

        ema.update(model)

    # Save pytorch.bin file in the output path.
    torch.save(ema.ema_model.state_dict(), output_file)
    print(f" - EMA model saved at {output_file}")

    # Load ema model to the model (in-place)
    # model.load_state_dict(ema.ema_model.state_dict())
    # print(f"EMA model loaded.")
    return EMA_model_folder