
import torch
import torch.nn as nn
import torch.nn.functional as F
from transformers import AutoModel, AutoTokenizer,AutoConfig

class IRLCBert(nn.Module):
    def __init__(self,args=None):
        super(IRLCBert, self).__init__()
        self.transformer = AutoModel.from_pretrained(args.bert)
        self.tokenizer = AutoTokenizer.from_pretrained(args.bert)
        self.config=AutoConfig.from_pretrained(args.bert)
        special_tokens_dict = {'additional_special_tokens': ['<e1>', '</e1>', '<e2>', '</e2>']} 
        self.tokenizer.add_special_tokens(special_tokens_dict)
        self.transformer.resize_token_embeddings(len(self.tokenizer))
        self.emb_size=self.config.hidden_size*2
        self.layerNorm = nn.LayerNorm(self.emb_size)
        self.drop = nn.Dropout(0.6)

        # Instance-CL head
        self.head = nn.Sequential(
            nn.Linear(self.emb_size, self.emb_size),
            nn.ReLU(inplace=True),
            nn.Linear(self.emb_size, 256))

        # Cluster-CL head
        self.cluster_head = nn.Sequential(
            nn.Linear(self.emb_size, self.emb_size),
            nn.ReLU(inplace=True),
            nn.Linear(self.emb_size, 10))


    def get_sentence_encoding(self,features):
        input_ids=features['input_ids']
        token_type_ids=features['token_type_ids']
        attention_mask=features['attention_mask']
        e1_idx=features['e1_idx']
        e2_idx=features['e2_idx']
        output = self.transformer(input_ids,token_type_ids=token_type_ids,attention_mask=attention_mask)
        last_hidden = output[0]
        batch_size=e1_idx.shape[0]
        e1_encoding = last_hidden[torch.arange(batch_size),e1_idx[:,0]]
        e2_encoding = last_hidden[torch.arange(batch_size),e2_idx[:,0]]
        sentence_encoding = torch.cat([e1_encoding,e2_encoding], dim=-1)
        sentence_encoding = self.drop(sentence_encoding)
        sentence_encoding=self.layerNorm(sentence_encoding)
        return sentence_encoding
    







