import torch
import torch.nn as nn

class S4DecoderBatch(nn.Module):
    """
        The decoder part for S4 model 
    """
    def __init__(self, hidden_size, output_size, padding_idx, num_layers=1, bidirectional=False, dropout=0.0, config=None):
        self.config=config
        super(S4DecoderBatch, self).__init__()
        self.hidden_size = hidden_size
        self.num_layers = num_layers
        self.bidirectional = bidirectional
        self.padding_idx = padding_idx
        self.embedding = nn.Embedding(output_size, hidden_size, padding_idx=padding_idx)
        self.hidden_times = num_layers * (1 + bidirectional)
        if (self.config["model"] == "gru"):
            self.model = nn.GRU(input_size=hidden_size, hidden_size=hidden_size, num_layers=num_layers, batch_first=True, dropout=dropout, bidirectional=bidirectional) 
            self.bn1 = nn.BatchNorm1d(hidden_size * (self.bidirectional+1))
            self.out = nn.Linear(hidden_size * (self.bidirectional+1), output_size)
        elif (self.config["model"] == "transformer"):
            decoder_layer = nn.TransformerDecoderLayer(d_model=hidden_size, nhead=1, batch_first=True)
            self.model = nn.TransformerDecoder(decoder_layer, num_layers=4)
            self.out = nn.Linear(hidden_size, output_size)
        self.softmax = nn.LogSoftmax(dim=-1) # (bsz, 1, dim), on dim

    def forward(self, input, embeddings):
        bsz = input.size(0)
        if (self.config["no_encoder"] == True):
            embeddings = torch.randn((bsz, 1, self.hidden_size)).to(device)
            #embeddings = torch.randn(embeddings.shape).to(device)
        if self.config["model"] == "gru":
            gru_input = self.embedding(input) # (bsz, max_len, dim)
            gru_input = F.relu(gru_input)
            gru_hidden = embeddings.reshape(bsz, 1, -1).to(device) # (bsz, 1, dim)
            gru_hidden = gru_hidden.repeat(1, (self.bidirectional+1)*self.num_layers, 1) # (bsz, D*num_layers, dim), done
            gru_hidden.transpose_(0, 1)
            gru_hidden = gru_hidden.contiguous()
            output, gru_hidden = self.gru(gru_input, gru_hidden) # output: (bsz, max_len, dim), hidden: (D*num_layers, bsz, dim)
            output = output.transpose(1, 2) # (bsz, dim, max_len)
            output = self.bn1(output)
            output = output.transpose(1, 2) # (bsz, max_len, dim)
            output = self.out(output) # (bsz, max_len, dim)
            output = self.softmax(output) # (bsz, max_len, dim)
        elif self.config["model"] == "transformer":
            tgt = self.embedding(input) # (bsz, max_len, dim)
            memory = embeddings.reshape(bsz, 1, -1).to(device).contiguous() # (bsz, 1, dim)
            output = self.model(tgt, memory)
            output = self.out(output)
            output = self.softmax(output)
        return output

    def initHidden(self, bsz):
        return torch.rand(bsz, self.num_layers, self.hidden_size)