
import torch
from torch import nn
from transformers.models.bert import BertPreTrainedModel, BertModel
from transformers.utils import logging


from model.model_output import RelationAwareClassificationModelOutput

logger = logging.get_logger(__name__)


class BertForRelationAwareClassification(BertPreTrainedModel):

    def __init__(self, config):
        super().__init__(config)
        self.bert = BertModel(config, add_pooling_layer=True)
        self.class_outputs = nn.Linear(config.hidden_size, 5)
        self.span_outputs = nn.Linear(config.hidden_size, 3)
        self.dropout = nn.Dropout(config.hidden_dropout_prob)

        self.model_type: str = config.model_type


        self.init_weights()

    def forward(
        self,
        input_ids=None,
        attention_mask=None,
        token_type_ids=None,
        position_ids=None,
        head_mask=None,
        inputs_embeds=None,
        class_label=None,
        span_label=None,
        p_mask=None,
    ) -> RelationAwareClassificationModelOutput:
        outputs = self.bert(
            input_ids,
            attention_mask=attention_mask,
            token_type_ids=token_type_ids,
            position_ids=position_ids,
            head_mask=head_mask,
            inputs_embeds=inputs_embeds,
            output_attentions=False,
            output_hidden_states=True,
            return_dict=True,
        )

        sequence_output = outputs.last_hidden_state
        pooled_output = outputs.pooler_output

        pooled_output = self.dropout(pooled_output)
        logits_cls = self.class_outputs(pooled_output)

        sequence_output = self.dropout(sequence_output)
        logits_span = self.span_outputs(sequence_output)

        if class_label is not None:
            assert p_mask is not None
            assert span_label is not None

            lam=1
            loss_fct = nn.CrossEntropyLoss()

            loss_cls = loss_fct(logits_cls, class_label)

            loss_fct = nn.CrossEntropyLoss()
            active_logits = logits_span.view(-1, 3)
            active_labels = torch.where(
                p_mask.view(-1) == 0, span_label.view(-1),
                torch.tensor(loss_fct.ignore_index).type_as(span_label)
            )
            loss_speaker = loss_fct(active_logits, active_labels)
            if lam==1:
                loss = loss_cls + loss_speaker
            else:
                loss = loss_cls + lam * loss_speaker
        else:
            loss, loss_cls, loss_speaker = None, None, None

        return RelationAwareClassificationModelOutput(
            loss=loss,
            loss_cls=loss_cls,
            loss_span=loss_speaker,
            class_logits=logits_cls,
            span_logits=logits_span
        )
