#
 #     MILIE: Modular & Iterative Multilingual Open Information Extraction
 #
 #
 #
 #     Authors: Deleted for purposes of anonymity
 #
 #     Proprietor: Deleted for purposes of anonymity --- PROPRIETARY INFORMATION
 #
 # The software and its source code contain valuable trade secrets and shall be maintained in
 # confidence and treated as confidential information. The software may only be used for
 # evaluation and/or testing purposes, unless otherwise explicitly stated in the terms of a
 # license agreement or nondisclosure agreement with the proprietor of the software.
 # Any unauthorized publication, transfer to third parties, or duplication of the object or
 # source code---either totally or in part---is strictly prohibited.
 #
 #     Copyright (c) 2021 Proprietor: Deleted for purposes of anonymity
 #     All Rights Reserved.
 #
 # THE PROPRIETOR DISCLAIMS ALL WARRANTIES, EITHER EXPRESS OR
 # IMPLIED, INCLUDING BUT NOT LIMITED TO IMPLIED WARRANTIES OF MERCHANTABILITY
 # AND FITNESS FOR A PARTICULAR PURPOSE AND THE WARRANTY AGAINST LATENT
 # DEFECTS, WITH RESPECT TO THE PROGRAM AND ANY ACCOMPANYING DOCUMENTATION.
 #
 # NO LIABILITY FOR CONSEQUENTIAL DAMAGES:
 # IN NO EVENT SHALL THE PROPRIETOR OR ANY OF ITS SUBSIDIARIES BE
 # LIABLE FOR ANY DAMAGES WHATSOEVER (INCLUDING, WITHOUT LIMITATION, DAMAGES
 # FOR LOSS OF BUSINESS PROFITS, BUSINESS INTERRUPTION, LOSS OF INFORMATION, OR
 # OTHER PECUNIARY LOSS AND INDIRECT, CONSEQUENTIAL, INCIDENTAL,
 # ECONOMIC OR PUNITIVE DAMAGES) ARISING OUT OF THE USE OF OR INABILITY
 # TO USE THIS PROGRAM, EVEN IF the proprietor HAS BEEN ADVISED OF
 # THE POSSIBILITY OF SUCH DAMAGES.
 #
 # For purposes of anonymity, the identity of the proprietor is not given herewith.
 # The identity of the proprietor will be given once the review of the
 # conference submission is completed.
 #
 # THIS HEADER MAY NOT BE EXTRACTED OR MODIFIED IN ANY WAY.
 #

import logging
import os
import random

import numpy as np

import torch
from torch.utils.data import DataLoader, RandomSampler
from torch.utils.data.distributed import DistributedSampler

import warnings
with warnings.catch_warnings():
    warnings.filterwarnings("ignore", category=FutureWarning)
    from transformers.optimization import AdamW, get_linear_schedule_with_warmup, \
        get_constant_schedule_with_warmup, get_cosine_schedule_with_warmup

from .transformer_heads import MODEL_CLASSES

LOGGER = logging.getLogger(__name__)


def save_model(milie_args, model, prefix=None):
    """
    Saves a model.

    :param milie_args: instance of NspArguments
    :param model: the model to save
    :param prefix: the prefix to attach to the file name
    :return: the location of the output file
    """
    # Only save the model it-self
    if prefix:
        output_model_file = os.path.join(milie_args.output_dir, prefix)
    else:
        output_model_file = milie_args.output_dir
    if torch.cuda.device_count() > 1:
        model = model.module
    model.save_pretrained(output_model_file)
    return output_model_file


def set_seed(seed, n_gpu=1):
    """
    Sets the seed.

    :param seed: seed to set, set -1 to draw a random number
    :param n_gpu:
    :return: 0 on success
    """
    if seed == -1:
        seed = random.randrange(2**32 - 1)
    LOGGER.info("Seed: %s", seed)
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    if n_gpu > 0:
        torch.cuda.manual_seed_all(seed)
    return 0


