# -*- coding: utf-8 -*
from pytorch_transformers import XLMPreTrainedModel
from transformers import XLMModel
from torch import nn
from torch.nn import CrossEntropyLoss


class XLMForQuestionAnswering(XLMPreTrainedModel):
    def __init__(self, config):
        super(XLMForQuestionAnswering, self).__init__(config)

        self.transformer = XLMModel.from_pretrained("xlm-mlm-100-1280")
        self.qa_outputs = nn.Linear(200, config.num_labels)

        self.init_weights()


    def forward(
            self,
            input_ids=None,
            token_type_ids=None,
            attention_mask=None,
            start_positions=None,
            end_positions=None,
            position_ids=None,
            head_mask=None,
            dis_model=None,
            langs=None,
            lengths=None,
            cache=None,
            inputs_embeds=None,
            output_attentions=None,
            output_hidden_states=None,
            return_dict=None,
    ):
        r"""
        start_positions (:obj:`torch.LongTensor` of shape :obj:`(batch_size,)`, `optional`):
            Labels for position (index) of the start of the labelled span for computing the token classification loss.
            Positions are clamped to the length of the sequence (:obj:`sequence_length`). Position outside of the
            sequence are not taken into account for computing the loss.
        end_positions (:obj:`torch.LongTensor` of shape :obj:`(batch_size,)`, `optional`):
            Labels for position (index) of the end of the labelled span for computing the token classification loss.
            Positions are clamped to the length of the sequence (:obj:`sequence_length`). Position outside of the
            sequence are not taken into account for computing the loss.
        """
        return_dict = False

        transformer_outputs = self.transformer(
            input_ids,
            attention_mask=attention_mask,
            langs=langs,
            token_type_ids=token_type_ids,
            position_ids=position_ids,
            lengths=lengths,
            cache=cache,
            head_mask=head_mask,
            inputs_embeds=inputs_embeds,
            output_attentions=output_attentions,
            output_hidden_states=output_hidden_states,
            return_dict=return_dict,
        )

        sequence_output = transformer_outputs[0]

        sequence_output_sem = dis_model.mean1(sequence_output)

        logits = self.qa_outputs(sequence_output_sem)
        start_logits, end_logits = logits.split(1, dim=-1)
        start_logits = start_logits.squeeze(-1)
        end_logits = end_logits.squeeze(-1)

        total_loss = None
        if start_positions is not None and end_positions is not None:
            # If we are on multi-GPU, split add a dimension
            if len(start_positions.size()) > 1:
                start_positions = start_positions.squeeze(-1)
            if len(end_positions.size()) > 1:
                end_positions = end_positions.squeeze(-1)
            # sometimes the start/end positions are outside our model inputs, we ignore these terms
            ignored_index = start_logits.size(1)
            start_positions.clamp_(0, ignored_index)
            end_positions.clamp_(0, ignored_index)

            loss_fct = CrossEntropyLoss(ignore_index=ignored_index)
            start_loss = loss_fct(start_logits, start_positions)
            end_loss = loss_fct(end_logits, end_positions)
            total_loss = (start_loss + end_loss) / 2

        output = (start_logits, end_logits) + transformer_outputs[1:]
        return ((total_loss,) + output) if total_loss is not None else output
