#
 #     MILIE: Modular & Iterative Multilingual Open Information Extraction
 #
 #
 #
 #     Authors: Deleted for purposes of anonymity
 #
 #     Proprietor: Deleted for purposes of anonymity --- PROPRIETARY INFORMATION
 #
 # The software and its source code contain valuable trade secrets and shall be maintained in
 # confidence and treated as confidential information. The software may only be used for
 # evaluation and/or testing purposes, unless otherwise explicitly stated in the terms of a
 # license agreement or nondisclosure agreement with the proprietor of the software.
 # Any unauthorized publication, transfer to third parties, or duplication of the object or
 # source code---either totally or in part---is strictly prohibited.
 #
 #     Copyright (c) 2021 Proprietor: Deleted for purposes of anonymity
 #     All Rights Reserved.
 #
 # THE PROPRIETOR DISCLAIMS ALL WARRANTIES, EITHER EXPRESS OR
 # IMPLIED, INCLUDING BUT NOT LIMITED TO IMPLIED WARRANTIES OF MERCHANTABILITY
 # AND FITNESS FOR A PARTICULAR PURPOSE AND THE WARRANTY AGAINST LATENT
 # DEFECTS, WITH RESPECT TO THE PROGRAM AND ANY ACCOMPANYING DOCUMENTATION.
 #
 # NO LIABILITY FOR CONSEQUENTIAL DAMAGES:
 # IN NO EVENT SHALL THE PROPRIETOR OR ANY OF ITS SUBSIDIARIES BE
 # LIABLE FOR ANY DAMAGES WHATSOEVER (INCLUDING, WITHOUT LIMITATION, DAMAGES
 # FOR LOSS OF BUSINESS PROFITS, BUSINESS INTERRUPTION, LOSS OF INFORMATION, OR
 # OTHER PECUNIARY LOSS AND INDIRECT, CONSEQUENTIAL, INCIDENTAL,
 # ECONOMIC OR PUNITIVE DAMAGES) ARISING OUT OF THE USE OF OR INABILITY
 # TO USE THIS PROGRAM, EVEN IF the proprietor HAS BEEN ADVISED OF
 # THE POSSIBILITY OF SUCH DAMAGES.
 #
 # For purposes of anonymity, the identity of the proprietor is not given herewith.
 # The identity of the proprietor will be given once the review of the
 # conference submission is completed.
 #
 # THIS HEADER MAY NOT BE EXTRACTED OR MODIFIED IN ANY WAY.
 #

from __future__ import absolute_import
from __future__ import division
from __future__ import print_function

import logging
from torch.nn.utils import clip_grad_norm_
from tqdm import tqdm, trange

from .predict import predict
from .model_helper import save_model, get_train_dataloader, prepare_optimizer

LOGGER = logging.getLogger(__name__)

#TODO: fix this file so the unittest in test_train test_get_model_loss_multiple_datasets passes
#maybe Mayumi had a working version before?

