#!/usr/bin/env python
import torch
import torch.nn as nn


class GraphConvolution(nn.Module):
    def __init__( self, input_dim, \
                        output_dim, \
                        support, \
                        act_func = nn.LeakyReLU(), \
                        featureless = False, \
                        dropout_rate = 0., \
                        bias=False,
                        pre=None):
        super(GraphConvolution, self).__init__()
        self.support = support
        self.featureless = featureless
        hidden_dim = output_dim
        self.g_1 = nn.Linear(hidden_dim,hidden_dim)
        self.t_1 = nn.Linear(hidden_dim,hidden_dim)

        for i in range(len(self.support)):
            if pre is None:
                print("Random initialize ...")
                setattr(self, 'W{}'.format(i), nn.Parameter(torch.randn(input_dim, output_dim)))
            else:
                print("Load from glove ...")
                setattr(self, 'W{}'.format(i), nn.Parameter(torch.from_numpy(pre).float()))

        if bias:
            self.b = nn.Parameter(torch.zeros(1, output_dim))

        self.act_func = act_func
        self.dropout = nn.Dropout(dropout_rate)
        self.hw_act = nn.LeakyReLU()

    def forward(self, x):

        for i in range(len(self.support)):
            if self.featureless:
                pre_sup = getattr(self, 'W{}'.format(i))
            else:
                pre_sup = x.mm(getattr(self, 'W{}'.format(i)))

            if i == 0:
                out = self.support[i].mm(pre_sup)
            else:
                out += self.support[i].mm(pre_sup)

        if self.act_func is not None:
            out = self.act_func(out)

        if not self.featureless:
            t_1 = torch.sigmoid(self.t_1(x))
            out = t_1*out+ (1.0 - t_1)*x

        out = self.dropout(out)
        self.embedding = out
        return out


class GraphConvolution1(nn.Module):
    def __init__( self, input_dim, \
                        output_dim, \
                        support, \
                        act_func = nn.LeakyReLU(), \
                        featureless = False, \
                        dropout_rate = 0., \
                        bias=False,
                        pre=None):
        super(GraphConvolution1, self).__init__()
        self.support = support
        self.featureless = featureless
        self.act_func = act_func
        self.dropout = nn.Dropout(dropout_rate)
        self.hw_act = nn.LeakyReLU()
        hidden_dim = output_dim
        self.g_1 = nn.ModuleList([nn.Linear(hidden_dim,hidden_dim)] for _ in range(len(self.support)))
        self.t_1 = nn.ModuleList([nn.Linear(hidden_dim,hidden_dim)] for _ in range(len(self.support)))

        for i in range(len(self.support)):
            if pre is None:
                print("Random initialize ...")
                setattr(self, 'W{}'.format(i), nn.Parameter(torch.randn(input_dim, output_dim)))
            else:
                print("Load from glove ...")
                setattr(self, 'W{}'.format(i), nn.Parameter(torch.from_numpy(pre).float()))

        if bias:
            self.b = nn.Parameter(torch.zeros(1, output_dim))


    def forward(self, x):
        if self.featureless:
            pre_sup = getattr(self, 'W{}'.format(0))
        else:
            pre_sup = x.mm(getattr(self, 'W{}'.format(0)))

        out = self.support[0].mm(pre_sup)

        for i in range(1,len(self.support)):
            t_1 = torch.sigmoid(self.t_1(pre_sup))
            out = t_1*out+ (1.0 - t_1)*x
            out = out + self.support[i].mm(pre_sup)

        if self.act_func is not None:
            out = self.act_func(out)

        if not self.featureless:
            t_1 = torch.sigmoid(self.t_1(x))
            out = t_1*out+ (1.0 - t_1)*x

        self.embedding = out
        return out



