import torch
from torch import nn
from models.common import GradReverse


class TextR(nn.Module):
    def __init__(self, cfg, embeddings):
        super(TextR, self).__init__()
        self.config = cfg.model
        self.embeddings = nn.Embedding(*embeddings.shape)
        self.embeddings.weight = nn.Parameter(embeddings, requires_grad=False)
        self.dropout = nn.Dropout(self.config.dropout)
        self.conv1 = nn.Sequential(
            nn.Conv1d(in_channels=self.config.embed_size, out_channels=self.config.num_channels, kernel_size=self.config.kernel_size[0]),
            nn.ReLU(),
            nn.MaxPool1d(self.config.max_sen_len - self.config.kernel_size[0] + 1)
        )
        self.conv2 = nn.Sequential(
            nn.Conv1d(in_channels=self.config.embed_size, out_channels=self.config.num_channels, kernel_size=self.config.kernel_size[1]),
            nn.ReLU(),
            nn.MaxPool1d(self.config.max_sen_len - self.config.kernel_size[1] + 1)
        )
        self.conv3 = nn.Sequential(
            nn.Conv1d(in_channels=self.config.embed_size, out_channels=self.config.num_channels, kernel_size=self.config.kernel_size[2]),
            nn.ReLU(),
            nn.MaxPool1d(self.config.max_sen_len - self.config.kernel_size[2] + 1)
        )

    def forward(self, x):
        # x.shape = (max_sen_len, batch_size)
        embedded_sent = self.dropout(self.embeddings(x).permute(1, 2, 0))
        # embedded_sent.shape = (batch_size=64,embed_size=300,max_sen_len=20)
        conv_out1 = self.conv1(embedded_sent).squeeze(2)  # shape=(64, num_channels, 1) (squeeze 1)
        conv_out2 = self.conv2(embedded_sent).squeeze(2)
        conv_out3 = self.conv3(embedded_sent).squeeze(2)
        return torch.cat((conv_out1, conv_out2, conv_out3), 1) 


class TextO(nn.Module):
    def __init__(self, cfg):
        super(TextO, self).__init__()
        self.config = cfg.model
        # Embedding Layer
        self.dropout = nn.Dropout(self.config.dropout)
        self.fc = nn.Linear(self.config.num_channels * len(self.config.kernel_size), self.config.output_size)

    def forward(self, x):
        return self.fc(self.dropout(x))

class u_TextO(nn.Module):
    def __init__(self, cfg):
        super(u_TextO, self).__init__()
        self.config = cfg.model
        # Embedding Layer
        self.dropout = nn.Dropout(self.config.dropout)
        self.fc = nn.Linear(self.config.num_channels * len(self.config.kernel_size), self.config.output_size)
        self.logsigma = nn.Parameter(torch.FloatTensor([1.0]))

    def forward(self, x):
        return self.fc(self.dropout(x)), self.logsigma

class TextD(nn.Module):
    def __init__(self, cfg):
        super(TextD, self).__init__()
        self.config = cfg.model
        self.dropout = nn.Dropout(self.config.dropout)
        self.fc = nn.Linear(self.config.num_channels * len(self.config.kernel_size), len(cfg.data['tasks']))

    def forward(self, x):
        return self.fc(self.dropout(GradReverse.apply(x)))