def get_loss(milie_args, model, batch):
    """
    Given a batch, gets the loss for a VariableHeadsmilie model.

    :param model: a VariableHeadsmilie model
    :param batch: the current batch
    :return: A tuple consisting of

             - loss: the overall loss
             - batch_gen_logits: a list of lists: each entry gives the logits of the corresponding generation head
             - batch_cls_logits: a list of lists: each entry gives the logits of the corresponding sequence classification head
             - batch_tokens_logits: a list of lists: each entry gives the logits of the corresponding token classification head
             - gen_loss: the sum of the losses' of all generation heads
             - cls_loss:  the sum of the losses' of all sequence classification heads
             - tok_loss: the sum of the losses' of all token classification heads
             - ppl: the perplexity measured on all generation heads
             - attentions: attention scores for plotting attention heat maps
    """
    input_ids, input_mask, segment_ids, gen_label_ids, \
        classify_id_cls, classify_id_tokens, _ = batch

    # expected tensor shapes
    # need to check masking.py, especially if you pass an empty list there instead of [-1] with the desired shape.
    #input_ids.size() == (batch_size, max_sequence_length)
    #input_mask.size() == (batch_size, max_sequence_length)
    #segment_ids.size() == (batch_size, max_sequence_length)
    #gen_label_ids.size() == (batch_size, num_gen_heads, max_sequence_length)
    #classify_id_cls.size() == (batch_size, num_cls_heads, max_sequence_length)
    #classify_id_tokens.size() == (batch_size, num_tok_heads, max_sequence_length)

    # TODO: if we have models other than VariableHeadsmilie, the call here needs to be handled
    # differently
    # DataParallel / DistributedDataParallel calls model.forward() parallely
    # in multiple gpu and gather(average) the outputs.
    outputs = model(input_ids, attention_mask=input_mask, token_type_ids=segment_ids,
                    labels_tok=classify_id_tokens, masked_lm_labels=gen_label_ids,
                    labels_cls=classify_id_cls)

    loss = outputs[0]

    # get prediction in order to plot scores on train data
    batch_gen_logits = [None] * milie_args.plus_generation # if plus_generation == 0, then empty
    batch_cls_logits = [None] * milie_args.plus_classify_sequence # if plus_classify_sequence == 0, then empty
    batch_tokens_logits = [None] * milie_args.plus_classify_tokens # if plus_classify_tokens == 0, then empty
    counter = 1
    for i in range(milie_args.plus_generation): # if plus_generation == 0, then skipped
        batch_gen_logits[i] = outputs[counter]
        counter += 1
    for i in range(milie_args.plus_classify_sequence): # if plus_classify_sequence == 0, then skipped
        batch_cls_logits[i] = outputs[counter]
        counter += 1
    for i in range(milie_args.plus_classify_tokens): # if plus_classify_tokens == 0, then skipped
        batch_tokens_logits[i] = outputs[counter]
        counter += 1

    gen_loss, cls_loss, tok_loss, ppl = outputs[counter:counter+4]

    attentions = None
    if milie_args.output_attentions and outputs[-1]:
        attentions = outputs[-1]  # list of torch.FloatTensor (one for each layer) of shape (batch_size, num_heads, sequence_length, sequence_length)
    return loss, batch_gen_logits, batch_cls_logits, batch_tokens_logits, gen_loss, cls_loss, tok_loss, ppl, attentions


