import argparse
import logging
import os

import numpy as np
import tensorflow.compat.v1 as tf
import torch
import math

from unilm.modeling import UniLMForSequenceClassification
from unilm.config import UnilmConfig

logger = logging.getLogger(__name__)
logging.basicConfig(format='%(asctime)s - %(levelname)s - %(name)s -   %(message)s',
                    datefmt='%m/%d/%Y %H:%M:%S',
                    level=logging.INFO)

reciprocal_model_path = "/home/LAB/chenty/workspace/personal/MSRA_project/encryption/unilm2-he/outputs/reciprocal_model_lr2e-5_scale3_iter10000.pt"


def convert_pytorch_checkpoint_to_tf(pytorch_model_path: str, ckpt_dir: str, model_name: str):
    checkpoint_config = UnilmConfig.from_pretrained(pytorch_model_path)
    setattr(checkpoint_config, "softmax_approximation", True)
    state_dict = torch.load(os.path.join(pytorch_model_path, "pytorch_model.bin"), map_location="cpu")

    app_state_dict = torch.load(reciprocal_model_path, map_location="cpu")
    for layer_id in range(checkpoint_config.num_hidden_layers):
        head = "bert.encoder.layer.%d.attention.self.softmax_approximation.reciprocal" % layer_id
        for key in app_state_dict:
            state_dict["%s.%s" % (head, key)] = app_state_dict[key]

    setattr(checkpoint_config, 'app_ln_layer', True)
    setattr(checkpoint_config, 'app_ln_loss', False)

    model = UniLMForSequenceClassification.from_pretrained(
        pytorch_model_path, config=checkpoint_config, state_dict=state_dict)

    from unilm.squash_model_for_he import squash_model_encoder
    squash_model_encoder(model, model.config, squash_attention_scale=True)

    tensors_to_transpose = (
        "dense.weight", "attention.self.query", "attention.self.key", "attention.self.value",
        "transform.weight", "predict.weight",
    )
    # tensors_to_transpose = []

    var_map = (
        ("layer.", "layer_"),
        ("word_embeddings.weight", "word_embeddings"),
        ("position_embeddings.weight", "position_embeddings"),
        ("token_type_embeddings.weight", "token_type_embeddings"),
        (".", "/"),
        ("LayerNorm/weight", "LayerNorm/gamma"),
        ("LayerNorm/bias", "LayerNorm/beta"),
        ("weight", "kernel"),
        ("softmax_approximation/reciprocal", "softmax_approximation"),
    )

    if not os.path.isdir(ckpt_dir):
        os.makedirs(ckpt_dir)

    state_dict = model.state_dict()

    def to_tf_var_name(name: str):
        for patt, repl in iter(var_map):
            name = name.replace(patt, repl)
        return "{}".format(name)

    def create_tf_var(tensor: np.ndarray, name: str, session: tf.Session):
        tf_dtype = tf.dtypes.as_dtype(tensor.dtype)
        tf_var = tf.get_variable(dtype=tf_dtype, shape=tensor.shape, name=name, initializer=tf.zeros_initializer())
        session.run(tf.variables_initializer([tf_var]))
        session.run(tf_var)
        return tf_var

    tf.reset_default_graph()
    with tf.Session() as session:
        for var_name in state_dict:
            if "LayerNorm" in var_name:
                continue
            tf_name = to_tf_var_name(var_name)
            if "layer_norm_approximation/kernel" in tf_name:
                tf_name = tf_name.replace("layer_norm_approximation/kernel", "layer_norm_approximation/gamma")
            if "layer_norm_approximation/bias" in tf_name:
                tf_name = tf_name.replace("layer_norm_approximation/bias", "layer_norm_approximation/beta")
            torch_tensor = state_dict[var_name]
            if tf_name.endswith("kernel") and len(torch_tensor.size()) == 2 and tf_name != "classifier/kernel":
                torch_tensor = torch_tensor.unsqueeze(-1)
                if "softmax_approximation" in tf_name:
                    torch_tensor = torch_tensor.unsqueeze(-1)
            torch_tensor = torch_tensor.numpy()
            if any([x in var_name for x in tensors_to_transpose]):
                torch_tensor = torch_tensor.T
            tf_var = create_tf_var(tensor=torch_tensor, name=tf_name, session=session)
            tf.keras.backend.set_value(tf_var, torch_tensor)
            tf_weight = session.run(tf_var)
            print("Successfully created {}: {}: {}".format(tf_name, tf_weight.shape, np.allclose(tf_weight, torch_tensor)))
            if tf_name == "classifier/bias" or "bert/encoder/layer_0/attention/self/query" in tf_name:
                print(tf_weight)

        saver = tf.train.Saver(tf.trainable_variables())
        saver.save(session, os.path.join(ckpt_dir, model_name.replace("-", "_") + ".ckpt"))


def main(raw_args=None):
    parser = argparse.ArgumentParser()
    parser.add_argument("--model_name", type=str, default="transformer")
    parser.add_argument("--pytorch_model_path", type=str, default=r"F:\users\addf400\S\test\epoch-4")
    parser.add_argument("--tf_cache_dir", type=str, default="../tf_models")
    args = parser.parse_args(raw_args)

    convert_pytorch_checkpoint_to_tf(
        pytorch_model_path=args.pytorch_model_path, ckpt_dir=args.tf_cache_dir, model_name=args.model_name)


if __name__ == "__main__":
    main()
