import torch
import torch.nn as nn
import torch.nn.functional as F
from models.GAT_layers import GraphAttentionLayer, SpGraphAttentionLayer


class GAT(nn.Module):
    def __init__(self, model_opt, nheads=8):
        """Dense version of GAT."""
        super(GAT, self).__init__()
        self.dropout = float(model_opt['GAT_dropout'])
        self.alpha = float(model_opt['GAT_alpha'])
        self.nhid = int(model_opt['GAT_hidden'])
        self.nfeat = int(model_opt['enc_rnn_size'])*2

        self.attentions = [GraphAttentionLayer(self.nfeat, self.nhid, dropout=self.dropout, alpha=self.alpha, concat=True) for _ in range(nheads)]
        for i, attention in enumerate(self.attentions):
            self.add_module('attention_{}'.format(i), attention)

        # self.out_att = GraphAttentionLayer(self.nhid * nheads, nclass, dropout=self.dropout, alpha=self.alpha, concat=False)

    def forward(self, x, adj):
        x = F.dropout(x, self.dropout, training=self.training)
        x = torch.cat([att(x, adj) for att in self.attentions], dim=1)
        x = F.dropout(x, self.dropout, training=self.training)
        x = F.elu(self.out_att(x, adj))
        return F.log_softmax(x, dim=1)


class SpGAT(nn.Module):
    def __init__(self, model_opt, nheads=8):
        """Sparse version of GAT."""
        super(SpGAT, self).__init__()
        self.dropout = float(model_opt['GAT_dropout'])
        self.alpha = float(model_opt['GAT_alpha'])
        self.nhid = int(model_opt['GAT_hidden'])
        self.nfeat = int(model_opt['enc_rnn_size']) * 2

        self.attentions = [SpGraphAttentionLayer(self.nfeat,
                                                 self.nhid,
                                                 dropout=self.dropout,
                                                 alpha=self.alpha,
                                                 concat=True) for _ in range(nheads)]
        for i, attention in enumerate(self.attentions):
            self.add_module('attention_{}'.format(i), attention)

        # self.out_att = SpGraphAttentionLayer(nhid * nheads,
        #                                      nclass,
        #                                      dropout=dropout,
        #                                      alpha=alpha,
        #                                      concat=False)

    def forward(self, x, adj):
        x = F.dropout(x, self.dropout, training=self.training)
        x = torch.cat([att(x, adj) for att in self.attentions], dim=1)
        x = F.dropout(x, self.dropout, training=self.training)
        x = F.elu(self.out_att(x, adj))
        return F.log_softmax(x, dim=1)

