#!/usr/bin/env python
import torch
import torch.nn as nn


class MLP1(nn.Module):
    def __init__(self, input_dim=300, hidden_dim=300,output_dim=300,act_fn=nn.LeakyReLU(),dropout_rate=0.0):
        super(MLP1, self).__init__()

        self.fc1 = nn.Linear(input_dim, hidden_dim)
        self.fc2 = nn.Linear(hidden_dim, output_dim)

        self.act_fn = act_fn
        self.dropout = nn.Dropout(0.5)

    def forward(self, x):
        out = self.fc1(x)
        out = self.act_fn(out)
        out = self.dropout(out)

        out = self.fc2(out)

        return out

class MLP(nn.Module):
    def __init__(self, input_dim=300, hidden_dim=512,output_dim=300,keep_dropout=0.3):
        super(MLP, self).__init__()
        self.fn = nn.Sequential(
            nn.Dropout(keep_dropout),
            nn.Linear(input_dim,output_dim),
            #nn.Linear(input_dim,hidden_dim),
            #nn.LeakyReLU(),
            #nn.Dropout(keep_dropout),
            #nn.Linear(hidden_dim,output_dim)
        )
        #self.fn[1].weight.data.normal_(0,0.01)
        #self.fn[1].bias.data.normal_(0,0.01)

    def forward(self, x):
        out = self.fn(x)
        return out


class ACT(nn.Module):
    def __init__(self, input_dim=300, hidden_dim=300,act_fn=nn.LeakyReLU(),dropout_rate=0.0):
        super(ACT, self).__init__()

        self.fc1 = nn.Linear(input_dim, hidden_dim)

        self.act_fn = act_fn
        self.dropout = nn.Dropout(dropout_rate)

    def forward(self, x):
        out = self.fc1(x)
        out = self.act_fn(out)
        #out = self.dropout(out)

        return out
