#
 #     MILIE: Modular & Iterative Multilingual Open Information Extraction
 #
 #
 #
 #     Authors: Deleted for purposes of anonymity
 #
 #     Proprietor: Deleted for purposes of anonymity --- PROPRIETARY INFORMATION
 #
 # The software and its source code contain valuable trade secrets and shall be maintained in
 # confidence and treated as confidential information. The software may only be used for
 # evaluation and/or testing purposes, unless otherwise explicitly stated in the terms of a
 # license agreement or nondisclosure agreement with the proprietor of the software.
 # Any unauthorized publication, transfer to third parties, or duplication of the object or
 # source code---either totally or in part---is strictly prohibited.
 #
 #     Copyright (c) 2021 Proprietor: Deleted for purposes of anonymity
 #     All Rights Reserved.
 #
 # THE PROPRIETOR DISCLAIMS ALL WARRANTIES, EITHER EXPRESS OR
 # IMPLIED, INCLUDING BUT NOT LIMITED TO IMPLIED WARRANTIES OF MERCHANTABILITY
 # AND FITNESS FOR A PARTICULAR PURPOSE AND THE WARRANTY AGAINST LATENT
 # DEFECTS, WITH RESPECT TO THE PROGRAM AND ANY ACCOMPANYING DOCUMENTATION.
 #
 # NO LIABILITY FOR CONSEQUENTIAL DAMAGES:
 # IN NO EVENT SHALL THE PROPRIETOR OR ANY OF ITS SUBSIDIARIES BE
 # LIABLE FOR ANY DAMAGES WHATSOEVER (INCLUDING, WITHOUT LIMITATION, DAMAGES
 # FOR LOSS OF BUSINESS PROFITS, BUSINESS INTERRUPTION, LOSS OF INFORMATION, OR
 # OTHER PECUNIARY LOSS AND INDIRECT, CONSEQUENTIAL, INCIDENTAL,
 # ECONOMIC OR PUNITIVE DAMAGES) ARISING OUT OF THE USE OF OR INABILITY
 # TO USE THIS PROGRAM, EVEN IF the proprietor HAS BEEN ADVISED OF
 # THE POSSIBILITY OF SUCH DAMAGES.
 #
 # For purposes of anonymity, the identity of the proprietor is not given herewith.
 # The identity of the proprietor will be given once the review of the
 # conference submission is completed.
 #
 # THIS HEADER MAY NOT BE EXTRACTED OR MODIFIED IN ANY WAY.
 #
import logging

import torch
from torch import nn
from torch.nn import CrossEntropyLoss, MSELoss

from transformers.models.bert.modeling_bert import BertPreTrainedModel, BertModel, BertOnlyMLMHead, BertPreTrainingHeads, BertEmbeddings
from transformers.models.bert_japanese.tokenization_bert_japanese import BertJapaneseTokenizer
from transformers.configuration_utils import PretrainedConfig
from transformers import (BertConfig, BertTokenizer, XLMConfig, XLMTokenizer, XLMWithLMHeadModel,
                          XLNetConfig, XLNetTokenizer, XLNetLMHeadModel, DistilBertConfig,
                          DistilBertTokenizer)



LOGGER = logging.getLogger(__name__)


