import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.nn import CrossEntropyLoss
from transformers import BertPreTrainedModel,RobertaModel, BertModel, RobertaForMaskedLM, AutoModel


class DMBert(nn.Module):
    def __init__(self, model_name, dropout, tokenizer_size,  num_labels=5):
        super().__init__()
        self.bert = AutoModel.from_pretrained(model_name)
        self.bert.resize_token_embeddings(tokenizer_size)
        self.dropout = nn.Dropout(dropout)
        self.maxpooling = nn.MaxPool1d(160)
        self.classifier = nn.Sequential(
            nn.Linear(self.bert.config.hidden_size * 2, 256),
            nn.ReLU(),
            nn.Linear(256, 64),
            nn.ReLU(),
            nn.Linear(64, num_labels)
        )
        self.num_labels = num_labels

    def forward(self, input_ids=None, attention_mask=None, maskL=None, maskR=None):
        batch_size = input_ids.size(0)
        outputs = self.bert(
            input_ids,
            attention_mask=attention_mask,
        )
        conved = outputs[0]
        conved = conved.transpose(1, 2)     
        conved = conved.transpose(0, 1)     
        L = (conved * maskL).transpose(0, 1) 
        R = (conved * maskR).transpose(0, 1) 
        L = L + torch.ones_like(L)
        R = R + torch.ones_like(R)
        pooledL = self.maxpooling(L).contiguous().view(batch_size, self.bert.config.hidden_size)
        pooledR = self.maxpooling(R).contiguous().view(batch_size, self.bert.config.hidden_size)
        pooled = torch.cat((pooledL, pooledR), 1) 
        pooled = pooled - torch.ones_like(pooled)
        pooled = self.dropout(pooled)
        logits = self.classifier(pooled)
        reshaped_logits = logits.view(-1, self.num_labels)
        return reshaped_logits  
    

class RawBert(nn.Module):
    def __init__(self, model_name, dropout, tokenizer_size,  num_labels=5):
        super().__init__()
        self.bert = AutoModel.from_pretrained(model_name)
        self.bert.resize_token_embeddings(tokenizer_size)
        self.classifier = nn.Sequential(
            nn.Linear(self.bert.config.hidden_size, 256),
            nn.ReLU(),
            nn.Linear(256, 64),
            nn.ReLU(),
            nn.Linear(64, num_labels)
        )
        self.num_labels = num_labels

    def forward(self, input_ids=None, attention_mask=None, maskL=None, maskR=None):
        outputs = self.bert(
            input_ids,
            attention_mask=attention_mask,
        )
        conved = outputs[0] 
        logits = self.classifier(conved[:, 0, :]) 
        reshaped_logits = logits.view(-1, self.num_labels)
        return reshaped_logits 


class RawBertRelation(nn.Module):
    def __init__(self, model_name, dropout, tokenizer_size,  num_labels=5):
        super().__init__()
        self.bert = AutoModel.from_pretrained(model_name)
        self.bert.resize_token_embeddings(tokenizer_size)
        self.classifier = nn.Sequential(
            nn.Linear(self.bert.config.hidden_size * 3, 256),
            nn.ReLU(),
            nn.Linear(256, 64),
            nn.ReLU(),
            nn.Linear(64, num_labels)
        )
        self.num_labels = num_labels

    def forward(self, input_ids=None, attention_mask=None, cause_ids=None, precondition_ids=None, cause_mask=None, precondition_mask=None, maskL=None, maskR=None):
        outputs = self.bert(
            input_ids,
            attention_mask=attention_mask,
        )
        cause_relation = self.bert(
            cause_ids,
            attention_mask=cause_mask,
        )
        cause_relation = cause_relation[0]
        cause_relation = cause_relation[:, 0, :]

        precondition_relation = self.bert(
            precondition_ids,
            attention_mask=precondition_mask,
        )
        precondition_relation = precondition_relation[0] 
        precondition_relation = precondition_relation[:, 0, :]

        conved = outputs[0]
        logits = self.classifier(torch.cat((conved[:, 0, :], cause_relation, precondition_relation), 1)) 
        reshaped_logits = logits.view(-1, self.num_labels)
        return reshaped_logits 


