# coding=utf-8
# Copyright 2018 The Google AI Language Team Authors and The HuggingFace Inc. team.
# Copyright (c) 2018, NVIDIA CORPORATION.  All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
""" Finetuning the library models for sequence classification on GLUE (Bert, XLM, XLNet, RoBERTa)."""

from __future__ import absolute_import, division, print_function

import argparse
import glob
import logging
import os
import random
import json
import copy
import collections
from nlu_finetune.utils_for_glue import glue_output_modes, glue_processors

import numpy as np
import torch
from torch.utils.data import (DataLoader, RandomSampler, SequentialSampler,
                              TensorDataset)
from torch.utils.data.distributed import DistributedSampler

try:
    from torch.utils.tensorboard import SummaryWriter
except:
    from tensorboardX import SummaryWriter

from tqdm import tqdm, trange

from transformers import (WEIGHTS_NAME, BertConfig,
                          BertForSequenceClassification,BertTokenizer,
                          RobertaConfig,
                          RobertaForSequenceClassification,
                          RobertaTokenizer,
                          XLMConfig, XLMForSequenceClassification,
                          XLMTokenizer, XLNetConfig,
                          XLNetForSequenceClassification,
                          XLNetTokenizer,
                          DistilBertConfig,
                          DistilBertForSequenceClassification,
                          DistilBertTokenizer,
                          AlbertConfig,
                          AlbertForSequenceClassification,
                          AlbertTokenizer,
                          XLMRobertaConfig,
                          XLMRobertaForSequenceClassification,
                          XLMRobertaTokenizer,
                          ElectraConfig,
                          ElectraForSequenceClassification,
                          ElectraTokenizer,
)

from transformers import AdamW, get_linear_schedule_with_warmup
from unilm.modeling import UniLMForSequenceClassification
from unilm.optimization_utils import get_optimizer_grouped_parameters, add_optimzation_args
from unilm.configuration_unilm import UnilmConfig
from unilm.tokenization_unilm import UniLMAutoTokenizer

from nlu_finetune.utils_for_glue import glue_compute_metrics as compute_metrics
from nlu_finetune.utils_for_glue import glue_output_modes as output_modes
from nlu_finetune.utils_for_glue import glue_processors as processors
from nlu_finetune.utils_for_glue import glue_convert_examples_to_features as convert_examples_to_features

logger = logging.getLogger(__name__)

ALL_MODELS = sum((tuple(conf.pretrained_config_archive_map.keys()) for conf in (BertConfig, XLNetConfig, XLMConfig, 
                                                                                RobertaConfig, DistilBertConfig)), ())

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
reciprocal_model_path = "/home/LAB/chenty/workspace/personal/MSRA_project/encryption/unilm2-he/outputs/reciprocal_model_lr2e-5_scale3_iter10000.pt"

MODEL_CLASSES = {
    'bert': (BertConfig, BertForSequenceClassification, BertTokenizer),
    'xlnet': (XLNetConfig, XLNetForSequenceClassification, XLNetTokenizer),
    'xlm': (XLMConfig, XLMForSequenceClassification, XLMTokenizer),
    'roberta': (RobertaConfig, RobertaForSequenceClassification, RobertaTokenizer),
    'distilbert': (DistilBertConfig, DistilBertForSequenceClassification, DistilBertTokenizer),
    'albert': (AlbertConfig, AlbertForSequenceClassification, AlbertTokenizer),
    'xlm-roberta': (XLMRobertaConfig, XLMRobertaForSequenceClassification, XLMRobertaTokenizer),
    'unilm': (UnilmConfig, UniLMForSequenceClassification, UniLMAutoTokenizer),
    'electra': (ElectraConfig, ElectraForSequenceClassification, ElectraTokenizer),
}


def set_seed(args):
    random.seed(args.seed)
    np.random.seed(args.seed)
    torch.manual_seed(args.seed)
    if args.n_gpu > 0:
        torch.cuda.manual_seed_all(args.seed)