class WGCN(nn.Module):
    def __init__( self, input_dim, \
                        support,\
                        dropout_rate=0.5, \
                        num_classes=8,
                        hidden_dim=300,
                        pre=None,
                        masks=None):
        super(WGCN, self).__init__()

        # GraphConvolution
        if masks is not None:
            supports = []
            for i in range(len(masks)):
                supports.append(support[0]*masks[i].cuda())
        else:
            supports = [support[0],support[0]]
        self.layer1 = GraphConvolution(input_dim, hidden_dim, [supports[0]], act_func=nn.LeakyReLU(), featureless=True, dropout_rate=dropout_rate,pre=pre)
        #self.layer2 = GraphConvolution(input_dim, hidden_dim, [supports[1]], act_func=nn.LeakyReLU(), featureless=True, dropout_rate=dropout_rate,pre=pre)
        self.layer2 = GraphConvolution(hidden_dim, hidden_dim,[supports[1]],act_func=nn.LeakyReLU(), dropout_rate=dropout_rate)
        self.layer3 = GraphConvolution(hidden_dim, hidden_dim,[supports[1]],act_func=nn.LeakyReLU(), dropout_rate=dropout_rate)
        self.layer4 = GraphConvolution(hidden_dim, hidden_dim,[supports[1]],act_func=nn.LeakyReLU(), dropout_rate=dropout_rate)
        #self.layer5 = GraphConvolution(hidden_dim, hidden_dim,[supports[1]],act_func=nn.LeakyReLU(), dropout_rate=dropout_rate)
        self.layer_norm1 = nn.LayerNorm(hidden_dim,1e-6)
        self.layer_norm2 = nn.LayerNorm(hidden_dim,1e-6)
        self.layer_norm3 = nn.LayerNorm(hidden_dim,1e-6)
        self.layer_norm4 = nn.LayerNorm(hidden_dim,1e-6)
        #self.layer5 = GraphConvolution(hidden_dim, hidden_dim,[supports[1]],act_func=nn.LeakyReLU(), dropout_rate=dropout_rate)
        #self.dropout = nn.Dropout(0.5)

    def forward(self, net):
        embed_1 = self.layer1(net)
        #embed_1 = self.layer_norm1(embed_1)
        embed_2 = self.layer2(embed_1)
        #embed_2 = self.layer_norm3(embed_2)
        embed_3 = self.layer3(embed_2)
        #embed_3 = self.layer_norm3(embed_3)
        embed_4 = self.layer4(embed_3)
        #embed_5 = self.layer4(embed_4)
        #embed_4 = self.layer_norm4(embed_4)
        self.embeds = [embed_1,embed_2,embed_3,embed_4]
        #self.embeds = [embed_1,embed_2]
        #self.embeds = [embed_2]
        return self.embeds


class MWGCN(nn.Module):
    def __init__( self, input_dim, \
                        support,\
                        dropout_rate=0.5, \
                        num_classes=8,
                        hidden_dim=300,
                        pre=None,
                        masks=None):
        super(MWGCN, self).__init__()

        # GraphConvolution
        #if masks is not None:
        #    supports = []
        #    for i in range(len(masks)):
        #        supports.append(support[0]*masks[i].cuda())
        self.layer1 = GraphConvolution(input_dim, hidden_dim//2, [support[0]], act_func=nn.LeakyReLU(), featureless=True, dropout_rate=dropout_rate,pre=pre)
        self.layer2 = GraphConvolution(input_dim, hidden_dim//2,[support[1]],act_func=nn.LeakyReLU(), featureless=True, dropout_rate=dropout_rate,pre=pre)
        #self.layer2 = GraphConvolution(input_dim, hidden_dim, [supports[1]], act_func=nn.LeakyReLU(), featureless=True, dropout_rate=dropout_rate,pre=pre)
        #self.layer3 = GraphConvolution(hidden_dim, hidden_dim,support,act_func=nn.LeakyReLU(), dropout_rate=dropout_rate)
        self.dropout = nn.Dropout(0.5)

    def forward(self, net):
        embed_1 = self.layer1(net)
        embed_2 = self.layer2(net)
        embed = torch.cat((embed_1,embed_2),dim=1)
        self.embeds = [embed,embed,embed]
        return self.embeds
