import torch
from torch import nn
from transformers import AutoModel


# Response Selection Model
class ResponseSelection(nn.Module):
    """
    Response Selection model for dialogue response selection task

    :param pretrained_model: pretrained model
    """
    def __init__(self, pretrained_model: AutoModel):
        super(ResponseSelection, self).__init__()
        
        self.pretrained_model = pretrained_model
        self.linear = torch.nn.Linear(768, 1, bias=False)

    def forward(self, ids, mask, return_output=False):
        output, _ = self.pretrained_model(ids, mask, return_dict=False)
        cls_ = output[:, 0]
        
        if return_output:
            return self.linear(cls_), cls_
        else:
            return self.linear(cls_)

    def get_cls_repr(self, ids, mask):
        output, _ = self.pretrained_model(ids, mask, return_dict=False)
        cls_ = output[:, 0]

        return cls_


# Response Selection Model for Multi-task learning
class ResponseSelectionforMultiTask(nn.Module):
    """
    Response Selection model for (1) dialogue response selection task and (2) negative type classification task

    :param pretrained_model: pretrained model
    """
    def __init__(self, pretrained_model: AutoModel):
        super(ResponseSelectionforMultiTask, self).__init__()
        
        self.pretrained_model = pretrained_model
        self.linear = torch.nn.Linear(768, 1, bias=False)
        self.negtypeclassifier = torch.nn.Linear(768, 2, bias=False)

    def forward(self, ids, mask, return_output=False):
        output, _ = self.pretrained_model(ids, mask, return_dict=False)
        cls_ = output[:, 0]
        
        if return_output:
            return self.linear(cls_), self.negtypeclassifier(cls_), cls_
        else:
            return self.linear(cls_), self.negtypeclassifier(cls_)
