#
 #     MILIE: Modular & Iterative Multilingual Open Information Extraction
 #
 #
 #
 #     Authors: Deleted for purposes of anonymity
 #
 #     Proprietor: Deleted for purposes of anonymity --- PROPRIETARY INFORMATION
 #
 # The software and its source code contain valuable trade secrets and shall be maintained in
 # confidence and treated as confidential information. The software may only be used for
 # evaluation and/or testing purposes, unless otherwise explicitly stated in the terms of a
 # license agreement or nondisclosure agreement with the proprietor of the software.
 # Any unauthorized publication, transfer to third parties, or duplication of the object or
 # source code---either totally or in part---is strictly prohibited.
 #
 #     Copyright (c) 2021 Proprietor: Deleted for purposes of anonymity
 #     All Rights Reserved.
 #
 # THE PROPRIETOR DISCLAIMS ALL WARRANTIES, EITHER EXPRESS OR
 # IMPLIED, INCLUDING BUT NOT LIMITED TO IMPLIED WARRANTIES OF MERCHANTABILITY
 # AND FITNESS FOR A PARTICULAR PURPOSE AND THE WARRANTY AGAINST LATENT
 # DEFECTS, WITH RESPECT TO THE PROGRAM AND ANY ACCOMPANYING DOCUMENTATION.
 #
 # NO LIABILITY FOR CONSEQUENTIAL DAMAGES:
 # IN NO EVENT SHALL THE PROPRIETOR OR ANY OF ITS SUBSIDIARIES BE
 # LIABLE FOR ANY DAMAGES WHATSOEVER (INCLUDING, WITHOUT LIMITATION, DAMAGES
 # FOR LOSS OF BUSINESS PROFITS, BUSINESS INTERRUPTION, LOSS OF INFORMATION, OR
 # OTHER PECUNIARY LOSS AND INDIRECT, CONSEQUENTIAL, INCIDENTAL,
 # ECONOMIC OR PUNITIVE DAMAGES) ARISING OUT OF THE USE OF OR INABILITY
 # TO USE THIS PROGRAM, EVEN IF the proprietor HAS BEEN ADVISED OF
 # THE POSSIBILITY OF SUCH DAMAGES.
 #
 # For purposes of anonymity, the identity of the proprietor is not given herewith.
 # The identity of the proprietor will be given once the review of the
 # conference submission is completed.
 #
 # THIS HEADER MAY NOT BE EXTRACTED OR MODIFIED IN ANY WAY.
 #

import sys
import random
import logging
import numpy as np
from tqdm import tqdm

LOGGER = logging.getLogger(__name__)


class GenInputFeatures:
    """Feature for one data point."""

    def __init__(self,
                 input_ids,
                 input_mask,
                 segment_ids,
                 gen_label_ids,
                 classify_id_cls,
                 classify_id_tokens=None):
        """
        General possible structure of an input sentence:
        [CLS] Part A [SEP] Part B [SEP] <Padding until max_seq_length>

        :param input_ids: contains the vocabulary id for each unmasked token,
                          masked tokens receive the value of [MASK]
        :param input_mask: 1 prior to padding, 0 for padding
        :param segment_ids: 0 for Part A, 1 for Part B, 0 for padding.
        :param gen_label_ids: -1 for unmasked tokens, vocabulary id for masked tokens
        :param classify_id_cls: the gold value that we want to predict on the [CLS] token
        :param classify_id_tokens: the gold values for token classification
        """
        self.input_ids = input_ids
        self.input_mask = input_mask
        self.segment_ids = segment_ids
        self.gen_label_ids = gen_label_ids
        self.classify_id_cls = classify_id_cls
        self.classify_id_tokens = classify_id_tokens


def get_masker(milie_args):
    """
    Factory for returning a masker.

	Nowadays all masking options are united in one Masking class.

    :param milie_args: the command line arguments
    :return: an instance of :py:class:`~milie.masking.Masking`
    """
    return Masking(milie_args)


