
#
 #     MILIE: Modular & Iterative Multilingual Open Information Extraction
 #
 #
 #
 #     Authors: Deleted for purposes of anonymity
 #
 #     Proprietor: Deleted for purposes of anonymity --- PROPRIETARY INFORMATION
 #
 # The software and its source code contain valuable trade secrets and shall be maintained in
 # confidence and treated as confidential information. The software may only be used for
 # evaluation and/or testing purposes, unless otherwise explicitly stated in the terms of a
 # license agreement or nondisclosure agreement with the proprietor of the software.
 # Any unauthorized publication, transfer to third parties, or duplication of the object or
 # source code---either totally or in part---is strictly prohibited.
 #
 #     Copyright (c) 2021 Proprietor: Deleted for purposes of anonymity
 #     All Rights Reserved.
 #
 # THE PROPRIETOR DISCLAIMS ALL WARRANTIES, EITHER EXPRESS OR
 # IMPLIED, INCLUDING BUT NOT LIMITED TO IMPLIED WARRANTIES OF MERCHANTABILITY
 # AND FITNESS FOR A PARTICULAR PURPOSE AND THE WARRANTY AGAINST LATENT
 # DEFECTS, WITH RESPECT TO THE PROGRAM AND ANY ACCOMPANYING DOCUMENTATION.
 #
 # NO LIABILITY FOR CONSEQUENTIAL DAMAGES:
 # IN NO EVENT SHALL THE PROPRIETOR OR ANY OF ITS SUBSIDIARIES BE
 # LIABLE FOR ANY DAMAGES WHATSOEVER (INCLUDING, WITHOUT LIMITATION, DAMAGES
 # FOR LOSS OF BUSINESS PROFITS, BUSINESS INTERRUPTION, LOSS OF INFORMATION, OR
 # OTHER PECUNIARY LOSS AND INDIRECT, CONSEQUENTIAL, INCIDENTAL,
 # ECONOMIC OR PUNITIVE DAMAGES) ARISING OUT OF THE USE OF OR INABILITY
 # TO USE THIS PROGRAM, EVEN IF the proprietor HAS BEEN ADVISED OF
 # THE POSSIBILITY OF SUCH DAMAGES.
 #
 # For purposes of anonymity, the identity of the proprietor is not given herewith.
 # The identity of the proprietor will be given once the review of the
 # conference submission is completed.
 #
 # THIS HEADER MAY NOT BE EXTRACTED OR MODIFIED IN ANY WAY.
 #
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function

import logging

import torch

from .predict import predict
from .masking import get_masker
from .model_helper import get_model_elements, argument_sanity_check, set_seed, set_up_device, \
    move_model
from .dataset_handlers.datasets_factory import get_data_handler
from .train import train

LOGGER = logging.getLogger(__name__)


def milie_runner(milie_args):
    """
    Main function to run training or prediction for Nsp.

    :param milie_args: instance of NspArguments
    :return: A tuple of:
            1. the best score of the validation set during training
            2. the best score after prediction
    """
    # Set up masker which decides what parts in the input will be masked
    masker = get_masker(milie_args)
    LOGGER.info("Masker: %s", masker)

    # Set up needed data handlers
    data_handler = None
    data_handler_predict = None
    if milie_args.do_train:
        data_handler = get_data_handler(milie_args)
        LOGGER.info("Data Handler for training: %s", data_handler)
    if milie_args.do_predict or milie_args.valid_every_epoch:
        # then meta class for training, but we just use the first dataset for prediction
        data_handler_predict = get_data_handler(milie_args, predict=True)
        LOGGER.info("Data Handler for prediction: %s", data_handler_predict)

    if milie_args.compute_embeddings:
        data_handler = get_data_handler(milie_args)

    # Runs some sanity check on the argument, returns 0 on success, else raises error
    if not milie_args.compute_embeddings:
        argument_sanity_check(milie_args)

    # Set up where computations will be run (gpu vs cpu, number of gpus)
    device, n_gpu = set_up_device(milie_args)

    # Get actual batch size that will fit in RAM
    milie_args.train_batch_size = int(milie_args.train_batch_size /
                                    milie_args.gradient_accumulation_steps)

    # Set seed
    set_seed(milie_args.seed, n_gpu)

    # Training
    deciding_score_train = -1
    if milie_args.do_train:
        # Prepare model
        config, model = get_model_elements(milie_args, data_handler) # tokenizer instantiated here

        # Move model
        model = move_model(milie_args, model, device, n_gpu)

        data_handler.read_examples(is_training=True)
        deciding_score_train = train(milie_args, data_handler, data_handler_predict, model, masker,
                                     device, n_gpu)


    # Prediction
    deciding_score = -1
    if milie_args.do_predict and (milie_args.local_rank == -1 or torch.distributed.get_rank() == 0):
        if milie_args.do_train:
            milie_args.config_name = milie_args.output_dir
            milie_args.tokenizer_name = milie_args.output_dir
            milie_args.model_name_or_path = milie_args.output_dir
        config, model = get_model_elements(milie_args, data_handler_predict)  # tokenizer instantiated here
        model = move_model(milie_args, model, device, n_gpu)
        model.eval()
        results = predict(milie_args, data_handler_predict, masker, model, device)
        deciding_score = data_handler_predict.select_deciding_score(results)



    return deciding_score_train, deciding_score
