from dataclasses import dataclass
from typing import Optional, Tuple

import torch
from torch import nn
from torch.nn import CrossEntropyLoss
from transformers import BertPreTrainedModel, BertModel
from transformers.file_utils import ModelOutput


class KoldModel(BertPreTrainedModel):

    # _keys_to_ignore_on_load_unexpected = [r"pooler"]

    def __init__(self, config):
        super().__init__(config)
        self.num_labels = config.num_labels
        self.num_pooled_labels = config.task_specific_params["num_pooled_labels"]

        self.bert = BertModel(config, add_pooling_layer=True)
        classifier_dropout = (
            config.classifier_dropout if config.classifier_dropout is not None else config.hidden_dropout_prob
        )
        self.dropout = nn.Dropout(classifier_dropout)

        self.pooled_classifier = nn.Linear(config.hidden_size, self.num_pooled_labels)

        self.classifier = nn.Linear(config.hidden_size, config.num_labels)

        # Initialize weights and apply final processing
        self.post_init()

    def forward(
        self,
        input_ids=None,
        attention_mask=None,
        token_type_ids=None,
        position_ids=None,
        head_mask=None,
        inputs_embeds=None,
        pooled_labels=None,
        labels=None,
        output_attentions=None,
        output_hidden_states=None,
        return_dict=None,
    ):
        r"""
        labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
            Labels for computing the token classification loss. Indices should be in `[0, ..., config.num_labels - 1]`.
        """
        return_dict = return_dict if return_dict is not None else self.config.use_return_dict

        outputs = self.bert(
            input_ids,
            attention_mask=attention_mask,
            token_type_ids=token_type_ids,
            position_ids=position_ids,
            head_mask=head_mask,
            inputs_embeds=inputs_embeds,
            output_attentions=output_attentions,
            output_hidden_states=output_hidden_states,
            return_dict=return_dict,
        )
        pooled_output = outputs['pooler_output']
        pooled_output = self.dropout(pooled_output)
        pooled_logits = self.pooled_classifier(pooled_output)

        sequence_output = outputs['last_hidden_state']
        sequence_output = self.dropout(sequence_output)
        logits = self.classifier(sequence_output)

        loss = None
        if pooled_labels is not None and labels is not None:
            loss_fct = CrossEntropyLoss()
            seq_loss = loss_fct(pooled_logits.view(-1, self.num_pooled_labels), pooled_labels.view(-1))
            # Only keep active parts of the loss
            if attention_mask is not None:
                active_loss = attention_mask.view(-1) == 1
                active_logits = logits.view(-1, self.num_labels)    # (B, L, C) -> (BxL, C)
                active_labels = torch.where(active_loss, labels.view(-1),
                                            torch.tensor(loss_fct.ignore_index).type_as(labels))  # (BxL)
                span_loss = loss_fct(active_logits, active_labels)
            else:
                span_loss = loss_fct(logits.view(-1, self.num_labels), labels.view(-1))

            loss = seq_loss + span_loss

        if not return_dict:
            output = (pooled_logits, logits,) + outputs[2:]
            return ((loss,) + output) if loss is not None else output

        return HateSpanOutput(
            loss=loss,
            pooled_logits=pooled_logits,
            logits=logits,
            hidden_states=outputs.hidden_states,
            attentions=outputs.attentions,
        )


@dataclass
class KoldOutput(ModelOutput):

    loss: Optional[torch.FloatTensor] = None
    pooled_logits: torch.FloatTensor = None
    logits: torch.FloatTensor = None
    hidden_states: Optional[Tuple[torch.FloatTensor]] = None
    attentions: Optional[Tuple[torch.FloatTensor]] = None