class Masking:
    """
    Masking objects take a data_handler, (sub)class instance of :py:class:`~milie.dataset_handlers.dataset_bitext.BitextHandler`,
    and converts the elements of data_handler.examples into a set of features,
    which are stored in data_handler.features.

    Same index indicates same example/feature.

    data_handler.examples is a list of (sub)class instances of :py:class:`~milie.dataset_handlers.dataset_bitext.GenExample`
    data_hanlder.examples ist a list of instances of :py:class:`~milie.masking.GenInputFeatures`
    """

    def __init__(self, milie_args):
        """
        The following arguments will be set:

        - **violate_max_part_a_len**: how often in data_handler, the maximum query length (Part A) was violated
        - **violate_max_gen_len**: how often in data_handler, the maximum generation length (Part B) was violated
        - **trunc_part_b**: how often Part B was truncated
        - **trunc_part_a**: how often Part A was truncated
        - **max_gen_b_length**: If we generate in part B, how long it is allowed to be
        - **max_gen_a_length**: If we generate in part A, how long it is allowed to be
        - **plus_generation**: number of generation heads
        - **plus_classify_tokens**: number of token classification heads
        - **max_part_a**: the maximum length of part A
        - **mask_in_a**: If true, masking will be applied to part A
        - **mask_in_b**: If true, masking will be applied to part B
        - **masking_strategy**: Options are: Masks are drawn from a bernoulli or gaussian distribution
          or dataset dependent (then the function data_handler.possible_mask_locations
          will return a list the same length as there are tokens and an entry will be 1
          if it is okay to mask there. Then based on this list, masks will be drawn from
          a bernoulli dist.)
        - **mean**: The mean of either the bernoulli or gaussian distribution
        - **stdev**: The standard deviation for the gaussian distribution
        """
        self.violate_max_part_a_len = 0
        self.violate_max_gen_len = 0
        self.trunc_part_b = 0
        self.trunc_part_a = 0
        self.max_gen_b_length = milie_args.max_gen_b_length
        self.max_gen_a_length = milie_args.max_gen_a_length
        self.plus_generation = milie_args.plus_generation
        self.plus_classify_tokens = milie_args.plus_classify_tokens
        self.max_part_a = milie_args.max_part_a
        self.mask_in_a = milie_args.mask_in_a
        self.mask_in_b = milie_args.mask_in_b
        self.masking_strategy = milie_args.masking_strategy
        self.mean = milie_args.distribution_mean
        # for masking_strategy = 'gaussian'
        self.stdev = milie_args.distribution_stdev
        LOGGER.info("Mean: %s, Variance: %s", self.mean, self.stdev)

    def _binomial(self, mask_list):
        """
        Given a list of 0s and 1s, for each entry with 1, sample with mean self.mean whether
        to mask this position or not.

        :param mask_list: the list of 0s and 1s
        :return: same length as mask_list, with position with 1 masked with probability self.mean
        """
        for index, _ in enumerate(mask_list):
            if mask_list[index] == 1.0:
                binomial = np.random.binomial(1, self.mean)
                if binomial == 0:
                    mask_list[index] = 0.0
        return mask_list

    def set_up_mask(self, data_handler, part_a, part_b):
        """
        Given a data_handler and a current instance (defined via part_a and part_b),
        create masks.

        :param data_handler: (sub)class instance of :py:class:`~milie.dataset_handlers.dataset_bitext.BitextHandler`
        :param part_a: Still strings, but tokenized.
        :param part_b: Still strings, but tokenized.
        :return: a tuple of:

                 - mask_list_a: same length as part_a, 1.0 if the position should be masked, 0.0 else
                 - mask_list_b: same length as part_b, 1.0 if the position should be masked, 0.0 else
        """
        if self.masking_strategy == 'dataset_dependent':
            # Create masks (does not cover [CLS] or [SEP] )
            mask_list_a, mask_list_b = \
                data_handler.possible_mask_locations(part_a, part_b, is_training=True)
            # at the moment, mask_list would mask all possible instances, now apply binomial
            mask_list_a = self._binomial(mask_list_a)
            mask_list_b = self._binomial(mask_list_b)
        else:
            mask_list_a = self._create_stochastic_mask(len(part_a))
            mask_list_b = self._create_stochastic_mask(len(part_b))
        return mask_list_a, mask_list_b

    def _create_stochastic_mask(self, len_mask_list):
        """
        Given a length, it uses the specified masking strategy to create a corresponding
        masking list.

        :param len_mask_list: the length the masking list will have to be.
        :return: the masking list, where it is 1.0 if a mask should be placed in that position
        """
        mask_list = [0.0] * len_mask_list
        if self.masking_strategy == 'bernoulli':
            # 1.0 means mask
            for i, _ in enumerate(mask_list):
                sample = random.random()
                if sample < self.mean:
                    mask_list[i] = 1.0
        elif self.masking_strategy == 'gaussian':
            current_threshold = np.random.normal(self.mean, self.stdev)
            if current_threshold > 1.0:
                current_threshold = 1.0
            if current_threshold < 0.0:
                current_threshold = 0.0
            nr_masks = int(round(current_threshold * len_mask_list))
            mask_list = [1.0] * nr_masks + [0.0] * (len_mask_list - nr_masks)
            random.shuffle(mask_list)
        return mask_list

    def handle_masking(self, part_a, part_b, is_training, max_seq_length, tokenizer,
                       example_index=-1, data_handler=None):
        """
        Convert a part_a and a part_b into 4 lists needed to instantiate :py:class:`~milie.masking.GenInputFeatures`

        :param part_a: a string of text of Part A (already tokenized), i.e. part_a of a subclass instance of :py:class:`~milie.dataset_handlers.dataset_bitext.GenExample`
        :param part_b: a string of text of Part B (already tokenized), i.e. part_b of a subclass instance of :py:class:`~milie.dataset_handlers.dataset_bitext.GenExample`
        :param is_training: true if training, masks are only applied if training is true. Else:

                - if self.max_gen_a_length > 0: Write this many mask tokens in Part A
                - if self.max_gen_b_length > 0: Write this many mask tokens in Part B
                - other possibility: [MASK] is directly written into the test file.
        :param max_seq_length: the maximum sequence length (Part A + Part B)
        :param tokenizer: an instance of :py:class:`BertTokenizer`
        :param example_index: the index of the current sample, e.g. i when iterating over data_handler.examples[i]
        :param data_handler: an instance or subclass instance of :py:class:`~milie.dataset_handlers.dataset_bitext.BitextHandler`
        :return: a 4-tuple of lists, each with length max_seq_length

                 - input_ids: ids of "[cls] part a [sep] part b [sep]" or a masking thereof
                 - input_mask: 1 for all spots that should be attended to
                 - segment_ids: 0 up to and including the first [sep], 1 until second [sep] or
                   for remainder of sequence
                 - gen_label_ids: -1 for positions in input_ids that should not be predicted,
                   the id of the to-be-predicted token, should be always -1 at test time
                   (Note that although we can have several generation heads, only one
                   gen_label_ids will be created here. So far we have not needed the other scenario.)
        """
        mask_list_a = None
        mask_list_b = None
        # the [SEP] token needs to be learnt too in the case of masking
        part_a_with_sep = part_a + ['[SEP]']
        part_b_with_sep = part_b + ['[SEP]']
        # note: even if only one is true, we do set up both but just never use the other mask list
        if (self.mask_in_a is True or self.mask_in_b is True) and is_training is True:
            mask_list_a, mask_list_b = self.set_up_mask(data_handler,
                                                        part_a_with_sep, part_b_with_sep)
            assert len(mask_list_a) == len(part_a_with_sep)
            assert len(mask_list_b) == len(part_b_with_sep)

        tokens = []
        segment_ids = []
        gen_label_ids = []

        tokens.append("[CLS]")
        segment_ids.append(0)
        gen_label_ids.append(-1)

        # Part A
        # if: then add only padding
        if is_training is False and self.max_gen_a_length > 0:
            for _ in range(self.max_gen_a_length):
                tokens.append('[MASK]')
                segment_ids.append(0)
                gen_label_ids.append(-1)
        # else: then if is_training apply masks, if not apply no masks, used for sequence and token
        # classification, or if plus_generation, then assumed that the input has [MASK] in the
        # correct positions
        else:
            # self.max_part_a - 1: save space for [sep]
            if self.mask_in_a is True and is_training is True:
                part_a_tokens, part_a_segment_ids, part_a_gen_label_ids = \
                    self.prepare_input_with_masking(part_a_with_sep, example_index, tokenizer,
                                                    mask_list_a, 0, (self.max_part_a - 1),
                                                    part_type='A')
            else:
                part_a_tokens, part_a_segment_ids, part_a_gen_label_ids = \
                    self.prepare_input_no_masking(part_a_with_sep, example_index, 0,
                                                  (self.max_part_a - 1), part_type='A')
            tokens += part_a_tokens
            segment_ids += part_a_segment_ids
            gen_label_ids += part_a_gen_label_ids

            # If generation is set up, pad with mask tokens up to max_gen_a_length
            if len(part_a_with_sep) < self.max_gen_a_length and self.plus_generation > 0:
                for _ in range(self.max_gen_a_length - len(part_a_with_sep)):
                    tokens.append('[MASK]')
                    segment_ids.append(0)
                    gen_label_ids.append(-1)

        # Part B
        # if: then add only padding
        if is_training is False and self.max_gen_b_length > 0:
            for _ in range(self.max_gen_b_length):
                tokens.append('[MASK]')
                segment_ids.append(1)
                gen_label_ids.append(-1)
        # else: then if is_training apply masks, if not apply no masks, used for sequence and token
        # classification, or if plus_generation, then assumed that the input has [MASK] in the
        # correct positions
        else:
            if self.mask_in_b is True and is_training is True:
                part_b_tokens, part_b_segment_ids, part_b_gen_label_ids = \
                    self.prepare_input_with_masking(part_b_with_sep, example_index, tokenizer,
                                                    mask_list_b, 1, max_seq_length, part_type='B')
            else:
                part_b_tokens, part_b_segment_ids, part_b_gen_label_ids = \
                    self.prepare_input_no_masking(part_b_with_sep, example_index, 1, max_seq_length,
                                                  part_type='B')
            tokens += part_b_tokens
            segment_ids += part_b_segment_ids
            gen_label_ids += part_b_gen_label_ids

            # If generation is set up, pad with mask tokens up to max_gen_b_length
            if len(part_b_with_sep) < self.max_gen_b_length and self.plus_generation > 0:
                for _ in range(self.max_gen_b_length - len(part_b_with_sep)):
                    tokens.append('[MASK]')
                    segment_ids.append(1)
                    gen_label_ids.append(-1)

        input_ids = tokenizer.convert_tokens_to_ids(tokens)
        input_mask = [1] * len(input_ids)

        # Pad to maximum sequence length
        while len(input_ids) < max_seq_length:
            input_ids.append(0)
            input_mask.append(0)
            segment_ids.append(0)
            gen_label_ids.append(-1)

        # VariableHeadsNSP expects a list, mock it here
        # (it is not further implemented because it has not been needed yet)
        gen_label_ids = [gen_label_ids]
        return input_ids, input_mask, segment_ids, gen_label_ids

    def convert_examples_to_features(self, data_handler, max_seq_length, is_training, milie_args):
        """
        From a list of examples, (sub)class instances of :py:class:`~milie.dataset_handlers.dataset_bitext.GenExample`,
        creates a list of instances of :py:class:`~milie.masking.GenInputFeatures`

        :param data_handler: a subclass instance of :py:class:`~milie.dataset_handlers.dataset_bitext.BitextHandler` will access
                             data_handler.examples (list of subclass instances of :py:class:`~milie.dataset_handlers.dataset_bitext.GenExample`)
                             and will set data_handler.features (list of instances of :py:class:`~milie.masking.GenInputFeatures`)
        :param max_seq_length: the maximum sequence length ([CLS] + Part A + [SEP] + Part B + [SEP])
        :param is_training: true if training, handles gold label construction
        :param milie_args: the command line arguments
        :return: 0 on success (data_handler.features is now set)
        """
        max_part_a = milie_args.max_part_a
        tokenizer = data_handler.tokenizer
        data_handler.features = []
        max_a = 0
        max_b = 0
        plus_generation_warning_given = False
        plus_classify_sequence_warning_given = False
        plus_classify_tokens_warning_given = False

        # iterate over subclass instances of :py:class:GenExample
        for i, example in enumerate(tqdm(data_handler.examples)):
            # data_handler.update_info(i)  # used for meta handler, but May doesn't use it anymore
            # Part A
            part_a = tokenizer.tokenize(example.part_a)
            max_a = max(max_a, len(part_a))
            if len(part_a) > max_part_a:
                if data_handler.truncate_end:
                    part_a = part_a[0:max_part_a]
                    self.trunc_part_a += 1
                else:  # truncate beginning
                    # +2 because we save space for [CLS] and [SEP]
                    first_trunc_index = len(part_a) - max_part_a + 2
                    part_a = part_a[first_trunc_index:]
                    self.trunc_part_a += 1

            # Part B
            part_b = tokenizer.tokenize(example.part_b)
            max_b = max(max_b, len(part_b))
            max_part_b = max_seq_length + max_part_a
            if len(part_b) > max_part_b:
                if data_handler.truncate_end:
                    part_b = part_b[0:max_part_b]
                    self.trunc_part_b += 1
                else:  # truncate beginning
                    # +1 because we save space for [SEP]
                    first_trunc_index = len(part_b) - max_part_b + 1
                    part_b = part_b[first_trunc_index:]
                    self.trunc_part_b += 1

            classify_id_cls = []
            if is_training is True:
                classify_id_cls = example.classify_id_cls

            # Masking for one example, handled by subclass of :py:class:Masker
            input_ids, input_mask, segment_ids, gen_label_ids = \
                self.handle_masking(part_a, part_b, is_training, max_seq_length, tokenizer, i,
                                    data_handler)

            # How tokens should be classified is the job of the dataset specific class
            classify_id_tokens = []
            if self.plus_classify_tokens > 0 and is_training is True:
                classify_id_tokens = data_handler.get_token_classification_ids(example,
                                                                               input_ids)
            segment_ids = data_handler.get_segment_ids(example, input_ids)
            # sanity checks
            assert len(input_ids) == max_seq_length
            assert len(input_mask) == max_seq_length
            #assert len(segment_ids) == max_seq_length
            for head in gen_label_ids:
                assert len(head) == max_seq_length
            if len(gen_label_ids) != milie_args.plus_generation and \
                    plus_generation_warning_given is not True:
                LOGGER.warning("The number of generation heads assumed by the dataset handler does "
                               "not match the number of generation heads specified by the command "
                               "line argument plus_generation")
                plus_generation_warning_given = True
            if len(classify_id_cls) != milie_args.plus_classify_sequence and \
                    plus_classify_sequence_warning_given is not True:
                LOGGER.warning("The number of sequence classification heads assumed by the dataset "
                               "handler does not match the number of generation heads specified "
                               "by the command line argument plus_classify_sequence")
                plus_classify_sequence_warning_given = True
            if len(classify_id_tokens) != milie_args.plus_classify_sequence and \
                    plus_classify_tokens_warning_given is not True:
                LOGGER.warning("The number of token classification heads assumed by the dataset "
                               "handler does not match the number of generation heads specified "
                               "by the command line argument plus_classify_tokens")
                plus_classify_tokens_warning_given = True

            if example.example_index < 1:
                LOGGER.info("*** Example ***")
                LOGGER.info("Feature for example: %s", i)
                LOGGER.info("part_a: %s", part_a)
                LOGGER.info("part_b: %s", part_b)
                LOGGER.info("input_ids: %s", input_ids)
                LOGGER.info("input_mask: %s", input_mask)
                LOGGER.info("segment_ids: %s", segment_ids)
                LOGGER.info("gen_label_ids: %s", gen_label_ids)
                LOGGER.info("classify_id_cls: %s", classify_id_cls)
                LOGGER.info("classify_id_tokens: %s", classify_id_tokens)

            feature = GenInputFeatures(
                input_ids=input_ids,
                input_mask=input_mask,
                segment_ids=segment_ids,
                gen_label_ids=gen_label_ids,
                classify_id_cls=classify_id_cls,
                classify_id_tokens=classify_id_tokens)
            data_handler.features.append(feature)

        # Every example has exactly one corresponding features at the same index
        assert len(data_handler.examples) == len(data_handler.features)
        LOGGER.info("Maximum Part A is: %s", max_a)
        LOGGER.info("Maximum Part B is: %s", max_b)
        LOGGER.warning("Couldn't encode query length %s times.", self.violate_max_part_a_len)
        LOGGER.warning("Couldn't encode generation length %s times.", self.violate_max_gen_len)
        LOGGER.warning("Truncated part b %s times.", self.trunc_part_b)
        LOGGER.warning("Truncated part a %s times.", self.trunc_part_a)
        return 0

    def prepare_input_no_masking(self, part, example_index, seg_id, max_len, part_type='A'):
        """
        Prepare the input data of the current example without applying any masking.

        :param part: Either part a or part b
        :param example_index: the index of the current example (just for debugging message)
        :param seg_id: which segment ID to write.
        :param max_len: The maximum length for this part.
        :param part_type: Whether its part a or b (just for debugging message)
        :return: A tuple of:

                 - tokens: a list of stringtokens (still needs converting to IDs)
                 - segment_ids: the segment ids
                 - gen_label_ids: As no masking is applied, a list of just -1
        """
        tokens = []
        segment_ids = []
        gen_label_ids = []
        for token in part:
            tokens.append(token)
            segment_ids.append(seg_id)
            gen_label_ids.append(-1)
            if len(tokens) == max_len:  # save space for [SEP]
                LOGGER.debug("Can't encode the maximum Part %s length of example number %s",
                             (part_type, example_index))
                self.violate_max_part_a_len += 1
                break
        return tokens, segment_ids, gen_label_ids

    def prepare_input_with_masking(self, part, example_index, tokenizer, mask_list, seg_id, max_len,
                                   part_type='A'):
        """
        Prepare the input data of the current example while applying any masking.

        :param part: Either part a or part b, still string format
        :param example_index: the index of the current example (just for debugging message)
        :param tokenizer: the tokenizer, we use this to write the correct IDs to gen_label_ids (the strings for tokens are converted into IDs later)
        :param mask_list: a list that contains 1.0 if the position should be masked
        :param seg_id: which segment ID to write.
        :param max_len: The maximum length for this part.
        :param part_type: Whether its part a or b (just for debugging message)
        :return: A tuple of:

                 - tokens: a list of stringtokens (still needs converting to IDs)
                 - segment_ids: the segment ids
                 - gen_label_ids: -1 if no mask should be applied, else the ID of the token
        """
        tokens = []
        segment_ids = []
        gen_label_ids = []
        for idx, token in enumerate(part):
            if mask_list[idx] == 1.0:
                tokens.append('[MASK]')
                gen_label_ids.append(tokenizer.vocab[token])
            else:
                tokens.append(token)
                gen_label_ids.append(-1)
            # always supply gen label even if not masked, previously produced worse results
            # gen_label_ids.append(tokenizer.vocab[token])
            segment_ids.append(seg_id)
            if len(tokens) == max_len:  # save space for [SEP]
                LOGGER.debug("Can't encode the maximum Part %s length of example number %s",
                             (part_type, example_index))
                self.violate_max_part_a_len += 1
                break
        return tokens, segment_ids, gen_label_ids