class DMBertRelation(nn.Module):
    def __init__(self, model_name, dropout, tokenizer_size,  num_labels=5):
        super().__init__()
        self.bert = AutoModel.from_pretrained(model_name)
        self.bert.resize_token_embeddings(tokenizer_size)
        self.dropout = nn.Dropout(dropout)
        self.maxpooling = nn.MaxPool1d(160)
        self.classifier = nn.Sequential(
            nn.Linear(self.bert.config.hidden_size * 4, 256),
            nn.ReLU(),
            nn.Linear(256, 64),
            nn.ReLU(),
            nn.Linear(64, num_labels)
        )
        self.num_labels = num_labels

    def forward(self, input_ids, cause_ids, precondition_ids, cause_mask, precondition_mask, attention_mask=None, maskL=None, maskR=None):
        batch_size = input_ids.size(0)
        outputs = self.bert(
            input_ids,
            attention_mask=attention_mask,
        )
        cause_outputs = self.bert(
            cause_ids,
            attention_mask=cause_mask,
        )
        cause_conved = cause_outputs[0] 
        pooled_cause = cause_conved[:, 0, :] 
       
        precondition_outputs = self.bert(
            precondition_ids,
            attention_mask=precondition_mask,
        )
        precondition_conved = precondition_outputs[0]
        pooled_precondition = precondition_conved[:, 0, :]
        conved = outputs[0] 
        conved = conved.transpose(1, 2)     
        conved = conved.transpose(0, 1)     
        L = (conved * maskL).transpose(0, 1) 
        R = (conved * maskR).transpose(0, 1) 
        L = L + torch.ones_like(L)
        R = R + torch.ones_like(R)
        pooledL = self.maxpooling(L).contiguous().view(batch_size, self.bert.config.hidden_size) 
        pooledR = self.maxpooling(R).contiguous().view(batch_size, self.bert.config.hidden_size)
        pooled = torch.cat((pooledL, pooledR), 1) 
        pooled = pooled - torch.ones_like(pooled)
        pooled = torch.cat((pooled, pooled_cause, pooled_precondition), 1)
        pooled = self.dropout(pooled)
        logits = self.classifier(pooled)
        reshaped_logits = logits.view(-1, self.num_labels)
        return reshaped_logits   


class RawBertArg(nn.Module):
    def __init__(self, model_name, dropout, tokenizer_size, num_labels=5):
        super().__init__()
        self.bert = AutoModel.from_pretrained(model_name)
        self.bert.resize_token_embeddings(tokenizer_size)
        self.classifier = nn.Sequential(
            nn.Linear(self.bert.config.hidden_size * 2, 256),
            nn.ReLU(),
            nn.Linear(256, 64),
            nn.ReLU(),
            nn.Linear(64, num_labels)
        )
        self.num_labels = num_labels

    def forward(self, input_ids=None, attention_mask=None, arg_ids=None, arg_mask=None, maskL=None, maskR=None):
        outputs = self.bert(
            input_ids,
            attention_mask=attention_mask,
        )
        arg_outputs = self.bert(
            arg_ids,
            attention_mask=arg_mask,
        )
        arg_conved = arg_outputs[0]
        pooled_arg = arg_conved[:, 0, :]
        conved = outputs[0]
        logits = self.classifier(torch.cat((conved[:, 0, :], pooled_arg), 1))
        reshaped_logits = logits.view(-1, self.num_labels)
        return reshaped_logits
    

