import torch
from torch import nn
from models.common import GradReverse


class TextRNNR(nn.Module):
    def __init__(self, cfg, embeddings):
        super(TextRNNR, self).__init__()
        self.config = cfg.model
        self.embeddings = nn.Embedding(*embeddings.shape)
        self.embeddings.weight = nn.Parameter(embeddings, requires_grad=False)

        self.lstm = nn.LSTM(input_size=self.config.embed_size,
                            hidden_size=self.config.hidden_size,
                            num_layers=self.config.hidden_layers,
                            dropout=self.config.dropout_keep,
                            bidirectional=self.config.bidirectional)

    def forward(self, x):
        # x.shape = (max_sen_len, batch_size)
        embedded_sent = self.embeddings(x)
        # embedded_sent.shape = (max_sen_len=20, batch_size=64,embed_size=300)
        lstm_out, (h_n, c_n) = self.lstm(embedded_sent)
        # h_n shape=[num_layers * num_directions, batch_size, hidden_size], transform to [batch_size, others]
        h_n = h_n.permute(1, 0, 2)
        h_n = h_n.contiguous().view(h_n.shape[0], -1)
        return h_n


class TextRNNO(nn.Module):
    def __init__(self, cfg):
        super(TextRNNO, self).__init__()
        self.config = cfg.model
        self.dropout = nn.Dropout(self.config.dropout_keep)
        self.fc = nn.Linear(self.config.hidden_size * self.config.hidden_layers * (1 + self.config.bidirectional), self.config.output_size)

    def forward(self, x):
        return self.fc(self.dropout(x))

class u_TextRNNO(nn.Module):
    def __init__(self, cfg):
        super(u_TextRNNO, self).__init__()
        self.config = cfg.model
        self.dropout = nn.Dropout(self.config.dropout_keep)
        self.fc = nn.Linear(self.config.hidden_size * self.config.hidden_layers * (1 + self.config.bidirectional), self.config.output_size)
        self.logsigma = nn.Parameter(torch.FloatTensor([1.0]))

    def forward(self, x):
        return self.fc(self.dropout(x)), self.logsigma

class TextRNND(nn.Module):
    def __init__(self, cfg):
        super(TextRNND, self).__init__()
        self.config = cfg.model
        self.dropout = nn.Dropout(self.config.dropout_keep)
        self.fc = nn.Linear(self.config.hidden_layers * 2 * self.config.hidden_size, len(cfg.data['tasks']))

    def forward(self, x):
        return self.fc(self.dropout(GradReverse.apply(x)))
