import argparse
import os
import torch
import torch.nn as nn
from torch.utils.data import TensorDataset, SequentialSampler, DataLoader

from nlu_finetune.utils_for_glue import glue_compute_metrics as compute_metrics
from nlu_finetune.utils_for_glue import glue_output_modes as output_modes
from nlu_finetune.utils_for_glue import glue_processors as processors
from nlu_finetune.utils_for_glue import glue_convert_examples_to_features as convert_examples_to_features
from transformers import BertTokenizer, BertConfig


class BertLayerNorm(nn.Module):
    def __init__(self, hidden_size, eps=1e-5):
        """Construct a layernorm module in the TF style (epsilon inside the square root).
                """
        super(BertLayerNorm, self).__init__()
        self.weight = nn.Parameter(torch.ones(hidden_size))
        self.bias = nn.Parameter(torch.zeros(hidden_size))
        self.variance_epsilon = eps

    def forward(self, x_in):
        x = x_in.float()
        u = x.mean(-1, keepdim=True)
        s = (x - u).pow(2).mean(-1, keepdim=True)
        x = (x - u) / torch.sqrt(s + self.variance_epsilon)
        x = x.type_as(x_in)
        return self.weight * x + self.bias


class BertEmbeddings(nn.Module):
    """Construct the embeddings from word, position and token_type embeddings.
    """
    def __init__(self, config):
        super(BertEmbeddings, self).__init__()
        embedding_size = get_embedding_size(config)
        self.word_embeddings = nn.Embedding(config.vocab_size, embedding_size, padding_idx=0)
        self.position_embeddings = nn.Embedding(config.max_position_embeddings, embedding_size)
        if config.type_vocab_size > 0:
            self.token_type_embeddings = nn.Embedding(config.type_vocab_size, embedding_size)
        else:
            self.token_type_embeddings = None

        # self.LayerNorm is not snake-cased to stick with TensorFlow model variable name and be able to load
        # any TensorFlow checkpoint file
        self.LayerNorm = BertLayerNorm(embedding_size, eps=config.layer_norm_eps)

        if embedding_size != config.hidden_size:
            self.embedding_projection = nn.Linear(
                embedding_size, config.hidden_size, bias=True)
        else:
            self.embedding_projection = None

    def forward(self, input_ids=None, token_type_ids=None, position_ids=None, inputs_embeds=None):
        if input_ids is not None:
            input_shape = input_ids.size()
        else:
            input_shape = inputs_embeds.size()[:-1]

        seq_length = input_shape[1]
        device = input_ids.device if input_ids is not None else inputs_embeds.device
        if position_ids is None:
            position_ids = torch.arange(seq_length, dtype=torch.long, device=device)
            position_ids = position_ids.unsqueeze(0).expand(input_shape)
        if token_type_ids is None:
            token_type_ids = torch.zeros(input_shape, dtype=torch.long, device=device)

        if inputs_embeds is None:
            inputs_embeds = self.word_embeddings(input_ids)
        position_embeddings = self.position_embeddings(position_ids)

        embeddings = inputs_embeds + position_embeddings

        if self.token_type_embeddings:
            embeddings = embeddings + self.token_type_embeddings(token_type_ids)

        embeddings = self.LayerNorm(embeddings)

        if self.embedding_projection is not None:
            embeddings = self.embedding_projection(embeddings)

        return embeddings, position_ids


def get_embedding_size(config):
    if hasattr(config, "embedding_size") and config.embedding_size:
        return config.embedding_size
    else:
        return config.hidden_size


def load_and_cache_examples(task, data_dir, tokenizer, model_type, max_seq_length):
    processor = processors[task]()
    output_mode = output_modes[task]
    examples = processor.get_dev_examples(data_dir)
    label_list = processor.get_labels()
    features = convert_examples_to_features(
        examples, tokenizer, label_list=label_list, max_length=max_seq_length,
        output_mode=output_mode, pad_on_left=bool(model_type in ['xlnet']),
        pad_token=tokenizer.convert_tokens_to_ids([tokenizer.pad_token])[0],
        pad_token_segment_id=4 if model_type in ['xlnet'] else 0,
    )

    all_input_ids = torch.tensor([f.input_ids for f in features], dtype=torch.long)
    all_attention_mask = torch.tensor([f.attention_mask for f in features], dtype=torch.long)
    all_token_type_ids = torch.tensor([f.token_type_ids for f in features], dtype=torch.long)
    if output_mode == "classification":
        all_labels = torch.tensor([f.label for f in features], dtype=torch.long)
    elif output_mode == "regression":
        all_labels = torch.tensor([f.label for f in features], dtype=torch.float)
    else:
        raise NotImplementedError()

    dataset = TensorDataset(all_input_ids, all_attention_mask, all_token_type_ids, all_labels)

    return dataset


def get_args():
    parser = argparse.ArgumentParser()
    parser.add_argument("--task", type=str, default="sst-2")
    parser.add_argument("--data_dir", type=str, default=r"D:\datasets\glue_data\sst-2")
    parser.add_argument("--pytorch_model", type=str, default=r"F:\users\addf\S\test\epoch-4")
    args = parser.parse_args()
    return args


def get_embedding_layer(args):
    config_file = os.path.join(args.pytorch_model, "config.json")
    weight_file = os.path.join(args.pytorch_model, "pytorch_model.bin")
    config = BertConfig.from_pretrained(config_file)
    state_dict = torch.load(weight_file)
    selected_dict = {}
    for key in state_dict:
        if key.startswith("bert.embeddings."):
            selected_dict[key[len("bert.embeddings."):]] = state_dict[key]

    print("Select dict keys = %s" % str(selected_dict.keys()))

    embedding_layer = BertEmbeddings(config)

    embedding_layer.load_state_dict(selected_dict)
    return embedding_layer


def main():
    args = get_args()
    tokenizer = BertTokenizer.from_pretrained("bert-base-uncased", do_lower_case=True)
    dataset = load_and_cache_examples(
        task=args.task, data_dir=args.data_dir, tokenizer=tokenizer,
        model_type="bert", max_seq_length=64,
    )

    embedding_layer = get_embedding_layer(args)

    eval_sampler = SequentialSampler(dataset)
    eval_dataloader = DataLoader(dataset, sampler=eval_sampler, batch_size=10)

    embedding_layer.eval()

    for batch in eval_dataloader:
        input_ids, attention_mask, token_type_ids, labels = batch

        with torch.no_grad():
            input_tensor, position_ids = embedding_layer(input_ids, token_type_ids)
            pass


if __name__ == "__main__":
    main()
