import torch
from torch import nn
from transformers import AutoModel, AutoModelForSequenceClassification

from utils.data_loader import ModelInput

class BertClassifier(nn.Module):
  def __init__(self, bert_model, class_size, loss_type='regression'):
    super().__init__()
    self.loss_type = loss_type
    model = AutoModel.from_pretrained(bert_model, num_labels=class_size, problem_type='regression' if loss_type == 'regression' else 'single_label_classification')
    self.classification = AutoModelForSequenceClassification.from_config(model.config)
    if 'berta' in bert_model: self.classification.roberta = model
    elif 'bert' in bert_model: self.classification.bert = model
    else: self.classification.transformer = model

  def forward(self, input: ModelInput):
    if self.loss_type == 'classify': labels = torch.argmax(labels, -1)
    loss, logits, *_ = self.classification(input.input_ids, attention_mask=input.attention_mask, token_type_ids=input.segment_ids, labels=input.labels,return_dict=False)
    if self.loss_type == 'classify': logits = torch.softmax(logits, -1)
    return logits, loss, None, None # these "None"s are added because of to match with a count of STMT return tuple elements
