Contents
1.GraphSAGE
1.1 이웃 샘플링
1.2 Aggregation
2. GraphSAGE 코드 구현
3. Inductive VS Transductive Learning
그래프나 추천 시스템 분야에서는 모델에 대한 scalability 문제가 중요합니다. 아무리 정확한 모델이더라도, 연산이 너무 많아 모델이 무거워지면 현실에 존재하는 빅데이터에 적용하기 힘들어집니다. GraphSage는 lager 그래프를 다루기 위해 제안된 모델입니다. 특히 e-commerce 분야에서 수백만명의 유저와 아이템이 존재하기 때문에, 추천 시스템에서도 GraphSage 알고리즘을 이용하는 경우가 많습니다. 이번 포스팅에서는 큰 그래프는 다룰 때 scalability를 해결하기 위한 GrapgSage를 알아보겠습니다. 간단히 개념을 소개한 후에, 실제 데이터를 통해 node classification 문제와 multi-label classification 문제를 해결하기 위한 코드를 소개하겠습니다.
1. GraphSAGE
2017년에 발표된 GraphSAGE는 대규모 그래프에 대한 inductive 학습을 위한 프레임워크입니다. GraphSAGE의 목표도 역시, 노드에 대한 임베딩을 생성하는 것입니다. GraphSAGE는 GCN과 GAT의 두 가지 문제를 해결합니다. 바로 대규모 그래프에 대한 확장성 문제와, 새로운 데이터에 대한 일반화 문제입니다. 이 섹션에서는 GraphSAGE의 두 가지 주요 구성 요소를 기준으로 설명드리겠습니다.
1.1 이웃 샘플링
보통 딥러닝 프레임워크에서는 전체 데이터세트에 대해 mini-batch 단위로 학습을 수행합니다. 그러나 그래프에서는 아무 노드를 batch로 불러오게 되면, 중요한 connection을 잃게 되는 문제가 발생할 것입니다. 이때 GNN의 기본 개념을 다시 생각해볼 수 있습니다. GNN은 기본적으로 이웃의 정보를 바탕으로 노드 임베딩을 학습합니다. 이때 직접적으로(direct) 이웃한 노드들을 1 hop에 있다고 합니다. 이 1 hop 이웃 노드들의 이웃을 2 hop이라고 표현합니다. 생각해 보면 이는 GNN layer를 두 개 쌓은 것과 마찬가지입니다.
이러한 아이디어는 그래프 구조에서는 batch 단위에 대한 연산을 가능하게 합니다. 아래 그림은 각 노드들의 2 hop 이웃에 대한 mini-batch를 생성하는 그림입니다.
노드 0의 임베딩을 얻기 위해서 2 hop 이웃 노드들을 aggregation 합니다. 이때 그림의 회색 박스가 aggregation이 들어가는 자리입니다. 그러나 이러한 방식은 hop의 크기가 커질수록 그래프의 연산 크기가 기하급수적으로 커진다는 단점이 있습니다. 또한 예를 들어, 소셜 네트워크 그래프에서, 인플루언서 노드의 경우에는 이웃 노드가 매우 많기 때문에, 특정 노드의 연산이 매우 커질 수도 있다는 문제가 발생합니다. 이러한 문제를 해결하기 위해, GraphSAGE는 neighbor sampling이라는 기법을 소개했습니다. 이는 모든 이웃을 추가하는 대신, 미리 정의된 수의 이웃만을 샘플링한다는 아이디어입니다. 예를 들어, 첫 번째 hop 동안 (최대) 3개의 이웃만 유지하고, 두 번째 hop 동안 5개의 이웃을 선택합니다. 따라서 이 경우에는 계산 그래프가 3 × 5 = 15개의 노드를 초과할 수 없습니다.
만약 샘플링 수가 낮으면, 연산이 효율적이지만 그래프에 대한 무작위성이 증가하여 분산이 높아집니다. 따라서 이러한 이웃 샘플링 기법은 대규모 그래프의 복잡성을 줄여주는 대신, 중요한 정보를 제거하여 정확도에 부정적인 영향을 미칠 수도 있다는 트레이드오프를 야기할 수 있습니다.
1.2 Aggregation
GraphSAGE에서는 이웃(-hop) 노드들의 정보를 반영할 때 세 가지 aggregation을 제안합니다.
- A mean aggregator
- A LSTM aggregator
- A pooling aggregator
가장 쉬운 예로, mean aggreation 같은 경우는 다음과 같은 식으로 계산됩니다.
$$ \acute{h_i} = \sigma (W \bullet mean_{j\in\tilde{N_i}}(h_j)) $$
LSTM은 인풋 데이터를 시퀀스 형태로 받기 때문에, 노드에 시퀀스를 부여하기 위해 랜덤 하게 permutation을 진행합니다.
2. GraphSAGE 코드 구현
앞서 설명한 대로, 대규모 그래프에 GNN을 적용하기 위한 GraphSAGE를 PyG의 코드로 구현해 보겠습니다. 먼저 노드 분류를 위해 데이터 세트는 MIT에서 제공하는 PubMed 데이터를 사용합니다. 그래프는 19,797개의 노드와 88,648개의 엣지로 구성되었고, 노드 feature는 TF-IDF 가중치 벡터인 500차원으로 구성되어 있습니다. 본 코드에서는 각 노드를 총 3가지의 카테고리로 분류하겠습니다.
from torch_geometric.datasets import Planetoid
dataset = Planetoid(root='.', name="Pubmed")
data = dataset[0]
# Print information about the dataset
print(f'Dataset: {dataset}')
print('-------------------')
print(f'Number of graphs: {len(dataset)}')
print(f'Number of nodes: {data.x.shape[0]}')
print(f'Number of features: {dataset.num_features}')
print(f'Number of classes: {dataset.num_classes}')
# Print information about the graph
print(f'\nGraph:')
print('------')
print(f'Training nodes: {sum(data.train_mask).item()}')
print(f'Evaluation nodes: {sum(data.val_mask).item()}')
print(f'Test nodes: {sum(data.test_mask).item()}')
print(f'Edges are directed: {data.is_directed()}')
print(f'Graph has isolated nodes: {data.has_isolated_nodes()}')
print(f'Graph has loops: {data.has_self_loops()}')
### result ###
Processing...
Done!
Dataset: Pubmed()
-------------------
Number of graphs: 1
Number of nodes: 19717
Number of features: 500
Number of classes: 3
Graph:
------
Training nodes: 60
Evaluation nodes: 500
Test nodes: 1000
Edges are directed: False
Graph has isolated nodes: False
Graph has loops: False
GraphSAGE에서 첫 번째 단계는 이웃 샘플링입니다. 이는 PyG에서 제공하는 NeighborLoader클래스를 통해 구현합니다. 이웃 샘플링은 첫 번째 이웃에는 5개, 두 번째 이웃에는 10개로 지정하겠습니다. 총 60개의 타겟 노드에 대해, 배치 사이즈를 16개로 구성하겠습니다. 따라서 총 4개의 배치 단위가 만들어지는 것입니다.
from torch_geometric.loader import NeighborLoader
from torch_geometric.utils import to_networkx
# Create batches with neighbor sampling
train_loader = NeighborLoader(
data,
num_neighbors=[5, 10],
batch_size=16,
input_nodes=data.train_mask,
)
# Print each subgraph
for i, subgraph in enumerate(train_loader):
print(f'Subgraph {i}: {subgraph}')
### result ###
Subgraph 0: Data(x=[391, 500], edge_index=[2, 439], y=[391], train_mask=[391], val_mask=[391], test_mask=[391], input_id=[16], batch_size=16)
Subgraph 1: Data(x=[257, 500], edge_index=[2, 303], y=[257], train_mask=[257], val_mask=[257], test_mask=[257], input_id=[16], batch_size=16)
Subgraph 2: Data(x=[260, 500], edge_index=[2, 299], y=[260], train_mask=[260], val_mask=[260], test_mask=[260], input_id=[16], batch_size=16)
Subgraph 3: Data(x=[192, 500], edge_index=[2, 231], y=[192], train_mask=[192], val_mask=[192], test_mask=[192], input_id=[12], batch_size=12)
각 배치 단위별 서브 그래프는 아래와 같습니다.
import numpy as np
import networkx as nx
import matplotlib.pyplot as plt
# Plot each subgraph
fig = plt.figure(figsize=(16,16))
for idx, (subdata, pos) in enumerate(zip(train_loader, [221, 222, 223, 224])):
G = to_networkx(subdata, to_undirected=True)
ax = fig.add_subplot(pos)
ax.set_title(f'Subgraph {idx}', fontsize=24)
plt.axis('off')
nx.draw_networkx(G,
pos=nx.spring_layout(G, seed=0),
with_labels=False,
node_color=subdata.y,
)
plt.show()
GNN 아키텍처를 위해 두 개의 SAGEConv 레이어를 사용했습니다. 이때, aggregator는 default인 mean aggregator가 적용됩니다.
def accuracy(pred_y, y):
"""정확도 계산"""
return ((pred_y == y).sum() / len(y)).item()
import torch
torch.manual_seed(-1)
import torch.nn.functional as F
from torch_geometric.nn import SAGEConv
class GraphSAGE(torch.nn.Module):
"""GraphSAGE"""
def __init__(self, dim_in, dim_h, dim_out):
super().__init__()
self.sage1 = SAGEConv(dim_in, dim_h) # default = mean aggregator
self.sage2 = SAGEConv(dim_h, dim_out)
def forward(self, x, edge_index):
h = self.sage1(x, edge_index)
h = torch.relu(h)
h = F.dropout(h, p=0.5, training=self.training)
h = self.sage2(h, edge_index)
return F.log_softmax(h, dim=1)
def fit(self, data, epochs):
criterion = torch.nn.CrossEntropyLoss()
optimizer = torch.optim.Adam(self.parameters(), lr=0.01)
self.train()
for epoch in range(epochs+1):
total_loss = 0
acc = 0
val_loss = 0
val_acc = 0
# Train on batches
for batch in train_loader:
optimizer.zero_grad()
out = self(batch.x, batch.edge_index)
loss = criterion(out[batch.train_mask], batch.y[batch.train_mask])
total_loss += loss.item()
acc += accuracy(out[batch.train_mask].argmax(dim=1), batch.y[batch.train_mask])
loss.backward()
optimizer.step()
# Validation
val_loss += criterion(out[batch.val_mask], batch.y[batch.val_mask])
val_acc += accuracy(out[batch.val_mask].argmax(dim=1), batch.y[batch.val_mask])
# Print metrics every 10 epochs
if epoch % 20 == 0:
print(f'Epoch {epoch:>3} | Train Loss: {loss/len(train_loader):.3f} | Train Acc: {acc/len(train_loader)*100:>6.2f}% | Val Loss: {val_loss/len(train_loader):.2f} | Val Acc: {val_acc/len(train_loader)*100:.2f}%')
@torch.no_grad()
def test(self, data):
self.eval()
out = self(data.x, data.edge_index)
acc = accuracy(out.argmax(dim=1)[data.test_mask], data.y[data.test_mask])
return acc
# GraphSAGE 생성
graphsage = GraphSAGE(dataset.num_features, 64, dataset.num_classes)
print(graphsage)
# Train
graphsage.fit(data, 200)
# Test
acc = graphsage.test(data)
print(f'GraphSAGE test accuracy: {acc*100:.2f}%')
### result ###
GraphSAGE(
(sage1): SAGEConv(500, 64, aggr=mean)
(sage2): SAGEConv(64, 3, aggr=mean)
)
Epoch 0 | Train Loss: 0.317 | Train Acc: 29.50% | Val Loss: 1.12 | Val Acc: 27.08%
Epoch 20 | Train Loss: 0.002 | Train Acc: 100.00% | Val Loss: 0.49 | Val Acc: 75.32%
Epoch 40 | Train Loss: 0.000 | Train Acc: 100.00% | Val Loss: 0.61 | Val Acc: 77.08%
Epoch 60 | Train Loss: 0.000 | Train Acc: 100.00% | Val Loss: 0.42 | Val Acc: 89.93%
Epoch 80 | Train Loss: 0.000 | Train Acc: 100.00% | Val Loss: 0.66 | Val Acc: 70.21%
Epoch 100 | Train Loss: 0.000 | Train Acc: 100.00% | Val Loss: 0.50 | Val Acc: 78.60%
Epoch 120 | Train Loss: 0.000 | Train Acc: 100.00% | Val Loss: 0.54 | Val Acc: 86.13%
Epoch 140 | Train Loss: 0.000 | Train Acc: 100.00% | Val Loss: 0.63 | Val Acc: 75.00%
Epoch 160 | Train Loss: 0.000 | Train Acc: 100.00% | Val Loss: 0.52 | Val Acc: 77.68%
Epoch 180 | Train Loss: 0.000 | Train Acc: 100.00% | Val Loss: 0.52 | Val Acc: 82.92%
Epoch 200 | Train Loss: 0.000 | Train Acc: 100.00% | Val Loss: 0.41 | Val Acc: 85.00%
GraphSAGE test accuracy: 74.30%
테스트 세트에 적용한 결과 74.30%의 정확도를 얻었습니다. 이는 GCN과 GAT에 비해 크게 향상된 결과는 아닙니다. 심지어 결과가 더 낮을 수도 있습니다. 그러나 GraphSAGE는 대규모 그래프 연산을 위해 효율적이고 빠른 학습을 진행한다는 장점이 있습니다. 실제로 PubMed 데이터 세트에 GCN, GAT를 적용하면 연산이 상당히 오래 걸리는 것을 알 수 있습니다.
3. Inductive VS Transductive Learning
GNN에서는 대상 그래프를 학습할 때, inductive와 transductive 학습 방식으로 구분할 수 있다.
- Inductive Learning: GNN과정에서 훈련 세트의 데이터만을 보는 방식입니다. 이는 훈련에 사용된 그래프와 그래프의 레이블에 대해서만 임베딩을 생성하고 예측하는 것입니다.
- Transductive Learning: GNN과정에서 훈련 세트와 테스트 세트의 데이터를 모두 볼 수 있습니다. 그러나 훈련 과정에서는 훈련 노드의 상태만을 학습하지, 테스트 노드에 대한 업데이트는 이뤄지지 않습니다.
GraphSAGE는 이웃 샘플링을 통해 pruning 된 그래프에 대해 예측을 하도록 설계되었기 때문에, inductive 프레임워크로 간주될 수 있습니다. 이러한 설명에 대한 예제로, 단백질 간의 상호작용(PPI) 네트워크 데이터셋을 사용하겠습니다. PPI 네트워크는 24개의 그래프로 구성되었으며, 노드는 21,557개, 엣지는 345,353개입니다. 이 데이터셋의 목표는 121개의 레이블로 노드를 분류하는 것입니다. 학습 후 모든 노드는 121개 중 하나의 클래스만을 할당받습니다.
Training 데이터는 20개, 검증 데이터와 테스트 데이터는 각각 2개의 그래프로 구성됩니다.
이웃 샘플링은 테스트 세트에만 적용합니다. 첫 번째 이웃은 20개, 두 번째 이웃은 10개를 샘플링합니다.
Batch.from_data_list()를 사용하여 20개의 그래프를 하나의 싱글 셋으로 만들어 줍니다.
f1 score를 통해 정확도를 확인합니다.
from torch_geometric.datasets import PPI
# 데이터 로딩
train_dataset = PPI(root=".", split='train')
val_dataset = PPI(root=".", split='val')
test_dataset = PPI(root=".", split='test')
from torch_geometric.data import Batch
from torch_geometric.loader import DataLoader, NeighborLoader
# 훈련세트에 이웃 샘플링 적용
train_data = Batch.from_data_list(train_dataset)
loader = NeighborLoader(train_data, batch_size=2048, shuffle=True, num_neighbors=[20, 10], num_workers=2, persistent_workers=True)
# 데이터 로더
train_loader = DataLoader(train_dataset, batch_size=2)
val_loader = DataLoader(val_dataset, batch_size=2)
test_loader = DataLoader(test_dataset, batch_size=2)
import torch
from sklearn.metrics import f1_score
from torch_geometric.nn import GraphSAGE
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
model = GraphSAGE(
in_channels=train_dataset.num_features,
hidden_channels=512,
num_layers=2,
out_channels=train_dataset.num_classes,
).to(device)
criterion = torch.nn.BCEWithLogitsLoss()
optimizer = torch.optim.Adam(model.parameters(), lr=0.005)
def fit():
model.train()
total_loss = 0
for data in train_loader:
data = data.to(device)
optimizer.zero_grad()
out = model(data.x, data.edge_index)
loss = criterion(out, data.y)
total_loss += loss.item() * data.num_graphs
loss.backward()
optimizer.step()
return total_loss / len(train_loader.dataset)
@torch.no_grad()
def test(loader):
model.eval()
data = next(iter(loader))
out = model(data.x.to(device), data.edge_index.to(device))
preds = (out > 0).float().cpu()
y, pred = data.y.numpy(), preds.numpy()
return f1_score(y, pred, average='micro') if pred.sum() > 0 else 0
for epoch in range(301):
loss = fit()
val_f1 = test(val_loader)
if epoch % 50 == 0:
print(f'Epoch {epoch:>3} | Train Loss: {loss:.3f} | Val F1-score: {val_f1:.4f}')
print(f'Test F1-score: {test(test_loader):.4f}')
### result ###
Epoch 0 | Train Loss: 0.587 | Val F1-score: 0.4202
Epoch 50 | Train Loss: 0.191 | Val F1-score: 0.8426
Epoch 100 | Train Loss: 0.141 | Val F1-score: 0.8806
Epoch 150 | Train Loss: 0.122 | Val F1-score: 0.8939
Epoch 200 | Train Loss: 0.105 | Val F1-score: 0.9052
Epoch 250 | Train Loss: 0.097 | Val F1-score: 0.9119
Epoch 300 | Train Loss: 0.088 | Val F1-score: 0.9156
Test F1-score: 0.9357
Inductive 세팅을 통해 높은 f1 스코어를 기록한 것을 볼 수 있습니다. 코드를 살펴보면 마스킹이 포함되지 않습니다. 이는 각 그래프를 학습, 검증, 테스트 세트로 나누고, inductive learning을 수행했기 때문입니다. 또한 dataloader도 각자 다릅니다. 만약 transductive learning을 진행하고 싶으면, Batch.from_data_list()를 통해 모든 그래프를 하나의 셋으로 만들 수 있습니다.
이상으로 대규모 그래프 학습을 위한 GraphSAGE에 대한 포스팅을 마치겠습니다. 감사합니다.
Reference
Hands-On Graph Neural Networks Using Python, published by Packt.
'Graph > Graph with Code' 카테고리의 다른 글
9. 그래프를 활용한 링크 예측(Link Prediction): 기초 방법론부터 GAE까지 (0) | 2023.07.07 |
---|---|
8. GIN (Graph isomorphism network) 기초 및 코드 구현 (0) | 2023.07.03 |
6. Graph Attention Networks (GATs) 기초 및 코드 구현 (0) | 2023.06.29 |
5. GCN 기초 및 코드 구현 (0) | 2023.06.28 |
4. GNN 기초 (직접 코드로 구현하기) (0) | 2023.06.27 |