def get_train_dataloader(milie_args, masker, data_handler):
    """
    Prepares a TensorDataset for training.
    The converted data is saved to the output folder and can be re-used next time.

    :param milie_args: instance of NspArguments
    :param masker: the masker which will mask the data as appropriate, an instance of a subclass of :py:class:`~milie.masking.Masking`
    :param data_handler: the dataset handler, an instance of :py:class:`~milie.dataset_handlers.dataset_bitext.BitextHandler` or a subclass
    :return: train_dataloader, an instance of :py:class:`TensorDataset`
    """
    if milie_args.cache_train and os.path.isfile(os.path.join(milie_args.output_dir, 'train.pt')):
        train_data = torch.load(os.path.join(milie_args.output_dir, 'train.pt'))
    else:
        masker.convert_examples_to_features(
            data_handler=data_handler,
            max_seq_length=milie_args.max_seq_length,
            is_training=True,
            milie_args=milie_args)
        train_data = data_handler.create_tensor_dataset()
        if milie_args.cache_train:
            torch.save(train_data, os.path.join(milie_args.output_dir, 'train.pt'))


    if milie_args.local_rank == -1:
        train_sampler = RandomSampler(train_data)
    else:
        train_sampler = DistributedSampler(train_data)
    data_handler.train_dataloader = DataLoader(train_data, sampler=train_sampler,
                                               batch_size=milie_args.train_batch_size)
    return 0


def get_model_elements(milie_args, data_handler):
    """
    Loads model, tokenizer and config according to provided arguments.

    :param milie_args: instance of NspArguments
    :param data_handler: the dataset handler, an (subclass) instance of :py:class:`~milie.dataset_handlers.dataset_bitext.BitextHandler`
    :return: A tuple containing config and model.

             - config: the configuration (any configuration supported by the huggingface library)
             - model: the model (either VariableHeadsNSP or a model from the huggingface library)
    """

    def load_state_dict(model, path):
        if not torch.cuda.is_available():
            model.load_state_dict(torch.load(path, map_location='cpu'))
        else:
            model.load_state_dict(torch.load(path))
        return model

    config_class, model_class, tokenizer_class = MODEL_CLASSES[milie_args.model_type]
    additional_arguments = get_additional_arguments(milie_args, data_handler)
    additional_arguments_tok = get_additional_arguments_tok(milie_args)

    config = config_class.from_pretrained(
        milie_args.config_name if milie_args.config_name else milie_args.model_name_or_path,
        output_hidden_states=milie_args.output_hidden_states,
        output_attentions=milie_args.output_attentions)
    tokenizer = tokenizer_class.from_pretrained(
        milie_args.tokenizer_name if milie_args.tokenizer_name else milie_args.model_name_or_path,
        do_lower_case=milie_args.do_lower_case, **additional_arguments_tok)
    # tokenizer belongs to the data_handler
    # that is, data_handler has control how to tokenize (not masker)
    if data_handler.tokenizer is None:
        data_handler.tokenizer = tokenizer

    LOGGER.info('Using the following configuration: {}'.format(config_class))
    LOGGER.info('Using the following tokenizer: {}'.format(tokenizer_class))
    LOGGER.info('Using the following model: {}'.format(model_class))
    LOGGER.info('Using the following additional arguments: {}'.format(additional_arguments))
    if milie_args.model_name_or_path == "vanilla" or milie_args.model_name_or_path == "":
        if milie_args.do_train:
            LOGGER.info("Creating a new model with initial weights")
            model = model_class(config=config, **additional_arguments)

            pretrained_config = config_class.from_pretrained('bert-base-multilingual-cased',
                output_hidden_states=milie_args.output_hidden_states,
                output_attentions=milie_args.output_attentions)
            pretrained_model = model_class.from_pretrained('bert-base-multilingual-cased',
                                                from_tf=bool('.ckpt' in milie_args.model_name_or_path),
                                                config=pretrained_config, **additional_arguments)
            model = copy_pretrained_model(model, pretrained_model)
        else:
            model = model_class(config=config, **additional_arguments)
            path = os.path.join(milie_args.output_dir, 'pytorch_model.bin')
            model = load_state_dict(model, path)
    else:
        if milie_args.do_train or milie_args.compute_embeddings:
            LOGGER.info("Loading a previous model: %s" % milie_args.model_name_or_path)
            model = model_class.from_pretrained(milie_args.model_name_or_path,
                                                from_tf=bool('.ckpt' in milie_args.model_name_or_path),
                                                config=config, **additional_arguments)
        else:
            model = model_class(config=config, **additional_arguments)
            path = os.path.join(milie_args.output_dir, 'pytorch_model.bin')
            model = load_state_dict(model, path)
    # resize embeddings
    if not milie_args.compute_embeddings and config.vocab_size != len(data_handler.tokenizer):
        #config will be updated during resizing
        config = model.resize_embeddings(len(data_handler.tokenizer))
        assert config.vocab_size == len(data_handler.tokenizer), \
            (config.vocab_size, len(data_handler.tokenizer))
    return config, model