class VariableHeadsmilie(BertPreTrainedModel):
    """milie extends the BERT model by adding a generation head and allowing any variable combination
    of token-level and sentence-level classification as well as generation.

    Each layer head type (token, cls/sequence and generation) is modelled as a list (nn.Module)
    and there can thus be arbitrarily many heads.
    Currently, the CLS head is not loaded from the pre-trained weights.
    For each generation head, we copy the weights from the pre-training.



    Params:

        - **config**: a BertConfig class instance with the configuration to build a new model.
        - **generate**: >=1 if the generation head should be active
        - **classify_sequence**: >=1 if a sequence should be classified
          (set num_labels_cls to 1 in dataset specific handler for regression)
        - **classify_tokens**: >=1 if tokens should be classified
        - **num_labels_tok**: the number of classes for the classifier on the tokens.
        - **num_labels_cls**: the number of classes for the classifier on the [CLS] token
          (for sequence classification).

    Inputs:

        - **input_ids**: a torch.LongTensor of shape [batch_size, sequence_length]
          with the word token indices in the vocabulary (see the tokens preprocessing logic in
          the scripts `extract_features.py`, `run_classifier.py` and `run_squad.py`)
        - **token_type_ids**: an optional torch.LongTensor of shape [batch_size, sequence_length] with
          the token types indices selected in [0, 1]. Type 0 corresponds to a `sentence A` and
          type 1 corresponds to a `sentence B` token (see BERT paper for more details).
        - **attention_mask**: an optional torch.LongTensor of shape [batch_size, sequence_length] with
          indices selected in [0, 1]. It's a mask to be used if the input sequence length is
          smaller than the max input sequence length in the current batch. It's the mask that we
          typically use for attention when a batch has varying length sentences.
        - **labels_tok**: labels for the classification output: torch.LongTensor of shape
          [batch_size, sequence_length] with indices selected in [0, ..., num_labels_tok].
        - **masked_lm_labels**: masked language modeling labels: torch.LongTensor of shape
          [batch_size, sequence_length] with indices selected in [-1, 0, ..., vocab_size].
          All labels set to -1 are ignored (masked), the loss
          is only computed for the labels set in [0, ..., vocab_size]
        - **labels_cls**: labels for the classification output: torch.LongTensor of shape [batch_size]
          with indices selected in [0, ..., num_labels_cls].

    Outputs: 

        `Tuple` comprising various elements depending on the configuration (config) and inputs:

        - **loss**: (`optional`, returned when ``masked_lm_labels``, 
          ``labels_tok`` or ``labels_cls`` are provided) 
          `torch.FloatTensor`` of shape ``(1,)``: Combined loss over all heads and head types.
        - **prediction_scores**: ``torch.FloatTensor`` of shape ``(batch_size, sequence_length, config.vocab_size)``
          Prediction scores of the language modeling head (scores for each vocabulary token before SoftMax).
          prediction_scores is repeated for as many times as heads are created during instantiation using ``generate``
        - **tok_logits**: **scores**: ``torch.FloatTensor`` of shape ``(batch_size, sequence_length, num_labels_tok)``
          Classification scores (before SoftMax). tok_logits is repeated for as many times as heads are created during
          instantiation using ``num_labels_tok``
        - **cls_logits**: ``torch.FloatTensor`` of shape ``(batch_size, num_labels_cls)``
          Classification (or regression if num_labels_cls==1) scores (before SoftMax).
          cls_logits is repeated for as many times as heads are created during
          instantiation using ``num_labels_cls``
        - **gen_total_loss**: (`optional`, returned when ``masked_lm_labels``, ``labels_tok`` or ``labels_cls`` are provided)
          ``torch.FloatTensor`` of shape ``(1,)``: Combined loss over all generation heads.
        - **cls_total_loss**: (`optional`, returned when ``masked_lm_labels``, ``labels_tok`` or ``labels_cls`` are provided)
          ``torch.FloatTensor`` of shape ``(1,)``: Combined loss over all CLS / sequence generation heads.
        - **tok_total_loss**: (`optional`, returned when ``masked_lm_labels``, ``labels_tok`` or ``labels_cls`` are provided)
          ``torch.FloatTensor`` of shape ``(1,)``: Combined loss over all token classification heads.
        - **total_ppl**: (`optional`, returned when ``masked_lm_labels``, ``labels_tok`` or ``labels_cls`` are provided)
          ``torch.FloatTensor`` of shape ``(1,)``: Combined perplexity over all generation heads.
        - **hidden_states**: (`optional`, returned when ``config.output_hidden_states=True``) 
          list of ``torch.FloatTensor`` one for the output of each layer + the output of the embeddings) 
          of shape ``(batch_size, sequence_length, hidden_size)``: Hidden-states of the model at the 
          output of each layer plus the initial embedding outputs.
        - **attentions**: (`optional`, returned when ``config.output_attentions=True``) list of ``torch.FloatTensor`` 
          (one for each layer) of shape ``(batch_size, num_heads, sequence_length, sequence_length)``:
          Attentions weights after the attention softmax, used to compute the weighted average in the self-attention heads.

    Examples:

        Here is the way to use transformer heads::

            tokenizer = BertTokenizer.from_pretrained('bert-base-uncased')
            model = VariableHeadsmilie.from_pretrained('bert-base-uncased', generate=1)
            # Batch size 1
            input_ids = torch.tensor(tokenizer.encode("Hello, my dog is cute")).unsqueeze(0)
            outputs = model(input_ids, masked_lm_labels=input_ids)
            loss, prediction_scores = outputs[:2]
    """

    def __init__(self, config: PretrainedConfig,
                 generate: int = 0,
                 classify_sequence: int = 0,
                 classify_tokens: int = 0,
                 num_labels_tok: list = None,
                 num_labels_cls: list = None):
        """
        Loads a transformer encoder with various heads.

        :param config: The configuration for the transformer encoder
        :param generate: How many LM generation heads to use (same output vocabulary as input)
        :param classify_sequence: How many sequence classification heads to use.
        :param classify_tokens: How many token classification heads to use.
        :param num_labels_tok: The number of classes for the token classification heads.
                               Given as a list where the list length should be equal to the number
                               of token classification heads.
        :param num_labels_cls: The number of classes for the sequence classification heads.
                               Given as a list where the list length should be equal to the number
                               of sequence classification heads.
        """
        super(VariableHeadsmilie, self).__init__(config)

        # shared
        self.bert = BertModel(config)

        self.generate = generate
        self.classify_sequence = classify_sequence
        self.classify_tokens = classify_tokens
        self.num_labels_cls = num_labels_cls
        self.num_labels_tok = num_labels_tok

        # sequence generation
        if self.generate > 0:
            # the first head should be called cls so we can load the pre-trained weights
            # further heads are in self.lm_heads
            self.cls = BertOnlyMLMHead(self.config)
            self.lm_heads = nn.ModuleList()
            for _ in range(1, self.generate):
                # TODO support different config for different heads.
                self.lm_heads.append(BertOnlyMLMHead(self.config))
            self.tie_weights()

        # CLS classification
        if self.classify_sequence > 0:
            self.cls_dropout = nn.ModuleList()
            self.cls_classifier = nn.ModuleList()
            for i in range(self.classify_sequence):
                self.cls_dropout.append(nn.Dropout(self.config.hidden_dropout_prob))
                self.cls_classifier.append(nn.Linear(self.config.hidden_size, num_labels_cls[i]))

        # token classification
        if self.classify_tokens > 0:
            self.dropout = nn.ModuleList()
            self.classifier = nn.ModuleList()
            for i in range(self.classify_tokens):
                self.dropout.append(nn.Dropout(self.config.hidden_dropout_prob))
                self.classifier.append(nn.Linear(self.config.hidden_size, num_labels_tok[i]))

        self.init_weights()
        LOGGER.info("Using %d generation heads, %d token classification heads "
                    "and %d sequence/sequence pair classification/regression heads.",
                    self.generate, self.classify_sequence, self.classify_tokens)

    def resize_embeddings(self, num_tokens):
        """
        Sometimes a pre-trained model is missing tokens that we would like to use during fine-tuning
        This method allows us to add new dimensions and therefore new tokens.

        :param num_tokens: The total number of new tokens (should be alrger than current vocab_size)
        :return: the updated config
        """
        vocab_size = self.config.vocab_size
        if num_tokens is not None and vocab_size != num_tokens:
            # resize embeddings
            emb = self.resize_token_embeddings(num_tokens)
            bert_emb = self.bert.resize_token_embeddings(num_tokens)

            # it seems there is an output-only bias, which also should be resized accordingly.
            if self.generate > 0 and hasattr(self.cls.predictions, 'bias') \
                    and self.cls.predictions.bias.size(0) != num_tokens:
                self.cls.predictions.bias = nn.Parameter(torch.zeros(num_tokens))
            if self.generate > 1:
                for i, _ in enumerate(self.lm_heads):
                    if hasattr(self.lm_heads[i].predictions, 'bias') \
                            and self.lm_heads[i].predictions.bias.size(0) != num_tokens:
                        self.lm_heads[i].predictions.bias = nn.Parameter(torch.zeros(num_tokens))

            # sanity check
            assert emb.weight.shape == bert_emb.weight.shape, \
                (emb.weight.shape, bert_emb.weight.shape)
            assert num_tokens == self.cls.predictions.bias.size(0), \
                (self.cls.predictions.bias.size(0))
            assert num_tokens == self.config.vocab_size and num_tokens == \
                   self.bert.config.vocab_size, \
                (num_tokens, self.config.vocab_size, self.bert.config.vocab_size)
            LOGGER.info("Token Embeddings are resized from {} to {}.".format(vocab_size,
                                                                             num_tokens))
        return self.config

    def tie_weights(self):
        """
        Make sure we are sharing the input and output embeddings.
        Export to TorchScript can't handle parameter sharing so we are cloning them instead.

        :return: 0 on success
        """
        # tie_weights is called in model_utils.py if it exists, but we only actually want
        # to tie them when we have generation on
        if self.generate > 0:
            output_embeddings = self._get_embeddings()
            if output_embeddings is not None:
                self._tie_or_clone_weights(output_embeddings, self.bert.get_input_embeddings())
                if self.generate > 1:
                    for i, _ in enumerate(self.lm_heads):
                        output_embeddings = self._get_embeddings(i)
                        if output_embeddings is not None:
                            self._tie_or_clone_weights(output_embeddings,
                                                       self.bert.get_input_embeddings())
        return 0

    def _get_embeddings(self, i=None):
        """
        Returns the correct embedding.
        If i is None, the cls decoder embedding.
        If i has a value, the ith lm head's decoder embedding.

        :param i: None or ith lm head of the decoder embedding that we want
        :return: the decoder embedding.
        """
        if i is None:
            return self.cls.predictions.decoder
        else:
            return self.lm_heads[i].predictions.decoder

    def forward(self, input_ids: torch.LongTensor, attention_mask: torch.FloatTensor = None,
                token_type_ids: torch.LongTensor = None, position_ids: torch.LongTensor = None,
                head_mask: torch.FloatTensor = None, labels_tok: torch.tensor = None,
                masked_lm_labels: torch.tensor = None, labels_cls: torch.tensor = None):
        """
        Implements the forward pass for this class.

        labels_tok, masked_lm_labels, labels_cls: Have to be a list of lists. Each entry in the list
        should map to the corresponding head. Pass [-1] in the list if a training example should
        not effect a certain head. Pass overall None if no such head is used at all.

        :param input_ids: torch tensor of size batch x max_length, contains tokenized input as IDs
        :param attention_mask: torch tensor of size batch x max_length, contains 1 if position
                               should be attended to, else 0
        :param token_type_ids: torch tensor of size batch x max_length, traditionally contains 0 for
                               Part A and 1 for Part B, could be used to model other connections
        :param position_ids: from the transformer repo documentation: "Indices of positions of
                             each input sequence tokens in the position embeddings."
        :param head_mask: used to mask heads by the underlying BERT models,
                          we don't intend to use it here
        :param labels_tok: torch tensor of size self.generate x batch x max_length, -1 for tokens that we do not
                           want to predict on, else the classification that should be assigned to
                           that token
        :param masked_lm_labels: torch tensor of size self.generate x batch x max_length,
                                 -1 for tokens that we do not want to predict on,
                                 else the vocabulary ID this token should become
        :param labels_cls: torch tensor of size self.generate x batch x 1, contains the
                           classifcation that should be assigned to the entire input
                           (i.e. sequence or sequence pair classification)
        :return: A tuple which contains elements depending on the configuration and model call:

        (total_loss), (prediction_scores * self.generate), (tok_logits * self.classify_sequence),
        (cls_logits * self.classify_tokens), (gen_total_loss), (cls_total_loss), (tok_total_loss), (total_ppl),
        (hidden_states), (attentions)

        total_loss (torch.FloatTensor) is contained if any of the three where set when the 
        function was called: labels_tok, masked_lm_labels, labels_cls
        The same is true for: (gen_total_loss), (cls_total_loss), (tok_total_loss), (total_ppl)

        - For generation, prediction_scores:  (torch.LongTensor, batch x max_length):
          prediction logits for each generation head
        - For token classification, tok_logits:  (torch.LongTensor, batch x max_length):
          token logits for each tokenisation head
        - For sequence classification, classify_sequence:  (torch.LongTensor, batch x max_length):
          logits for each sequence/sequence pair head
        """
        #######################
        # Collect all outputs #
        #######################
        # outputs: sequence_output, pooled_output, (hidden_states), (attentions)

        model_outputs = self.bert(input_ids=input_ids,
                                  attention_mask=attention_mask,
                                  token_type_ids=token_type_ids,
                                  position_ids=position_ids,
                                  head_mask=head_mask)

        sequence_output = model_outputs[0]
        pooled_output = model_outputs[1]

        outputs = ()

        # Sequence Generation
        prediction_scores = []
        if self.generate > 0:
            # Iterate over all heads
            head_prediction_scores = self.cls(sequence_output)
            prediction_scores.append(head_prediction_scores)
            outputs = outputs + (head_prediction_scores,)
            if self.generate > 1:
                for i, _ in enumerate(self.lm_heads):
                    head_prediction_scores = self.lm_heads[i](sequence_output)
                    prediction_scores.append(head_prediction_scores)
                    outputs = outputs + (head_prediction_scores,)

        # CLS classification
        # BertForSequenceClassification uses dropout here but BertForNextSentencePrediction doesn't
        cls_logits = []
        if self.classify_sequence > 0:
            # Iterate over all heads
            for i in range(self.classify_sequence):
                pooled_output = self.cls_dropout[i](pooled_output)
                cls_logits_head = self.cls_classifier[i](pooled_output)
                cls_logits.append(cls_logits_head)
                outputs = outputs + (cls_logits_head,)

        # Token Classification
        tok_logits = []
        if self.classify_tokens > 0:
            # Iterate over all heads
            for i in range(self.classify_tokens):
                sequence_output_dropout = self.dropout[i](sequence_output)
                tok_logits_head = self.classifier[i](sequence_output_dropout)
                tok_logits.append(tok_logits_head)
                outputs = outputs + (tok_logits_head,)

        ###################################
        # Compute losses if gold provided #
        ###################################
        total_loss = 0.0
        ### !!! DataParallel / DistributedDataParallel expects a tuple of tensors as a return value of forward().
        ### Numbers (float) are not allowed, that is, returning (0.0, tensor(0), np.array([1])) will causes an error.
        ### Make sure that all return values are always a tensor with the same shape across devices.
        ### the return values are successfully reduced(gathered) only if all devices return the same shape of tensors.
        ### i.e. cuda:0 returns (tensor(0), tensor([1]), tensor([[0], [1]]))
        ###  and cuda:1 returns (tensor([0]), tensor(1), tensor([[1]])), then it raises an error.
        ### Consider to assign torch.tensor(0.0) here,
        ### instead of 0.0 (float).
        device = 'cuda:{}'.format(torch.cuda.current_device()) if torch.cuda.is_available() else 'cpu'
        gen_total_loss = torch.tensor(0.0, device=device)
        cls_total_loss = torch.tensor(0.0, device=device)
        tok_total_loss = torch.tensor(0.0, device=device)
        total_ppl = torch.tensor(0.0, device=device)

        # Sequence Generation
        if self.generate > 0 and masked_lm_labels is not None:
            # Iterate over sequence generation heads
            masked_lm_labels_reshaped = masked_lm_labels.permute(1, 0, 2)
            for i in range(self.generate):
                # Maybe this head does not have any labels, then skip
                if masked_lm_labels_reshaped[i] is not [-1]:
                    loss_fct = CrossEntropyLoss(ignore_index=-1)
                    masked_lm_labels_flatten = masked_lm_labels_reshaped[i].view(-1)
                    generation_loss = loss_fct(prediction_scores[i].view(-1, self.config.vocab_size),
                                               masked_lm_labels_flatten)
                    total_loss += generation_loss
                    gen_total_loss += generation_loss
                    # perplexity
                    num_masked_tokens = torch.gt(masked_lm_labels_flatten,
                                                 -1 * torch.ones_like(masked_lm_labels_flatten)).nonzero().size(0)
                    ppl = torch.tensor(0.0, device=device)
                    if num_masked_tokens > 0: # otherwise throw zero division exception or return nan
                        ppl = torch.exp(generation_loss.sum() / num_masked_tokens)
                    total_ppl += ppl

        # CLS classification
        if self.classify_sequence > 0 and labels_cls is not None:
            # Iterate over CLS classification heads
            labels_cls_reshaped = labels_cls.permute(1, 0)
            for i in range(self.classify_sequence):
                # Maybe this head does not have any labels, then skip
                if labels_cls_reshaped[i] is not [-1]:  # Untested (cant have none in tensor)
                    if self.num_labels_cls[i] == 1:
                        #  We are doing regression
                        loss_fct = MSELoss()
                        cls_loss = loss_fct(cls_logits[i].view(-1), labels_cls_reshaped[i].view(-1))
                    else:
                        loss_fct = CrossEntropyLoss(ignore_index=-1)
                        cls_loss = loss_fct(cls_logits[i].view(-1, self.num_labels_cls[i]),
                                            labels_cls_reshaped[i].view(-1))
                    total_loss += cls_loss
                    cls_total_loss += cls_loss

        # Token Classification
        if self.classify_tokens > 0 and labels_tok is not None:
            # Iterate over CLS classification heads
            # This operation potentially causes non-contiguous tensors, so we call contiguous below
            # this will copy data though, is there a way to change it?
            labels_tok_reshaped = labels_tok.permute(1, 0, 2)
            for i in range(self.classify_tokens):
                # Maybe this head does not have any labels, then skip
                if labels_tok_reshaped[i] is not [-1]:
                    loss_fct = CrossEntropyLoss(ignore_index=-1)
                    # Only keep active parts of the loss
                    if attention_mask is not None:
                        active_loss = attention_mask.view(-1) == 1
                        active_logits = tok_logits[i].view(-1, self.num_labels_tok[i])[active_loss]
                        active_labels = labels_tok_reshaped[i].contiguous().view(-1)[active_loss]
                        #print(f"{i}, {active_logits.size()}, {active_labels.size()}")
                        loss = loss_fct(active_logits, active_labels)
                    else:
                        loss = loss_fct(tok_logits[i].view(-1, self.num_labels_tok[i]),
                                        labels_tok_reshaped[i].contiguous().view(-1))
                    token_loss = loss
                    total_loss += token_loss
                    tok_total_loss += token_loss


        if masked_lm_labels is not None or labels_cls is not None or labels_tok is not None:
            outputs = (total_loss,) + outputs

            #assert torch.is_tensor(total_loss) and total_loss.is_cuda
            for i, loss in enumerate([gen_total_loss, cls_total_loss, tok_total_loss, total_ppl]):
                #assert torch.is_tensor(loss) and loss.is_cuda, (i, loss)
                outputs += (loss,)

        outputs = outputs + model_outputs[2:]  # add hidden states and attention if they are here

        # outputs = a tuple of
        # - total_loss,
        # - prediction_scores * self.generate,
        # - cls_logits * self.classify_tokens,
        # - tok_logits * self.classify_sequence,
        # - gen_total_loss, cls_total_loss, tok_total_loss, total_ppl,
        # - (hidden_states),
        # - (attentions)
        return outputs

    @classmethod
    def from_pretrained(cls, pretrained_model_name_or_path, *model_args, **kwargs):
        model = super(VariableHeadsmilie, cls).from_pretrained(pretrained_model_name_or_path,
                                                             *model_args, **kwargs)
        # if there is more than one generation head, copy the weights to there
        if model.generate > 1:
            for i, _ in enumerate(model.lm_heads):
                model.lm_heads[i].load_state_dict(model.cls.state_dict())
        return model


