
import numpy
import torch
from torch import nn
import torch.nn.functional as F

from torch.nn.modules.linear import Linear

from transformers import BertModel, BertConfig
#from transformers import  PreTrainedModel

class PretrainedBertEmbedder(nn.Module):
    def __init__( self, pretrained_model ) :
        super().__init__()
        self.bert_model =BertModel.from_pretrained(pretrained_model)
        # model = PretrainedBertModel.load(pretrained_model)
        self.output_dim = self.bert_model.config.hidden_size

    def get_output_dim(self):
        return self.output_dim

    def forward( self,  input_ids, offsets=None ) :
        batch_size, full_seq_len = input_ids.size(0), input_ids.size(-1)
        input_mask = (input_ids != 0).long()
        bert_output = self.bert_model( input_ids=input_ids,   attention_mask=input_mask, output_hidden_states=False, return_dict=True  )
        top_layer = bert_output['last_hidden_state']
        if offsets is None:
            return top_layer 
        
        range_vector = torch.arange(0, offsets.size(0), dtype=torch.long).unsqueeze(1)
        if torch.cuda.is_available():
            range_vector = range_vector.cuda( input_ids.get_device() )
        # selected embeddings is also (batch_size * d1 * ... * dn, orig_sequence_length)
        selected_embeddings = top_layer[range_vector, offsets]
        return selected_embeddings