def copy_pretrained_model(model, pretrained_model):
    pretrained_state_dict = pretrained_model.state_dict()
    model_state_dict = model.state_dict()
    tensor = model_state_dict["bert.embeddings.token_type_embeddings.weight"].detach().clone()
    #tensor = torch.from_numpy(np.asarray(np.random.random(size=(17,768)), dtype='float32')).contiguous()
    pretrained_state_dict["bert.embeddings.token_type_embeddings.weight"] = tensor
    model.load_state_dict(pretrained_state_dict)
    return model

def get_additional_arguments(milie_args, data_handler):
    """
    Given the arguments, potentially add further arguments to a dictionary for model instantiation. Due to the transformers library set up, num_labels needs to be passed to the config, whereas all other additional values (e.g. here num_labels_tok) need to be passed directly to the model.

    :param milie_args: instance of NspArguments
    :return: a dictionary with the information added as needed
    """
    additional_arguments = {}
    if milie_args.plus_generation > 0:
        additional_arguments["generate"] = milie_args.plus_generation
    if milie_args.plus_classify_sequence > 0:
        additional_arguments["classify_sequence"] = milie_args.plus_classify_sequence
        additional_arguments["num_labels_cls"] = data_handler.num_labels_cls
    if milie_args.plus_classify_tokens > 0:
        additional_arguments["classify_tokens"] = milie_args.plus_classify_tokens
        additional_arguments["num_labels_tok"] = data_handler.num_labels_tok
    return additional_arguments


def get_additional_arguments_tok(milie_args):
    """
    Get additional arguments for the tokenizer

    :param milie_args: instance of NspArguments
    :return: a dictionary with the information added as needed
    """
    additional_arguments = {}
    if milie_args.sentencepiece_model_file:
        additional_arguments["sentencepiece_model_file"] = milie_args.sentencepiece_model_file
    return additional_arguments


def set_up_device(milie_args):
    """
    Sets ups the device, gpu vs cpu and number of gpus if applicable.

    :param milie_args: an instance of NspArguments
    :return: tuple of:
            1. device: the device on which computations will be run
            2. n_gpu: the number of gpu's
    """
    if milie_args.no_cuda:
        device = torch.device('cpu')
        n_gpu = 1
    elif milie_args.local_rank == -1:
        # torch.device('cuda') can cause an error. need to specify an index of the default device.
        # if you specify the default other device than zero, the index must be the first at device_ids in Dataparallel
        # i.e. torch.nn.DataParallel(model, device_ids=[1, 2, 3, 0]), if you set 'cuda:1' here (n_gpu = 4).
        device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')
        n_gpu = torch.cuda.device_count()
    else:
        torch.cuda.set_device(0)
        # pylint: disable=not-callable
        # pylint: disable=no-member
        device = torch.device("cuda", 0)
        n_gpu = 1
        # Initializes the distributed backend which will take care of sychronizing nodes/GPUs
        torch.distributed.init_process_group(backend='nccl')
    return device, n_gpu