class Featuremilie(BertPreTrainedModel):
    """BERT model for computing sentence-level embeddings.
    
    Params:
        `config`: a BertConfig class instance with the configuration to build a new model.

    Inputs:
        - **input_ids**: a torch.LongTensor of shape [batch_size, sequence_length]
              with the word token indices in the vocabulary(see the tokens preprocessing logic in
              the scripts `extract_features.py`, `run_classifier.py` and `run_squad.py`)
        - **token_type_ids**: an optional torch.LongTensor of shape [batch_size, sequence_length] with
              the token types indices selected in [0, 1]. Type 0 corresponds to a `sentence A` and
              type 1 corresponds to a `sentence B` token (see BERT paper for more details).
        - **attention_mask**: an optional torch.LongTensor of shape [batch_size, sequence_length] with
              indices selected in [0, 1]. It's a mask to be used if the input sequence length is
              smaller than the max input sequence length in the current batch. It's the mask that we
              typically use for attention when a batch has varying length sentences.
        - **labels_tok**: labels for the classification output: torch.LongTensor of shape
              [batch_size, sequence_length] with indices selected in [0, ..., num_labels_tok].
        - **masked_lm_labels**: masked language modeling labels: torch.LongTensor of shape
              [batch_size, sequence_length] with indices selected in [-1, 0, ..., vocab_size].
              All labels set to -1 are ignored (masked), the loss is only computed for the labels set in [0, ..., vocab_size]
        - **labels_cls**: labels for the classification output: torch.LongTensor of shape [batch_size]
              with indices selected in [0, ..., num_labels_cls].

    Examples::
        tokenizer = BertTokenizer.from_pretrained('bert-base-uncased')
        model = Featuremilie.from_pretrained('bert-base-uncased')
        # Batch size 1
        input_ids = torch.tensor(tokenizer.encode("Hello, my dog is cute")).unsqueeze(0)
        outputs = model(input_ids, masked_lm_labels=input_ids)
        sequence_output, _ = outputs[:2]
    """

    def __init__(self, config):
        """
        Loads a transformer encoder with various heads.
        :param config: The configuration for the transformer encoder
        """
        super(Featuremilie, self).__init__(config)

        # shared
        self.bert = BertModel(config)
        self.cls = BertPreTrainingHeads(config)

        self.init_weights()


    def forward(self, input_ids=None, attention_mask=None, token_type_ids=None, position_ids=None, head_mask=None, inputs_embeds=None,
                masked_lm_labels=None, next_sentence_label=None, labels_tok=None, labels_cls=None):

        outputs = self.bert(input_ids,
                            attention_mask=attention_mask,
                            token_type_ids=token_type_ids,
                            position_ids=position_ids,
                            head_mask=head_mask,
                            inputs_embeds=inputs_embeds)

        # sequence_output, _ = outputs[:2]  
        # print("sequence_output", sequence_output.shape)


        # TODO: add average or max pooling layer (we will have to test what works better) 
        # that takes sequence_output as input, assume the return of this function is 
        # called pooled_output
        # pooled_output = AddLayer(sequence_output) []

        # pool of square window of size=3, stride=2
        # pooling_layer = nn.MaxPool2d(3, stride=2)

        #input = torch.randn(20, 16, 50, 32)
        # pooled_output = pooling_layer(sequence_output)

        # add hidden states and attention if they are here
        # outputs = (pooled_output,) + outputs[2:]
        # outputs = outputs[2:]  

        # For now, we do not have any loss function for this task, we just want to compute outputs
        # As a result, the function will onyl ever return outputs and never loss, 
        # so we should not call backward() on the loss

        return outputs  
        # (loss), prediction_scores, seq_relationship_score, (hidden_states), (attentions)
        # outputs[0] is the fixed length vector that encodes the input (the summary of one case)
        # This should be stored on disc in a compact/efficient manner, for all case summaries
        # Next, given one case summary: return the most similar cases. 
        # There are efficient libraries for this, e.g. Facebook similarity search faiss


    @classmethod
    def from_pretrained(cls, pretrained_model_name_or_path, *model_args, **kwargs):
        # print('pretrained_model_name_or_path', pretrained_model_name_or_path)
        # print('model_args', model_args)
        # print('kwargs',kwargs)
        # print('cls',cls)

        model = super(Featuremilie, cls).from_pretrained(pretrained_model_name_or_path,
                                                            *model_args, **kwargs)

        return model

# Contains information about which pre-trained models are available.
#ALL_MODELS = sum((tuple(conf.pretrained_config_archive_map.keys()) \
#                  for conf in (BertConfig, XLNetConfig, XLMConfig)), ())

# The keys are specified when a model training is called, (e.g. .sh scripts in example_files)
# The corresponding values invoke the needed config, model and tokenizer
MODEL_CLASSES = {
    'bert': (BertConfig, VariableHeadsmilie, BertTokenizer),
    'jabert_mecab': (BertConfig, VariableHeadsmilie, BertJapaneseTokenizer),
    'jabert_juman': (BertConfig, VariableHeadsmilie, JumanTokenizer),
    'jabert_sentencepiece': (BertConfig, VariableHeadsmilie, SentencePieceTokenizer),
    'xlnet': (XLNetConfig, XLNetLMHeadModel, XLNetTokenizer),
    'xlm': (XLMConfig, XLMWithLMHeadModel, XLMTokenizer),
    'kg': (BertConfig, VariableHeadsmilie, FB15KTokenizer),
    'distilbert': (DistilBertConfig, VariableHeadsmilie, DistilBertTokenizer),
    'bert_feature': (BertConfig, Featuremilie, BertTokenizer)
}