class NoisedDataGenerator(object):
    def __init__(self,
                 enable_kl_loss,
                 kl_lambda=5.0,
                 original_loss=True,
                 noised_loss=False,
                 max_length=512,
                 overall_ratio=1.0,
                 enable_bpe_switch=False,
                 bpe_switch_ratio=0.5,
                 tokenizer_dir=None,
                 do_lower_case=False,
                 tokenizer_languages=None,
                 enable_bpe_sampling=False,
                 tokenizer=None,
                 bpe_sampling_ratio=0.5,
                 sampling_alpha=0.3,
                 sampling_nbest_size=-1,
                 enable_random_noise=False,
                 noise_detach_embeds=False,
                 noise_eps=1e-5,
                 noise_type='uniform',
                 enable_code_switch=False,
                 code_switch_ratio=0.5,
                 dict_dir=None,
                 dict_languages=None,
                 enable_word_dropout=False,
                 word_dropout_rate=0.1,
                 enable_translate_data=False,
                 train_language=None,
                 data_dir=None,
                 translate_different_pair=False,
                 translate_en_data=False):
        if enable_code_switch:
            assert dict_dir is not None
            assert dict_languages is not None
        assert tokenizer is not None
        if enable_random_noise:
            assert noise_type in ['uniform', 'normal']

        self.n_tokens = 0
        self.n_cs_tokens = 0
        self.enable_kl_loss = enable_kl_loss
        self.kl_lambda = kl_lambda
        self.original_loss = original_loss
        self.noised_loss = noised_loss
        self.max_length = max_length
        self.overall_ratio = overall_ratio

        self.enable_bpe_switch = enable_bpe_switch
        self.bpe_switch_ratio = bpe_switch_ratio / self.overall_ratio
        assert self.bpe_switch_ratio <= 1.0
        self.tokenizer_dir = tokenizer_dir
        self.tokenizer_languages = tokenizer_languages

        self.enable_bpe_sampling = enable_bpe_sampling
        self.bpe_sampling_ratio = bpe_sampling_ratio / self.overall_ratio
        assert self.bpe_sampling_ratio <= 1.0
        self.tokenizer = tokenizer
        self.sampling_alpha = sampling_alpha
        self.sampling_nbest_size = sampling_nbest_size

        self.enable_random_noise = enable_random_noise
        self.noise_detach_embeds = noise_detach_embeds
        self.noise_eps = noise_eps
        self.noise_type = noise_type

        self.enable_word_dropout = enable_word_dropout
        self.word_dropout_rate = word_dropout_rate

        self.enable_translate_data = enable_translate_data
        self.train_language = train_language
        self.data_dir = data_dir
        self.translate_different_pair = translate_different_pair
        self.translate_en_data = translate_en_data

        self.enable_code_switch = enable_code_switch
        self.code_switch_ratio = code_switch_ratio / self.overall_ratio
        assert self.code_switch_ratio <= 1.0
        self.dict_dir = dict_dir
        self.dict_languages = dict_languages
        self.lang2dict = {}
        for lang in dict_languages:
            dict_path = os.path.join(self.dict_dir, "{}2.txt".format(lang))
            assert os.path.exists(dict_path)
            logger.info("reading dictionary from {}".format(dict_path))
            assert os.path.exists(dict_path)
            with open(dict_path) as reader:
                raw = reader.readlines()
            self.lang2dict[lang] = {}
            for line in raw:
                line = line.strip()
                try:
                    src, tgt = line.split("\t")
                except:
                    src, tgt = line.split(" ")
                if src not in self.lang2dict[lang]:
                    self.lang2dict[lang][src] = [tgt]
                else:
                    self.lang2dict[lang][src].append(tgt)

        self.lang2tokenizer = {}
        for lang in tokenizer_languages:
            self.lang2tokenizer[lang] = XLMRobertaTokenizer.from_pretrained(
                os.path.join(tokenizer_dir, "{}".format(lang)), do_lower_case=do_lower_case)

    def get_noised_dataset(self, examples, **kwargs):
        # maybe do not save augmented examples
        examples = copy.deepcopy(examples)

        if self.enable_translate_data:
            examples = self.load_translate_data()
            assert not self.enable_code_switch

        if self.enable_code_switch:
            self.n_tokens = 0
            self.n_cs_tokens = 0

        dataset = self.convert_examples_to_dataset(examples, **kwargs)

        if self.enable_code_switch:
            logger.info("{:.2f}% tokens have been code-switched.".format(self.n_cs_tokens / self.n_tokens * 100))
        return dataset

    def encode_sentence(self, text, switch_text=False):
        if text is None:
            return None
        ids = []
        tokens = text.split(" ")
        for token in tokens:
            switch_token = random.random() <= self.overall_ratio
            self.n_tokens += 1
            if self.enable_code_switch and switch_text and switch_token and random.random() <= self.code_switch_ratio:
                lang = self.dict_languages[random.randint(0, len(self.dict_languages) - 1)]
                if token.lower() in self.lang2dict[lang]:
                    self.n_cs_tokens += 1
                    token = self.lang2dict[lang][token.lower()][
                        random.randint(0, len(self.lang2dict[lang][token.lower()]) - 1)]

            if self.enable_bpe_switch and switch_text and switch_token and random.random() <= self.bpe_switch_ratio:
                lang = self.tokenizer_languages[random.randint(0, len(self.tokenizer_languages) - 1)]
                tokenizer = self.lang2tokenizer[lang]
            else:
                tokenizer = self.tokenizer

            if self.enable_bpe_sampling and switch_text and switch_token and random.random() <= self.bpe_sampling_ratio:
                token_ids = tokenizer.encode_plus(token, add_special_tokens=True,
                                                  nbest_size=self.sampling_nbest_size,
                                                  alpha=self.sampling_alpha)["input_ids"]
            else:
                token_ids = tokenizer.encode_plus(token, add_special_tokens=True)["input_ids"]

            if self.enable_word_dropout:
                for token_id in token_ids[1:-1]:
                    if random.random() <= self.word_dropout_rate:
                        ids += [tokenizer.unk_token_id]
                    else:
                        ids += [token_id]
            else:
                ids += token_ids[1:-1]
        return ids

    def encode_plus(self, text_a, text_b):
        # switch all sentences
        switch_text = True
        ids = self.encode_sentence(text_a, switch_text)
        pair_ids = self.encode_sentence(text_b, switch_text)

        pair = bool(pair_ids is not None)
        len_ids = len(ids)
        len_pair_ids = len(pair_ids) if pair else 0

        # Handle max sequence length
        total_len = len_ids + len_pair_ids + (self.tokenizer.num_added_tokens(pair=pair))

        encoded_inputs = {
            "total_len": total_len,
        }

        if self.max_length and total_len > self.max_length:
            ids, pair_ids, overflowing_tokens = self.tokenizer.truncate_sequences(
                ids,
                pair_ids=pair_ids,
                num_tokens_to_remove=total_len - self.max_length,
                truncation_strategy="longest_first",
                stride=0,
            )

        # Handle special_tokens
        sequence = self.tokenizer.build_inputs_with_special_tokens(ids, pair_ids)
        token_type_ids = self.tokenizer.create_token_type_ids_from_sequences(ids, pair_ids)

        encoded_inputs["input_ids"] = sequence
        encoded_inputs["token_type_ids"] = token_type_ids

        return encoded_inputs

    def convert_examples_to_dataset(self, examples, tokenizer,
                                    max_length=512,
                                    task=None,
                                    label_list=None,
                                    output_mode=None,
                                    pad_on_left=False,
                                    pad_token=0,
                                    pad_token_segment_id=0,
                                    mask_padding_with_zero=True):
        """
        Loads a data file into a list of ``InputFeatures``

        Args:
            examples: List of ``InputExamples`` or ``tf.data.Dataset`` containing the examples.
            tokenizer: Instance of a tokenizer that will tokenize the examples
            max_length: Maximum example length
            task: GLUE task
            label_list: List of labels. Can be obtained from the processor using the ``processor.get_labels()`` method
            output_mode: String indicating the output mode. Either ``regression`` or ``classification``
            pad_on_left: If set to ``True``, the examples will be padded on the left rather than on the right (default)
            pad_token: Padding token
            pad_token_segment_id: The segment ID for the padding token (It is usually 0, but can vary such as for XLNet where it is 4)
            mask_padding_with_zero: If set to ``True``, the attention mask will be filled by ``1`` for actual values
                and by ``0`` for padded values. If set to ``False``, inverts it (``1`` for padded values, ``0`` for
                actual values)

        Returns:
            If the ``examples`` input is a ``tf.data.Dataset``, will return a ``tf.data.Dataset``
            containing the task-specific features. If the input is a list of ``InputExamples``, will return
            a list of task-specific ``InputFeatures`` which can be fed to the model.

        """
        is_tf_dataset = False

        if task is not None:
            processor = glue_processors[task]()
            if label_list is None:
                label_list = processor.get_labels()
                logger.info("Using label list %s for task %s" % (label_list, task))
            if output_mode is None:
                output_mode = glue_output_modes[task]
                logger.info("Using output mode %s for task %s" % (output_mode, task))

        label_map = {label: i for i, label in enumerate(label_list)}

        all_input_ids = []
        all_attention_mask = []
        all_token_type_ids = []

        length_counter = collections.defaultdict(int)
        N = 8
        sum_of_lengths = 0

        for (ex_index, example) in enumerate(examples):
            if ex_index % 10000 == 0:
                logger.info("Writing example %d" % (ex_index))
            if is_tf_dataset:
                example = processor.get_example_from_tensor_dict(example)
                example = processor.tfds_map(example)

            inputs = self.encode_plus(
                example.text_a,
                example.text_b,
            )
            
            input_ids = inputs["input_ids"]
            if "token_type_ids" in inputs:
                token_type_ids = inputs["token_type_ids"]
            else:
                token_type_ids = []

            # The mask has 1 for real tokens and 0 for padding tokens. Only real
            # tokens are attended to.
            attention_mask = [1 if mask_padding_with_zero else 0] * len(input_ids)

            # Zero-pad up to the sequence length.
            padding_length = max_length - len(input_ids)

            length_range = ((inputs["total_len"] + N - 1) // N) * N
            length_counter[length_range] += 1
            sum_of_lengths += inputs["total_len"]

            if pad_on_left:
                input_ids = ([pad_token] * padding_length) + input_ids
                attention_mask = ([0 if mask_padding_with_zero else 1] * padding_length) + attention_mask
                if len(token_type_ids) == 0:
                    padding_length = max_length
                token_type_ids = ([pad_token_segment_id] * padding_length) + token_type_ids
            else:
                input_ids = input_ids + ([pad_token] * padding_length)
                attention_mask = attention_mask + ([0 if mask_padding_with_zero else 1] * padding_length)
                if len(token_type_ids) == 0:
                    padding_length = max_length
                token_type_ids = token_type_ids + ([pad_token_segment_id] * padding_length)

            assert len(input_ids) == max_length, "Error with input length {} vs {}".format(len(input_ids), max_length)
            assert len(attention_mask) == max_length, "Error with input length {} vs {}".format(len(attention_mask), max_length)
            assert len(token_type_ids) == max_length, "Error with input length {} vs {}".format(len(token_type_ids), max_length)

            if output_mode == "classification":
                label = label_map[example.label]
            elif output_mode == "regression":
                label = float(example.label)
            else:
                raise KeyError(output_mode)

            if ex_index < 5:
                logger.info("*** Example ***")
                logger.info("guid: %s" % (example.guid))
                logger.info("input_ids: %s" % " ".join([str(x) for x in input_ids]))
                logger.info("input_tokens: %s" % " ".join(tokenizer.convert_ids_to_tokens(input_ids)))
                logger.info("attention_mask: %s" % " ".join([str(x) for x in attention_mask]))
                logger.info("token_type_ids: %s" % " ".join([str(x) for x in token_type_ids]))
                logger.info("label: %s (id = %d)" % (example.label, label))

            all_input_ids.append(input_ids)
            all_attention_mask.append(attention_mask)
            all_token_type_ids.append(token_type_ids)

        all_input_ids = torch.tensor([input_ids for input_ids in all_input_ids], dtype=torch.long)
        all_attention_mask = torch.tensor([attention_mask for attention_mask in all_attention_mask],
                                          dtype=torch.long)
        all_token_type_ids = torch.tensor([token_type_ids for token_type_ids in all_token_type_ids],
                                          dtype=torch.long)
        dataset = TensorDataset(all_input_ids, all_attention_mask, all_token_type_ids)

        infos = collections.OrderedDict()
        for key in sorted(length_counter.keys()):
            infos["[%d ~ %d]" % (key - N + 1, key)] = length_counter[key]

        logger.info("Average lengths = %.4f" % (sum_of_lengths / len(dataset)))
        logger.info("Length Counter = %s" % json.dumps(infos, indent=2))

        return dataset

    @staticmethod
    def merge_noised_dataset(dataset, noised_dataset):
        all_input_ids, all_attention_mask, all_token_type_ids, all_labels = dataset.tensors

        all_noised_input_ids, all_noised_attention_mask, all_noised_token_type_ids = noised_dataset.tensors[:3]

        dataset = TensorDataset(all_input_ids, all_attention_mask, all_token_type_ids, all_labels, all_noised_input_ids,
                                all_noised_attention_mask, all_noised_token_type_ids)
        return dataset

    @staticmethod
    def replace_noised_dataset(dataset, noised_dataset):
        all_input_ids, all_attention_mask, all_token_type_ids, all_labels = dataset.tensors

        all_noised_input_ids, all_noised_attention_mask, all_noised_token_type_ids = noised_dataset.tensors[:3]

        dataset = TensorDataset(all_noised_input_ids, all_noised_attention_mask, all_noised_token_type_ids, all_labels)
        return dataset

    def load_translate_data(self):
        train_languages = self.train_language.split(',')
        assert "en" in train_languages
        train_languages.remove("en")
        translate_train_dicts = []
        for i, language in enumerate(train_languages):
            logger.info("reading training data from lang {}".format(language))
            processor = processors["xnli"](language=language, train_language=language)
            translate_train_dicts.append(processor.get_translate_train_dict(self.data_dir))

        if self.translate_en_data:
            train_languages = ["en"] + train_languages
            translate_train_dicts = [None] + translate_train_dicts

        processor = processors["xnli"](language="en", train_language="en")
        en_train_examples = processor.get_train_examples(self.data_dir)

        for i in range(len(en_train_examples)):
            lang_id_a = random.randint(0, len(train_languages) - 1)
            if self.translate_different_pair:
                lang_id_b = random.randint(0, len(train_languages) - 1)
            else:
                lang_id_b = lang_id_a

            # if en_train_examples[i].text_a.strip() not in translate_train_dicts[lang_id_a]:
            #     print(en_train_examples[i].text_a.strip())
            # if en_train_examples[i].text_b.strip() not in translate_train_dicts[lang_id_b]:
            #     print(en_train_examples[i].text_b.strip())

            if not self.translate_en_data or lang_id_a > 0:
                assert en_train_examples[i].text_a.strip() in translate_train_dicts[lang_id_a]
                en_train_examples[i].text_a = translate_train_dicts[lang_id_a][en_train_examples[i].text_a.strip()]

            if not self.translate_en_data or lang_id_b > 0:
                assert en_train_examples[i].text_b.strip() in translate_train_dicts[lang_id_b]
                en_train_examples[i].text_b = translate_train_dicts[lang_id_b][en_train_examples[i].text_b.strip()]
        return en_train_examples


def train(args, train_examples, train_dataset, model, tokenizer, noised_data_generator=None):
    """ Train the model """
    if args.local_rank in [-1, 0] and args.log_dir:
        tb_writer = SummaryWriter(log_dir=args.log_dir)
    else:
        tb_writer = None

    args.train_batch_size = args.per_gpu_train_batch_size * max(1, args.n_gpu)
    train_sampler = RandomSampler(train_dataset) if args.local_rank == -1 else DistributedSampler(train_dataset)
    train_dataloader = DataLoader(
        train_dataset, sampler=train_sampler, batch_size=args.train_batch_size, num_workers=1)

    if args.num_training_steps > 0:
        t_total = args.num_training_steps
        args.num_training_epochs = t_total // (len(train_dataloader) // args.gradient_accumulation_steps) + 1
    else:
        t_total = len(train_dataloader) // args.gradient_accumulation_steps * args.num_training_epochs

    # Prepare optimizer and schedule (linear warmup and decay)
    optimizer_grouped_parameters = get_optimizer_grouped_parameters(
        model=model, weight_decay=args.weight_decay, learning_rate=args.learning_rate,
        layer_decay=args.layer_decay, n_layers=model.config.num_hidden_layers,
    )
    
    warmup_steps = t_total * args.warmup_ratio
    correct_bias = not args.disable_bias_correct

    logger.info("*********** Optimizer setting: ***********")
    logger.info("Learning rate = %.10f" % args.learning_rate)
    logger.info("Adam epsilon = %.10f" % args.adam_epsilon)
    logger.info("Adam_betas = (%.4f, %.4f)" % (float(args.adam_betas[0]), float(args.adam_betas[1])))
    logger.info("Correct_bias = %s" % str(correct_bias))
    optimizer = AdamW(
        optimizer_grouped_parameters, lr=args.learning_rate, eps=args.adam_epsilon,
        betas=(float(args.adam_betas[0]), float(args.adam_betas[1])),
        correct_bias=correct_bias,
    )
    scheduler = get_linear_schedule_with_warmup(optimizer, num_warmup_steps=warmup_steps, num_training_steps=t_total)
    if args.fp16:
        try:
            from apex import amp
        except ImportError:
            raise ImportError("Please install apex from https://www.github.com/nvidia/apex to use fp16 training.")
        model, optimizer = amp.initialize(model, optimizer, opt_level=args.fp16_opt_level)
        amp_state_dict = amp.state_dict()
        amp_state_dict['loss_scaler0']['loss_scale'] = args.fp16_init_loss_scale
        logger.info("Set fp16_init_loss_scale to %.1f" % args.fp16_init_loss_scale)
        amp.load_state_dict(amp_state_dict)

    # multi-gpu training (should be after apex fp16 initialization)
    if args.n_gpu > 1:
        model = torch.nn.DataParallel(model)

    # Distributed training (should be after apex fp16 initialization)
    if args.local_rank != -1:
        model = torch.nn.parallel.DistributedDataParallel(model, device_ids=[args.local_rank],
                                                          output_device=args.local_rank,
                                                          find_unused_parameters=True)

    # Train!
    logger.info("***** Running training *****")
    logger.info("  Num examples = %d", len(train_dataset))
    logger.info("  Num Epochs = %d", args.num_training_epochs)
    logger.info("  Instantaneous batch size per GPU = %d", args.per_gpu_train_batch_size)
    logger.info("  Total train batch size (w. parallel, distributed & accumulation) = %d",
                   args.train_batch_size * args.gradient_accumulation_steps * (torch.distributed.get_world_size() if args.local_rank != -1 else 1))
    logger.info("  Gradient Accumulation steps = %d", args.gradient_accumulation_steps)
    logger.info("  Total optimization steps = %d", t_total)

    metric_for_best = args.metric_for_choose_best_checkpoint
    best_performance = None
    best_epoch = None
    global_step = 0
    tr_loss, logging_loss = 0.0, 0.0
    tr_original_loss, logging_original_loss = 0.0, 0.0
    tr_noised_loss, logging_noised_loss = 0.0, 0.0
    tr_kl_loss, logging_kl_loss = 0.0, 0.0

    model.zero_grad()
    train_iterator = trange(int(args.num_training_epochs), desc="Epoch", disable=args.local_rank not in [-1, 0])
    set_seed(args)  # Added here for reproductibility (even between python 2 and 3)
    for _ in train_iterator:
        if noised_data_generator is not None:
            assert noised_data_generator.enable_kl_loss or noised_data_generator.noised_loss

            processor = processors[args.task_name]()

            label_list = processor.get_labels()
            if args.task_name in ['mnli', 'mnli-mm'] and args.model_type in ['roberta', 'xlmroberta']:
                # HACK(label indices are swapped in RoBERTa pretrained model)
                label_list[1], label_list[2] = label_list[2], label_list[1]

            noised_train_dataset = noised_data_generator.get_noised_dataset(
                train_examples, tokenizer=tokenizer, label_list=label_list,
                max_length=args.max_seq_length, output_mode=output_modes[args.task_name], 
                pad_on_left=bool(args.model_type in ['xlnet']),                 # pad on the left for xlnet
                pad_token=tokenizer.convert_tokens_to_ids([tokenizer.pad_token])[0],
                pad_token_segment_id=4 if args.model_type in ['xlnet'] else 0,
            )
            if not (noised_data_generator.original_loss or noised_data_generator.enable_kl_loss):
                logger.info("Creating replaced dataset")
                merged_train_dataset = noised_data_generator.replace_noised_dataset(train_dataset, noised_train_dataset)
            else:
                logger.info("Creating merged dataset")
                merged_train_dataset = noised_data_generator.merge_noised_dataset(train_dataset, noised_train_dataset)

            train_sampler = RandomSampler(merged_train_dataset) if args.local_rank == -1 else DistributedSampler(
                merged_train_dataset)
            train_dataloader = DataLoader(merged_train_dataset, sampler=train_sampler, batch_size=args.train_batch_size)

            if not args.num_training_steps > 0:
                assert t_total == len(train_dataloader) // args.gradient_accumulation_steps * args.num_training_epochs

        if args.disable_tqdm:
            epoch_iterator = train_dataloader
        else:
            epoch_iterator = tqdm(train_dataloader, desc="Iteration", disable=args.local_rank not in [-1, 0])
        for step, batch in enumerate(epoch_iterator):
            model.train()
            # freeze params except layernorm
            if args.only_train_layernorm:
                for name, param in model.named_parameters():
                    if ("layer_norm_approximation" in name):
                        param.requires_grad = True
                    else:
                        param.requires_grad = False
            batch = tuple(t.to(args.device) for t in batch)
            inputs = {'input_ids':      batch[0],
                      'attention_mask': batch[1],
                      'labels':         batch[3]}
            if len(batch) == 7:
                inputs['noised_input_ids'] = batch[4]
                inputs['noised_attention_mask'] = batch[6]
            if args.model_type != 'distilbert':
                inputs['token_type_ids'] = batch[2] if args.model_type in ['bert', 'xlnet', 'unilm'] else None  # XLM, DistilBERT and RoBERTa don't use segment_ids
            outputs = model(**inputs)
            loss = outputs[0]  # model outputs are always tuple in transformers (see doc)
            # import pudb;pu.db;
            if args.app_ln_layer:
                app_loss = None
                for layer_loss in outputs[2]:
                    if app_loss is None:
                        app_loss = layer_loss
                    else:
                        app_loss = layer_loss + app_loss
                # print("loss:{}, app loss:{}".format(loss,app_loss))
                loss = loss + app_loss

            if args.n_gpu > 1:
                loss = loss.mean()  # mean() to average on multi-gpu parallel training
            if args.gradient_accumulation_steps > 1:
                loss = loss / args.gradient_accumulation_steps

            if args.fp16:
                with amp.scale_loss(loss, optimizer) as scaled_loss:
                    scaled_loss.backward()
            else:
                loss.backward()

            tr_loss += loss.item()

            if len(batch) == 7:
                original_loss, noised_loss, kl_loss = outputs[1:4]
                if args.n_gpu > 1:
                    original_loss = original_loss.mean()
                    noised_loss = noised_loss.mean()
                    kl_loss = kl_loss.mean()
                if args.gradient_accumulation_steps > 1:
                    original_loss = original_loss / args.gradient_accumulation_steps
                    noised_loss = noised_loss / args.gradient_accumulation_steps
                    kl_loss = kl_loss / args.gradient_accumulation_steps
                tr_original_loss += original_loss.item()
                tr_noised_loss += noised_loss.item()
                tr_kl_loss += kl_loss.item()

            if (step + 1) % args.gradient_accumulation_steps == 0:
                if args.max_grad_norm > 0:
                    if args.fp16:
                        torch.nn.utils.clip_grad_norm_(amp.master_params(optimizer), args.max_grad_norm)
                    else:
                        torch.nn.utils.clip_grad_norm_(model.parameters(), args.max_grad_norm)

                optimizer.step()
                scheduler.step()  # Update learning rate schedule
                model.zero_grad()
                global_step += 1

                if args.local_rank in [-1, 0] and args.logging_steps > 0 and global_step % args.logging_steps == 0:
                    logs = {}
                    loss_scalar = (tr_loss - logging_loss) / args.logging_steps
                    original_loss_scalar = (tr_original_loss - logging_original_loss) / args.logging_steps
                    noised_loss_scalar = (tr_noised_loss - logging_noised_loss) / args.logging_steps
                    kl_loss_scalar = (tr_kl_loss - logging_kl_loss) / args.logging_steps
                    learning_rate_scalar = scheduler.get_lr()[0]
                    logs['learning_rate'] = learning_rate_scalar
                    logs['loss'] = loss_scalar
                    logs['tr_original_loss'] = original_loss_scalar
                    logs['tr_noised_loss'] = noised_loss_scalar
                    logs['tr_kl_loss'] = kl_loss_scalar
                    logging_loss = tr_loss
                    logging_original_loss = tr_original_loss
                    logging_noised_loss = tr_noised_loss
                    logging_kl_loss = tr_kl_loss

                    if tb_writer is not None:
                        for key, value in logs.items():
                            tb_writer.add_scalar(key, value, global_step)
                    logger.info(json.dumps({**logs, **{'step': global_step}}))

                if args.num_save_ckpts > 0 and (t_total - global_step) % args.save_checkpoint_steps == 0 and \
                        global_step + args.save_checkpoint_steps * args.num_save_ckpts >= t_total:

                    # Save model checkpoint
                    output_dir = os.path.join(args.output_dir, 'steps-{}'.format(global_step))
                    if not os.path.exists(output_dir):
                        os.makedirs(output_dir)

                    model_to_save = model.module if hasattr(model, 'module') else model
                    # Take care of distributed/parallel training
                    model_to_save.save_pretrained(output_dir)
                    logger.info("Saving model checkpoint to %s", output_dir)
                    torch.save(args, os.path.join(output_dir, 'training_args.bin'))

            if args.num_training_steps > 0 and global_step > args.num_training_steps:
                if not args.disable_tqdm:
                    epoch_iterator.close()
                break

        if args.local_rank in [-1, 0]:
            logs = {}
            if args.local_rank == -1 and args.evaluate_during_training:  # Only evaluate when single GPU otherwise metrics may not average well
                results = evaluate(args, model, tokenizer, prefix='epoch-{}'.format(_ + 1))
                for key, value in results.items():
                    eval_key = 'eval_{}'.format(key)
                    logs[eval_key] = value

                if metric_for_best is None:
                    metric_for_best = list(list(results.values())[0].keys())[0]
                if best_epoch is None:
                    best_epoch = _ + 1
                    best_performance = results
                else:
                    for eval_task in results:
                        if best_performance[eval_task][metric_for_best] < results[eval_task][metric_for_best]:
                            best_performance[eval_task] = results[eval_task]
                            best_epoch = _ + 1

            loss_scalar = (tr_loss - logging_loss) / args.logging_steps
            learning_rate_scalar = scheduler.get_lr()[0]
            logs['learning_rate'] = learning_rate_scalar
            logs['loss'] = loss_scalar
            logging_loss = tr_loss

            if tb_writer is not None:
                for key, value in logs.items():
                    tb_writer.add_scalar(key, value, global_step)
            print(json.dumps({**logs, **{'step': global_step}}))

            # Save model checkpoint
            output_dir = os.path.join(args.output_dir, 'epoch-{}'.format(_ + 1))
            if not os.path.exists(output_dir):
                os.makedirs(output_dir)
            if not args.do_not_save:
                model_to_save = model.module if hasattr(model, 'module') else model
                # Take care of distributed/parallel training
                model_to_save.save_pretrained(output_dir)
                logger.info("Saving model checkpoint to %s", output_dir)
            torch.save(args, os.path.join(output_dir, 'training_args.bin'))

        if args.num_training_steps > 0 and global_step > args.num_training_steps:
            train_iterator.close()
            break

        if args.fp16:
            logger.info("Amp state dict = %s" % json.dumps(amp.state_dict()))

    if args.local_rank in [-1, 0] and tb_writer is not None:
        tb_writer.close()

    if best_epoch is not None:
        logger.info(" ***************** Best checkpoint: {}, choosed by {} *****************".format(
            best_epoch, metric_for_best))
        logger.info("Best performance = %s" % json.dumps(best_performance))
        save_best_result(best_epoch, best_performance, args.output_dir)

    return global_step, tr_loss / global_step


def save_best_result(best_epoch, best_performance, output_dir):
    best_performance["checkpoint"] = best_epoch
    with open(os.path.join(output_dir, "best_performance.json"), mode="w") as writer:
        writer.write(json.dumps(best_performance, indent=2))


def evaluate(args, model, tokenizer, prefix=""):
    # Loop to handle MNLI double evaluation (matched, mis-matched)
    eval_task_names = [args.task_name]
    if args.task_name == "mnli":
        eval_task_names.append("mnli-mm")
    if args.task_name == "cb":
        eval_task_names.append("cb-debug")
    eval_output_dir = args.output_dir
    is_predict_list = [False] * len(eval_task_names)
    if args.do_predict:
        is_predict_list.append(True)
        eval_task_names.append(args.task_name)

    from unilm.squash_model_for_he import squash_model_encoder
    # squash_model_encoder(model, model.config)

    results = {}
    for eval_task, is_predict in zip(eval_task_names, is_predict_list):
        if is_predict:
            cached_file = args.cached_predict_file
        else:
            cached_file = args.cached_dev_file
        if cached_file is not None:
            cached_file = cached_file + '_' + eval_task
        eval_dataset, eval_examples = load_and_cache_examples(
            args, eval_task, tokenizer, cached_features_file=cached_file,
            evaluate=True, is_predict=is_predict, return_examples=True)

        if not os.path.exists(eval_output_dir) and args.local_rank in [-1, 0]:
            os.makedirs(eval_output_dir)

        args.eval_batch_size = args.per_gpu_eval_batch_size * max(1, args.n_gpu)
        # Note that DistributedSampler samples randomly
        eval_sampler = SequentialSampler(eval_dataset)
        eval_dataloader = DataLoader(eval_dataset, sampler=eval_sampler, batch_size=args.eval_batch_size)

        # multi-gpu eval
        if args.n_gpu > 1:
            model = torch.nn.DataParallel(model)

        # Eval!
        logger.info("***** Running evaluation {} *****".format(prefix))
        logger.info("  Num examples = %d", len(eval_dataset))
        logger.info("  Batch size = %d", args.eval_batch_size)
        eval_loss = 0.0
        nb_eval_steps = 0
        preds = None
        out_label_ids = None
        if args.disable_tqdm:
            epoch_iterator = eval_dataloader
        else:
            epoch_iterator = tqdm(eval_dataloader, desc="Evaluating")
        for batch in epoch_iterator:
            model.eval()
            batch = tuple(t.to(args.device) for t in batch)

            with torch.no_grad():
                inputs = {'input_ids':      batch[0],
                          'attention_mask': batch[1],
                          'labels':         batch[3]}
                if args.model_type != 'distilbert':
                    inputs['token_type_ids'] = batch[2] if args.model_type in ['bert', 'xlnet'] else None  # XLM, DistilBERT and RoBERTa don't use segment_ids
                outputs = model(**inputs)
                tmp_eval_loss, logits = outputs[:2]

                eval_loss += tmp_eval_loss.mean().item()
            nb_eval_steps += 1
            if preds is None:
                preds = logits.detach().cpu().numpy()
                out_label_ids = inputs['labels'].detach().cpu().numpy()
            else:
                preds = np.append(preds, logits.detach().cpu().numpy(), axis=0)
                out_label_ids = np.append(out_label_ids, inputs['labels'].detach().cpu().numpy(), axis=0)

        eval_loss = eval_loss / nb_eval_steps
        if args.output_mode == "classification":
            preds = np.argmax(preds, axis=1)
        elif args.output_mode == "regression":
            preds = np.squeeze(preds)
        result = compute_metrics(eval_task, preds, out_label_ids)
        eval_result_output_dir = os.path.join(eval_output_dir, prefix)
        if not os.path.exists(eval_result_output_dir):
            os.makedirs(eval_result_output_dir)
        if not is_predict:
            result["prefix"] = prefix
            results[eval_task] = result
        else:
            prediction_file = os.path.join(eval_result_output_dir, "prediction.%s.txt" % eval_task)
            logger.info("Write prediction into %s" % prediction_file)
            all_labels = processors[eval_task]().get_labels()
            with open(prediction_file, mode='w', encoding="utf-8") as f_out:
                for i in preds:
                    f_out.write(all_labels[i])
                    f_out.write('\n')

        output_eval_file = os.path.join(eval_result_output_dir, "eval_results.%s.txt" % eval_task)
        if not is_predict:
            writer = open(output_eval_file, "w")
        else:
            writer = None
        logger.info("***** Eval results {} for {} *****".format(prefix, eval_task))
        # for key in sorted(result.keys()):
        #     logger.info("  %s = %s", key, str(result[key]))
        #     writer.write("%s = %s\n" % (key, str(result[key])))
        if is_predict:
            logger.info("***** Do prediction ! *****".format(prefix, eval_task))
        logger.info("Result = %s" % json.dumps(result, indent=2))

        if args.write_error_on_dev:
            with open(
                    os.path.join(eval_result_output_dir, "error_examples.%s.txt" % eval_task),
                    mode="w", encoding="utf-8") as log:
                all_labels = processors[eval_task]().get_labels()
                for i in range(len(eval_examples)):
                    if preds[i] != out_label_ids[i]:
                        log.write(json.dumps({
                            "example-id": i,
                            "input": eval_examples[i].text_a + "   ###   " + eval_examples[i].text_b,
                            "label": all_labels[out_label_ids[i]],
                            "pred": all_labels[preds[i]],
                        }, indent=2))
                        assert all_labels[out_label_ids[i]] == eval_examples[i].label
                        log.write('\n')

        if writer is not None:
            writer.write(json.dumps(result, indent=2))
            writer.close()

    return results


def load_and_cache_examples(
        args, task, tokenizer, cached_features_file=None, evaluate=False, return_examples=False, is_predict=False):
    if args.local_rank not in [-1, 0] and not evaluate:
        torch.distributed.barrier()  # Make sure only the first process in distributed training process the dataset, and the others will use the cache

    processor = processors[task]()
    output_mode = output_modes[task]
    if cached_features_file is None:
        if args.disable_auto_cache and args.local_rank != -1:
            logger.warning("Please cache the features in DDP mode !")
            raise RuntimeError()
        if not args.disable_auto_cache:
            # Load data features from cache or dataset file
            seg = "train"
            if evaluate:
                seg = "dev"
            if is_predict:
                seg = "predict"
            cached_features_file = os.path.join(args.data_dir, 'cached_{}_{}_{}_{}'.format(
                seg,
                list(filter(None, args.model_name_or_path.split('/'))).pop(),
                str(args.max_seq_length),
                str(task)))
    examples = None
    if cached_features_file is not None and os.path.exists(cached_features_file) and not args.overwrite_cache:
        features = torch.load(cached_features_file)
        logger.info("Loading %d features from cached file %s" % (len(features), cached_features_file))
        if is_predict:
            examples = processor.get_predict_examples(args.predict_file)
        elif evaluate:
            examples = processor.get_dev_examples(args.data_dir)
        else:
            examples = processor.get_train_examples(args.data_dir)
    else:
        logger.info("Creating features from dataset file at %s", args.data_dir)
        label_list = processor.get_labels()
        if task in ['mnli', 'mnli-mm'] and args.model_type in ['roberta', 'xlmroberta']:
            # HACK(label indices are swapped in RoBERTa pretrained model)
            label_list[1], label_list[2] = label_list[2], label_list[1]
        if is_predict:
            examples = processor.get_predict_examples(args.predict_file)
        elif evaluate:
            examples = processor.get_dev_examples(args.data_dir)
        else:
            examples = processor.get_train_examples(args.data_dir)
        features = convert_examples_to_features(examples,
                                                tokenizer,
                                                label_list=label_list,
                                                max_length=args.max_seq_length,
                                                output_mode=output_mode,
                                                pad_on_left=bool(args.model_type in ['xlnet']),
                                                pad_token=tokenizer.convert_tokens_to_ids([tokenizer.pad_token])[0],
                                                pad_token_segment_id=4 if args.model_type in ['xlnet'] else 0,
        )

        if args.local_rank in [-1, 0] and cached_features_file is not None:
            logger.info("Saving features into cached file %s", cached_features_file)
            torch.save(features, cached_features_file)

    if args.local_rank == 0 and not evaluate:
        torch.distributed.barrier()  # Make sure only the first process in distributed training process the dataset, and the others will use the cache

    # Convert to Tensors and build dataset
    all_input_ids = torch.tensor([f.input_ids for f in features], dtype=torch.long)
    all_attention_mask = torch.tensor([f.attention_mask for f in features], dtype=torch.long)
    all_token_type_ids = torch.tensor([f.token_type_ids for f in features], dtype=torch.long)
    if output_mode == "classification":
        all_labels = torch.tensor([f.label for f in features], dtype=torch.long)
    elif output_mode == "regression":
        all_labels = torch.tensor([f.label for f in features], dtype=torch.float)
 
    dataset = TensorDataset(all_input_ids, all_attention_mask, all_token_type_ids, all_labels)
    if return_examples:
        return dataset, examples
    else:
        return dataset


def eval_str_list(x, type=float):
    if x is None:
        return None
    if isinstance(x, str):
        x = eval(x)
    try:
        return list(map(type, x))
    except TypeError:
        return [type(x)]


def add_stable_finetuning_args(parser):
    # stable fine-tuning paramters
    parser.add_argument("--overall_ratio", default=1.0, type=float, help="overall ratio")
    parser.add_argument("--enable_kl_loss", action="store_true", help="Whether to enable kl loss.")
    parser.add_argument("--kl_lambda", default=5.0, type=float, help="lambda of KL loss")
    parser.add_argument("--original_loss", action="store_true",
                        help="Whether to use cross entropy loss on the former example.")
    parser.add_argument("--noised_loss", action="store_true",
                        help="Whether to use cross entropy loss on the latter example.")
    parser.add_argument("--enable_bpe_switch", action="store_true", help="Whether to enable bpe-switch.")
    parser.add_argument("--bpe_switch_ratio", default=0.5, type=float, help="bpe_switch_ratio")
    parser.add_argument("--tokenizer_dir", default=None, type=str, help="tokenizer dir")
    parser.add_argument("--tokenizer_languages", default=None, type=str, help="tokenizer languages")
    parser.add_argument("--enable_bpe_sampling", action="store_true", help="Whether to enable bpe sampling.")
    parser.add_argument("--bpe_sampling_ratio", default=0.5, type=float, help="bpe_sampling_ratio")
    parser.add_argument("--sampling_alpha", default=5.0, type=float, help="alpha of sentencepiece sampling")
    parser.add_argument("--sampling_nbest_size", default=-1, type=int, help="nbest_size of sentencepiece sampling")
    parser.add_argument("--enable_random_noise", action="store_true", help="Whether to enable random noise.")
    parser.add_argument("--noise_detach_embeds", action="store_true", help="Whether to detach noised embeddings.")
    parser.add_argument("--noise_eps", default=1e-5, type=float, help="noise eps")
    parser.add_argument('--noise_type', type=str, default='uniform',
                        choices=['normal', 'uniform'],
                        help='type of noises for RXF methods')
    parser.add_argument("--enable_code_switch", action="store_true", help="Whether to enable code switch.")
    parser.add_argument("--code_switch_ratio", default=0.5, type=float, help="code_switch_ratio")
    parser.add_argument("--dict_dir", default=None, type=str, help="dict dir")
    parser.add_argument("--dict_languages", default=None, type=str, help="dict languages")
    parser.add_argument("--enable_word_dropout", action="store_true", help="Whether to enable word dropout.")
    parser.add_argument("--word_dropout_rate", default=0.1, type=float, help="word dropout rate.")
    parser.add_argument("--enable_translate_data", action="store_true", help="Whether to enable translate data.")
    parser.add_argument("--translate_different_pair", action="store_true", help="Whether to translate different pair.")
    parser.add_argument("--translate_en_data", action="store_true", help="Whether to translate en data.")


def main():
    parser = argparse.ArgumentParser()

    ## Required parameters
    parser.add_argument("--data_dir", default=None, type=str, required=True,
                        help="The input data dir. Should contain the .tsv files (or other data files) for the task.")
    parser.add_argument("--model_type", default="unilm", type=str, 
                        help="Model type selected in the list: " + ", ".join(MODEL_CLASSES.keys()))
    parser.add_argument("--model_name_or_path", default=None, type=str, required=True,
                        help="Path to pre-trained model or shortcut name selected in the list: " + ", ".join(ALL_MODELS))
    parser.add_argument("--task_name", default=None, type=str, required=True,
                        help="The name of the task to train selected in the list: " + ", ".join(processors.keys()))
    parser.add_argument("--output_dir", default=None, type=str, required=True,
                        help="The output directory where the model predictions and checkpoints will be written.")
    parser.add_argument("--do_predict", action='store_true',
                        help="Whether to run prediction for the --predict_file. ")
    parser.add_argument("--predict_file", default=None, type=str,
                        help="The predict file. For example, the test dataset. ")
    parser.add_argument("--do_not_save", action='store_true',
                        help="Disable save models after each epoch. ")
    parser.add_argument("--log_dir", default=None, type=str,
                        help="The output directory where the log will be written.")
    parser.add_argument("--prefix_need_remove", default=None, type=str,
                        help="To load model weight by removing the given prefix. ")
    parser.add_argument('--mean_pooling', action='store_true')
    parser.add_argument('--plus_pooler', action='store_true')

    parser.add_argument('--app_ln_layer', action='store_true')
    parser.add_argument('--app_ln_loss', action='store_true')

    parser.add_argument("--cached_train_file", default=None, type=str,
                        help="Path to cache the train set features. ")
    parser.add_argument("--cached_dev_file", default=None, type=str,
                        help="Path to cache the dev set features. ")
    parser.add_argument("--cached_predict_file", default=None, type=str,
                        help="Path to cache the prediction file. ")
    parser.add_argument('--disable_auto_cache', action='store_true',
                        help='Disable the function for automatic cache the training/dev features.')
    parser.add_argument('--disable_tqdm', action='store_true',
                        help='Disable the tqdm bar. ')

    ## Other parameters
    parser.add_argument("--convert_checkpoint_to_unilm", action='store_true',
                        help="Convert model checkpoint to UniLM format, then finetuning. ")

    parser.add_argument("--config_name", default="", type=str,
                        help="Pretrained config name or path if not the same as model_name")
    parser.add_argument("--tokenizer_name", default="", type=str,
                        help="Pretrained tokenizer name or path if not the same as model_name")
    parser.add_argument("--cache_dir", default="", type=str,
                        help="Where do you want to store the pre-trained models downloaded from s3")
    parser.add_argument("--max_seq_length", default=128, type=int,
                        help="The maximum total input sequence length after tokenization. Sequences longer "
                             "than this will be truncated, sequences shorter will be padded.")
    parser.add_argument("--do_train", action='store_true',
                        help="Whether to run training.")
    parser.add_argument("--do_eval", action='store_true',
                        help="Whether to run eval on the dev set.")
    parser.add_argument("--evaluate_during_training", action='store_true',
                        help="Run evaluation during training at each logging step.")
    parser.add_argument("--write_error_on_dev", action='store_true',
                        help="Write error examples for dev set. ")
    parser.add_argument("--do_lower_case", action='store_true',
                        help="Set this flag if you are using an uncased model.")
    parser.add_argument("--softmax_approximation", action='store_true')

    parser.add_argument("--per_gpu_eval_batch_size", default=8, type=int,
                        help="Batch size per GPU/CPU for evaluation.")

    parser.add_argument("--dropout_prob", default=None, type=float,
                        help="Set dropout prob, default value is read from config. ")
    parser.add_argument("--cls_dropout_prob", default=None, type=float,
                        help="Set cls layer dropout prob. ")
    parser.add_argument("--drop_task_layers", action='store_true',
                        help="Drop task layers' parameters for continue training. ")

    parser.add_argument('--logging_steps', type=int, default=50,
                        help="Log every X updates steps.")
    parser.add_argument("--eval_all_checkpoints", action='store_true',
                        help="Evaluate all checkpoints starting with the same prefix as model_name ending and ending with step number")
    parser.add_argument("--no_cuda", action='store_true',
                        help="Avoid using CUDA when available")
    parser.add_argument('--overwrite_output_dir', action='store_true',
                        help="Overwrite the content of the output directory")
    parser.add_argument('--overwrite_cache', action='store_true',
                        help="Overwrite the cached training and evaluation sets")
    parser.add_argument('--seed', type=int, default=42,
                        help="random seed for initialization")
    parser.add_argument('--metric_for_choose_best_checkpoint', type=str, default=None,
                        help="Set the metric to choose the best checkpoint")

    parser.add_argument('--save_checkpoint_steps', type=int, default=50)
    parser.add_argument('--num_save_ckpts', type=int, default=-1)

    parser.add_argument('--checkpoints_to_eval', type=str, default=None,
                        help="Checkpoint dirs which need to eval")

    parser.add_argument("--local_rank", type=int, default=-1,
                        help="For distributed training: local_rank")
    parser.add_argument('--only_train_layernorm', action='store_true')
    parser.add_argument('--server_ip', type=str, default='', help="For distant debugging.")
    parser.add_argument('--server_port', type=str, default='', help="For distant debugging.")

    add_stable_finetuning_args(parser)
    add_optimzation_args(parser)

    args = parser.parse_args()

    if os.path.exists(args.output_dir) and os.listdir(args.output_dir) and args.do_train and not args.overwrite_output_dir:
        raise ValueError("Output directory ({}) already exists and is not empty. Use --overwrite_output_dir to overcome.".format(args.output_dir))

    # Setup distant debugging if needed
    if args.server_ip and args.server_port:
        # Distant debugging - see https://code.visualstudio.com/docs/python/debugging#_attach-to-a-local-script
        import ptvsd
        print("Waiting for debugger attach")
        ptvsd.enable_attach(address=(args.server_ip, args.server_port), redirect_output=True)
        ptvsd.wait_for_attach()

    # Setup CUDA, GPU & distributed training
    if args.local_rank == -1 or args.no_cuda:
        device = torch.device("cuda" if torch.cuda.is_available() and not args.no_cuda else "cpu")
        args.n_gpu = torch.cuda.device_count()
    else:  # Initializes the distributed backend which will take care of sychronizing nodes/GPUs
        torch.cuda.set_device(args.local_rank)
        device = torch.device("cuda", args.local_rank)
        torch.distributed.init_process_group(backend='nccl')
        args.n_gpu = 1

    if not os.path.exists(args.output_dir):
        os.makedirs(args.output_dir)
    with open(os.path.join(args.output_dir, 'training_args.json'), mode='w', encoding="utf-8") as writer:
        writer.write(json.dumps(args.__dict__, indent=2, sort_keys=True))

    args.device = device

    # Setup logging
    logging.basicConfig(format = '%(asctime)s - %(levelname)s - %(name)s -   %(message)s',
                        datefmt = '%m/%d/%Y %H:%M:%S',
                        level = logging.INFO if args.local_rank in [-1, 0] else logging.WARN)
    logger.warning("Process rank: %s, device: %s, n_gpu: %s, distributed training: %s, 16-bits training: %s",
                    args.local_rank, device, args.n_gpu, bool(args.local_rank != -1), args.fp16)

    # Set seed
    set_seed(args)

    # Prepare GLUE task
    args.task_name = args.task_name.lower()
    if args.task_name not in processors:
        raise ValueError("Task not found: %s" % (args.task_name))
    processor = processors[args.task_name]()
    args.output_mode = output_modes[args.task_name]
    label_list = processor.get_labels()
    num_labels = len(label_list)

    # Load pretrained model and tokenizer
    if args.local_rank not in [-1, 0]:
        torch.distributed.barrier()  # Make sure only the first process in distributed training will download model & vocab

    args.model_type = args.model_type.lower()
    config_class, model_class, tokenizer_class = MODEL_CLASSES[args.model_type]
    config = config_class.from_pretrained(args.config_name if args.config_name else args.model_name_or_path,
                                          num_labels=num_labels,
                                          finetuning_task=args.task_name,
                                          cache_dir=args.cache_dir if args.cache_dir else None)
    tokenizer_name = args.tokenizer_name if args.tokenizer_name else args.model_name_or_path
    tokenizer = tokenizer_class.from_pretrained(tokenizer_name,
                                                do_lower_case=args.do_lower_case,
                                                cache_dir=args.cache_dir if args.cache_dir else None)

    if not hasattr(config, 'need_pooler') or config.need_pooler is not True:
        setattr(config, 'need_pooler', True)
    if args.dropout_prob is not None:
        config.hidden_dropout_prob = args.dropout_prob
        config.attention_probs_dropout_prob = args.dropout_prob

    if args.cls_dropout_prob is not None:
        config.cls_dropout_prob = args.cls_dropout_prob

    setattr(config, 'mean_pooling', args.mean_pooling)
    setattr(config, 'plus_pooler', args.plus_pooler)
    setattr(config, 'app_ln_layer', args.app_ln_layer)
    setattr(config, 'app_ln_loss', args.app_ln_loss)
    if args.app_ln_loss:
        assert args.app_ln_layer

    logger.info("Final model config for finetuning: ")
    logger.info("%s" % config.to_json_string())

    drop_parameters = None
    if args.drop_task_layers:
        drop_parameters = ["classifier"]
    state_dict = torch.load(os.path.join(args.model_name_or_path), map_location=device)
    setattr(config, 'softmax_approximation', args.softmax_approximation)
    if args.softmax_approximation:
        app_state_dict = torch.load(reciprocal_model_path, map_location=device)
        for layer_id in range(config.num_hidden_layers):
            head = "bert.encoder.layer.%d.attention.self.softmax_approximation.reciprocal" % layer_id
            for key in app_state_dict:
                state_dict["%s.%s" % (head, key)] = app_state_dict[key]
    # import pudb;pu.db;
    model = model_class.from_pretrained(
        args.model_name_or_path, config=config,
        cache_dir=args.cache_dir if args.cache_dir else None, state_dict=state_dict)

    if args.local_rank == 0:
        torch.distributed.barrier()  # Make sure only the first process in distributed training will download model & vocab

    model.to(args.device)

    logger.info("Training/evaluation parameters %s", args)

    # Training
    if args.do_train:
        train_dataset, train_examples = load_and_cache_examples(
            args, args.task_name, tokenizer, cached_features_file=args.cached_train_file, evaluate=False, return_examples=True)
        global_step, tr_loss = train(args, train_examples, train_dataset, model, tokenizer)
        logger.info(" global_step = %s, average loss = %s", global_step, tr_loss)

        tokenizer.save_pretrained(args.output_dir)

    # Evaluation
    if args.do_eval and args.local_rank in [-1, 0]:
        tokenizer = tokenizer_class.from_pretrained(args.output_dir, do_lower_case=args.do_lower_case)

        if args.checkpoints_to_eval is not None:
            checkpoints = args.checkpoints_to_eval.split(',')
        else:
            checkpoints = \
                list(os.path.dirname(c) for c in sorted(glob.glob(args.output_dir + '/**/' + WEIGHTS_NAME, recursive=True)))
        # logging.getLogger("transformers.modeling_utils").setLevel(logging.WARN)  # Reduce logging
        logger.info("Evaluate the following checkpoints: %s", checkpoints)

        metric_for_best = args.metric_for_choose_best_checkpoint
        best_performance = None
        best_epoch = None

        for checkpoint in checkpoints:
            prefix = checkpoint.split('/')[-1] if checkpoint.find('epoch') != -1 else ""
            checkpoint_config = config_class.from_pretrained(checkpoint)
            setattr(checkpoint_config, "softmax_approximation", args.softmax_approximation)
            state_dict = torch.load(os.path.join(checkpoint, "pytorch_model.bin"), map_location=device)
            if args.softmax_approximation:
                app_state_dict = torch.load(reciprocal_model_path, map_location=device)
                for layer_id in range(config.num_hidden_layers):
                    head = "bert.encoder.layer.%d.attention.self.softmax_approximation.reciprocal" % layer_id
                    for key in app_state_dict:
                        state_dict["%s.%s" % (head, key)] = app_state_dict[key]

            setattr(checkpoint_config, 'app_ln_layer', args.app_ln_layer)
            setattr(checkpoint_config, 'app_ln_loss', False)

            model = model_class.from_pretrained(checkpoint, config=checkpoint_config, state_dict=state_dict)

            
            model.to(args.device)

            # for name, param in model.named_parameters():
            #     if not torch.all(torch.eq(param.data, groud_truth[name])):
            #         print(name)
            #         print(param)
            #         print(groud_truth[name])
            #         print("\n")

            result = evaluate(args, model, tokenizer, prefix=prefix)

            if metric_for_best is None:
                metric_for_best = list(list(result.values())[0].keys())[0]
            if best_epoch is None:
                best_epoch = checkpoint
                best_performance = result
            else:
                for eval_task in result:
                    if best_performance[eval_task][metric_for_best] < result[eval_task][metric_for_best]:
                        best_performance[eval_task] = result[eval_task]
                        best_epoch = checkpoint

        if best_epoch is not None:
            logger.info(" ***************** Best checkpoint: {}, choosed by {} *****************".format(
                best_epoch, metric_for_best))
            logger.info("Best performance = %s" % json.dumps(best_performance))

            save_best_result(best_epoch, best_performance, args.output_dir)


if __name__ == "__main__":
    main()