def move_model(milie_args, model, device, n_gpu):
    """
    Move model to correct device.

    :param milie_args: instance of NspArguments
    :param model: a model (VariableHeadsNSP or a model from the huggingface library)
    :param device: the device to move to
    :param n_gpu: the number of GPUs to use
    :return: the moved model
    """
    if milie_args.fp16:
        model.half()
    model.to(device)
    if milie_args.local_rank != -1:
        try:
            from apex.parallel import DistributedDataParallel as DDP
        except ImportError:
            LOGGER.warning("Please install apex from https://www.github.com/nvidia/apex "
                           "to use distributed and fp16 training.")
        model = DDP(model)
    elif n_gpu > 1:
        # TODO if true, then error: when trying to save model:
        # 'DataParallel' object has no attribute 'save_pretrained'
        model = torch.nn.DataParallel(model)
    return model


def argument_sanity_check(milie_args):
    """
    Performs sanity checks on the provided program arguments.

    :param milie_args: instance of NspArguments
    :return: 0 on success (else an error is raise)
    """
    if milie_args.gradient_accumulation_steps < 1:
        raise ValueError("Invalid gradient_accumulation_steps parameter: {}, "
                         "should be >= 1".format(milie_args.gradient_accumulation_steps))

    if not milie_args.do_train and not milie_args.do_predict:
        raise ValueError("At least one of `do_train` or `do_predict` must be True.")

    if milie_args.do_train:
        if not milie_args.train_file:
            raise ValueError(
                "If `do_train` is True, then `train_file` must be specified.")
    if milie_args.do_predict:
        if not milie_args.predict_file:
            raise ValueError(
                "If `do_predict` is True, then `predict_file` must be specified.")

    if os.path.exists(os.path.join(milie_args.output_dir, "pytorch_model.bin")) \
            and milie_args.do_train:
        if not milie_args.load_prev_model:
            raise ValueError("Output directory already contains a saved model (pytorch_model.bin): "
                             "%s" % milie_args.output_dir)
    os.makedirs(milie_args.output_dir, exist_ok=True)
    return 0


def prepare_optimizer(milie_args, model, t_total):
    """
    Prepares the optimizer for training.

    :param milie_args: instance of NspArguments
    :param model: the model for which the optimizer will be created
    :param t_total: the total number of training steps that will be performed
                    (need for learning rate schedules that depend on this)
    :return: the optimizer, the learning rate scheduler and the number of total steps
    """
    no_decay = ['bias', 'LayerNorm.bias', 'LayerNorm.weight']
    optimizer_grouped_parameters = [
        {'params': [p for n, p in model.named_parameters() if not any(nd in n for nd in no_decay)],
         'weight_decay': 0.01},
        {'params': [p for n, p in model.named_parameters() if any(nd in n for nd in no_decay)],
         'weight_decay': 0.0}
    ]

    if milie_args.local_rank != -1:
        t_total = t_total // torch.distributed.get_world_size()
    optimizer = AdamW(optimizer_grouped_parameters, lr=milie_args.learning_rate,
                      weight_decay=0.01, correct_bias=False, eps=1e-6)
    scheduler = None
    warmup_steps = milie_args.warmup_proportion * t_total
    if milie_args.adam_schedule == 'warmup_linear':
        LOGGER.info("Using linear warmup")
        scheduler = get_linear_schedule_with_warmup(optimizer, num_warmup_steps=warmup_steps,
                                                    num_training_steps=t_total)
    elif milie_args.adam_schedule == 'warmup_constant':
        LOGGER.info("Using constant warmup")
        scheduler = get_constant_schedule_with_warmup(optimizer, num_warmup_steps=warmup_steps)
    elif milie_args.adam_schedule == 'warmup_cosine':
        LOGGER.info("Using cosine warmup")
        scheduler = get_cosine_schedule_with_warmup(optimizer, num_warmup_steps=warmup_steps,
                                                    num_training_steps=t_total)

    return optimizer, scheduler


