1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35
| import torch import torch.nn.functional as F from torch_geometric.nn import GATConv
class GAT(torch.nn.Module): def __init__(self): super(GAT, self).__init__() self.conv1 = GATConv(10, 16, heads=8) self.conv2 = GATConv(16*8, 2, heads=1)
def forward(self, x, edge_index): x = F.dropout(x, p=0.5, training=self.training) x = F.elu(self.conv1(x, edge_index)) x = F.dropout(x, p=0.5, training=self.training) x = self.conv2(x, edge_index) return F.log_softmax(x, dim=1)
x = torch.randn(10, 10) edge_index = torch.tensor([[0, 1, 2, 3, 4, 5, 6, 7, 8, 9], [1, 2, 3, 4, 5, 6, 7, 8, 9, 0]], dtype=torch.long)
model = GAT() optimizer = torch.optim.Adam(model.parameters(), lr=0.01)
model.train() for epoch in range(100): optimizer.zero_grad() out = model(x, edge_index) loss = F.nll_loss(out, torch.tensor([0, 1, 0, 1, 0, 1, 0, 1, 0, 1])) loss.backward() optimizer.step()
|