class DMBertArg(nn.Module):
    def __init__(self, model_name, dropout, tokenizer_size, num_labels=5):
        super().__init__()
        self.bert = AutoModel.from_pretrained(model_name)
        self.bert.resize_token_embeddings(tokenizer_size)
        self.dropout = nn.Dropout(dropout)
        self.maxpooling = nn.MaxPool1d(160)
        self.classifier = nn.Sequential(
            nn.Linear(self.bert.config.hidden_size * 3, 256),
            nn.ReLU(),
            nn.Linear(256, 64),
            nn.ReLU(),
            nn.Linear(64, num_labels)
        )
        self.num_labels = num_labels

    def forward(self, input_ids=None, attention_mask=None, arg_ids=None, arg_mask=None, maskL=None, maskR=None):
        batch_size = input_ids.size(0)
        outputs = self.bert(
            input_ids,
            attention_mask=attention_mask,
        )
        arg_outputs = self.bert(
            arg_ids,
            attention_mask=arg_mask,
        )
        arg_conved = arg_outputs[0]
        pooled_arg = arg_conved[:, 0, :]
        conved = outputs[0]
        conved = conved.transpose(1, 2)
        conved = conved.transpose(0, 1)
        L = (conved * maskL).transpose(0, 1)
        R = (conved * maskR).transpose(0, 1)
        L = L + torch.ones_like(L)
        R = R + torch.ones_like(R)
        pooledL = self.maxpooling(L).contiguous().view(batch_size, self.bert.config.hidden_size)
        pooledR = self.maxpooling(R).contiguous().view(batch_size, self.bert.config.hidden_size)
        pooled = torch.cat((pooledL, pooledR), 1)
        pooled = pooled - torch.ones_like(pooled)
        pooled = torch.cat((pooled, pooled_arg), 1)
        pooled = self.dropout(pooled)
        logits = self.classifier(pooled)
        reshaped_logits = logits.view(-1, self.num_labels)
        return reshaped_logits
    

class DMBertArgRelation(nn.Module):
    def __init__(self, model_name, dropout, tokenizer_size, num_labels=5):
        super().__init__()
        self.bert = AutoModel.from_pretrained(model_name)
        self.bert.resize_token_embeddings(tokenizer_size)
        self.dropout = nn.Dropout(dropout)
        self.maxpooling = nn.MaxPool1d(160)
        self.classifier = nn.Sequential(
            nn.Linear(self.bert.config.hidden_size * 5, 256),
            nn.ReLU(),
            nn.Linear(256, 64),
            nn.ReLU(),
            nn.Linear(64, num_labels)
        )
        self.num_labels = num_labels

    def forward(self, input_ids=None, attention_mask=None, arg_ids=None, arg_mask=None, cause_ids=None, precondition_ids=None, cause_mask=None, precondition_mask=None, maskL=None, maskR=None):
        batch_size = input_ids.size(0)
        outputs = self.bert(
            input_ids,
            attention_mask=attention_mask,
        )
        arg_outputs = self.bert(
            arg_ids,
            attention_mask=arg_mask,
        )
        arg_conved = arg_outputs[0]
        pooled_arg = arg_conved[:, 0, :]
        cause_outputs = self.bert(
            cause_ids,
            attention_mask=cause_mask,
        )
        cause_conved = cause_outputs[0]
        pooled_cause = cause_conved[:, 0, :]
        precondition_outputs = self.bert(
            precondition_ids,
            attention_mask=precondition_mask,
        )
        precondition_conved = precondition_outputs[0]
        pooled_precondition = precondition_conved[:, 0, :]
        conved = outputs[0]
        conved = conved.transpose(1, 2)
        conved = conved.transpose(0, 1)
        L = (conved * maskL).transpose(0, 1)
        R = (conved * maskR).transpose(0, 1)
        L = L + torch.ones_like(L)
        R = R + torch.ones_like(R)
        pooledL = self.maxpooling(L).contiguous().view(batch_size, self.bert.config.hidden_size)
        pooledR = self.maxpooling(R).contiguous().view(batch_size, self.bert.config.hidden_size)
        pooled = torch.cat((pooledL, pooledR), 1)
        pooled = pooled - torch.ones_like(pooled)
        pooled = torch.cat((pooled, pooled_arg, pooled_cause, pooled_precondition), 1)
        pooled = self.dropout(pooled)
        logits = self.classifier(pooled)
        reshaped_logits = logits.view(-1, self.num_labels)
        return reshaped_logits