import os.path as osp
import argparse
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch_geometric.datasets import Planetoid
import torch_geometric.transforms as T
from torch_geometric.nn import GCNConv, ChebConv  # noqa


class GCN(nn.Module):
    def __init__(self,node_input_channels, hidden_channels, dropout):
        # def __init__(self, node_input_channels, edge_input_channels, hidden_channels, dropout, num_layers):
        super(GCN, self).__init__()
        self.conv1 = GCNConv(node_input_channels, hidden_channels, cached=False,
                             normalize=True)
        self.conv2 = GCNConv(hidden_channels, hidden_channels, cached=False,
                             normalize=True)
    def forward(self,x,edge_index):
        x, edge_index = x, edge_index
        # print(x.size())
        # print(edge_index.size())
        x = self.conv1(x, edge_index)
        x = F.relu(x)
        x = F.dropout(x, training=self.training)
        x = self.conv2(x, edge_index)
        return x


def train():
    model.train()
    optimizer.zero_grad()
    F.nll_loss(model()[data.train_mask], data.y[data.train_mask]).backward()
    optimizer.step()

@torch.no_grad()
def test():
    model.eval()
    logits, accs = model(), []
    for _, mask in data('train_mask', 'val_mask', 'test_mask'):
        pred = logits[mask].max(1)[1]
        acc = pred.eq(data.y[mask]).sum().item() / mask.sum().item()
        accs.append(acc)
    return accs