def train(milie_args, data_handler, data_handler_predict, model, masker, device, n_gpu):
    """
    Runs training for a model.

    :param milie_args: Instance of milieArguments
    :param data_handler: instance or subclass of :py:class:`~milie.dataset_handlers.dataset_bitext.BitextHandler`, for training
    :param data_handler_predict: instance or subclass of :py:class:`~milie.dataset_handlers.dataset_bitext.BitextHandler`, for validation
    :param model: the model that will be trained
    :param masker: subclass instance of :py:class:`~milie.masking.Masking`
    :param device: the device to run the computation on
    :param n_gpu: number of gpus used
    :return: the best score on the validation dataset
    """
    train_examples = data_handler.examples

    # Divide again by accum to get originally intended batch size, milie_args.train_batch_size
    # is already the size require to fit in RAM so we undo this here (t/a/a = t*a / a -> a cancels)
    num_train_steps = int((len(train_examples) / milie_args.train_batch_size /
                           milie_args.gradient_accumulation_steps) * milie_args.num_train_epochs)

    optimizer, scheduler = prepare_optimizer(milie_args, model, num_train_steps)
    num_batches = int((len(train_examples) / milie_args.train_batch_size))
    if milie_args.fp16:
        try:
            from apex import amp
        except ImportError:
            raise ImportError("Please install apex from "
                              "https://www.github.com/nvidia/apex to use fp16 training.")
        model, optimizer = amp.initialize(model, optimizer, opt_level=milie_args.fp16_opt_level)

    get_train_dataloader(milie_args, masker, data_handler)

    best_valid_score = 0.0  # if validation is run during training, keep track of best
    cumulative_loss = 0.0

    n_params = sum([p.nelement() for p in model.parameters()])
    LOGGER.info("Number of parameters: %d", n_params)

    data_handler.tokenizer.save_pretrained(milie_args.output_dir)

    global_step = 0
    model.zero_grad()

    LOGGER.info("Number of training examples: %d", len(data_handler.examples))

    #TODO add incremental loading of batches for large datasets where not everything fits in RAM
    for epoch in trange(int(milie_args.num_train_epochs), desc="Epoch"):
        LOGGER.info("Starting Epoch %s:", epoch)

        # some masking changes at every epoch, thus reload if necessary
        if milie_args.masking_strategy is not None and epoch != 0:  # already done for first epoch
            LOGGER.info("Recreating masks")
            get_train_dataloader(milie_args, masker, data_handler)

        for step, batch in enumerate(tqdm(data_handler.train_dataloader, desc="Training")):
            model.train()

            batch = tuple(t.to(device) for t in batch)

            loss, batch_gen_logits, batch_cls_logits, batch_tokens_logits, \
                gen_loss, cls_loss, tok_loss, ppl, attentions = get_loss(milie_args, model, batch)

            if n_gpu > 1:
                loss = loss.mean()  # mean() to average on multi-gpu.
                gen_loss = gen_loss.mean()
                cls_loss = cls_loss.mean()
                tok_loss = tok_loss.mean()
                ppl = ppl.mean()
            if milie_args.gradient_accumulation_steps > 1:
                loss = loss / milie_args.gradient_accumulation_steps

            if milie_args.fp16:
                with amp.scale_loss(loss, optimizer) as scaled_loss:
                    scaled_loss.backward()
                if milie_args.max_grad_norm > 0.0:
                    clip_grad_norm_(amp.master_params(optimizer), milie_args.max_grad_norm)
            else:
                loss.backward()
                if milie_args.max_grad_norm > 0.0:
                    clip_grad_norm_(model.parameters(), milie_args.max_grad_norm)

            # update weights
            if (step + 1) % milie_args.gradient_accumulation_steps == 0:
                # modify learning rate with special warm up BERT uses,
                # only needed for fp16, else handled in optimizer.py
                optimizer.step()
                scheduler.step()
                model.zero_grad()
                global_step += 1
                cumulative_loss += loss.item()

            if (step + 1) % max(int((num_batches/4)),1)==0 and milie_args.valid_every_epoch:
                avg_loss = cumulative_loss / global_step
                LOGGER.info('Average loss since last validation: %0.15f', avg_loss)
                cumulative_loss = 0
                LOGGER.info('Current learning rate and number of updates performed: %0.15f, %d',
                            scheduler.get_lr()[0], global_step)
                best_valid_score = validate(best_valid_score, milie_args, data_handler_predict, masker,
                                            model, device, global_step,
                                            tokenizer=data_handler.tokenizer, print_epoch=True)


        # Validate on the dev set if desired.
        if milie_args.valid_every_epoch:
            avg_loss = cumulative_loss / global_step
            LOGGER.info('Average loss since last validation: %0.15f', avg_loss)
            cumulative_loss = 0
            LOGGER.info('Current learning rate and number of updates performed: %0.15f, %d',
                        scheduler.get_lr()[0], global_step)
            best_valid_score = validate(best_valid_score, milie_args, data_handler_predict, masker,
                                        model, device, global_step,
                                        tokenizer=data_handler.tokenizer, print_epoch=True)

    # save last model if we didn't pick the best during training
    if not milie_args.valid_every_epoch:
        LOGGER.info("Saving final model")
        save_model(milie_args, model)

    return best_valid_score


def validate(best_valid_score, milie_args, data_handler_predict, masker, model, device,
             global_step, tokenizer=None, print_epoch=False):
    """
    After an epoch of training, validate on the validation set.

    :param best_valid_score: the currently best validation score
    :param milie_args: an instance of milieArguments
    :param data_handler_predict: instance or subclass instance of :py:class:`~milie.dataset_handlers.dataset_bitext.BitextHandler`,
     on which to run prediction
    :param masker: an instance of a subclass of :py:class:`~milie.masking.Masking`
    :param model: the BERT model
    :param device: where to run computations
    :param global_step: the current global step
    :return: the new best validation score
    """
    model.eval()
    if best_valid_score == 0.0:  # then first epoch, save model
        save_model(milie_args, model)

    # prepare tokenizer
    if data_handler_predict.tokenizer is None and tokenizer is not None:
        data_handler_predict.tokenizer = tokenizer

    # predict
    results_collection = predict(milie_args, data_handler_predict, masker, model, device, global_step,
                                 print_epoch=print_epoch)
    LOGGER.info("Validation results: %s", results_collection)
    # For the data set, select the score that will decide which model to keep
    deciding_score = data_handler_predict.select_deciding_score(results_collection)

    if best_valid_score < deciding_score:
        LOGGER.info("Step %s: Saving new best model: %s vs. previous %s",
                    global_step, deciding_score, best_valid_score)
        save_model(milie_args, model)
        best_valid_score = deciding_score
    model.train()
    return best_valid_score
