메시지 패싱 (message passing)이란?
본 포스팅에서는 메시지 패싱을 그래프 신경망 관점에서 알아보겠습니다.
메시지 패싱은 그래프에서 노드 간에 정보를 전달하고 상호작용하는 방법을 의미합니다. 각 노드는 고유한 노드 임베딩을 학습하기 위해 주변 노드로부터 메시지를 받아들이고, 이를 기반으로 자신의 상태를 업데이트합니다.
예를 들어, GCN이나 GAT와 같은 GNN 기반 모델들은, 노드 임베딩을 계산하기 위한 세부적인 방법은 다르더라도 그 전체적인 흐름은 메시지 패싱의 형태를 따릅니다. 즉 우리는 다양한 형태의 GNN 레이어를 Message Passing Neural Network (MPNN or MP-GNN)으로 일반화 할 수 있습니다. Gilmer et al. (2017)에 따르면 노드에 메시지를 전달하는 메시지 패싱은 다음과 같은 원리로 작동 됩니다 [1].
- Message: 각 노드는 연결된 이웃마다 전달할 메시지를 생성하기 위해 메시지 함수를 사용합니다.
- Aggregate: 각 노드는 이웃들로부터 받은 메시지를 집계합니다.
- Update: 각 노드는 현재 노드 feature와 집계된 메시지를 결합하는 함수를 사용하여 새로운 특성을 업데이트합니다.
이 세 함수는 아래와 같은 식으로 나타낼 수 있습니다. \(h_i\)는 i 노드의 임베딩, \(e_{ij}\)는 j -> i 방향의 엣지, \(\phi\)는 메시지를 전달하는 함수, \(\oplus\)는 aggregation 함수, \(\gamma\)는 노드 피처를 업데이트 하는 업데이트 함수입니다.
$$ \acute{h_i} = \gamma (h_i,\oplus {j\in N_i} \phi({h_i, h_j, e_{j,i}})) $$
메시지 패싱에 대한 일반적인 흐름을 나타낸 식이라 이해하기 어렵지 않을 것이라고 생각됩니다. 아래 그림은 노드 A의 임베딩을 업데이트 하기 위해, 이웃 노드의 메시지들을 aggregate하는 과정을 시각화한 모습입니다. 여기서는 2-hop 이웃들의 정보까지 활용한 모습을 볼 수 있습니다.
이러한 메시지 패싱을 통한 노드 임베딩은 노드 classification이나 그래프 classification과 같은 작업을 수행하기 위해 진행되었습니다. 이후 각 노드가 최종적으로 업데이트 된 상황에서, 노드들의 정보를 특정하게 불러오는 것을 readout phase라고 합니다. 이러한 readout 함수 역시 간단하게 표현할 수 있습니다.
$$ \hat{y} = R(\{h^T_v | v \in G \} $$
메시지 전달 -> aggregate -> update 하는 메시지 패싱 과정을 T번 반복한 후에 최종적으로 업데이트 된 임베딩 정보들을, readout 함수인 \(R\)에 전달하여 원하는 상위 작업을 수행하게 됩니다.
2. Message passing 코드로 구현하기
PyG에서 제공하는 MeggagePassing 클래스를 이용하여 GCN 모델을 구현해보겠습니다.
먼저 필요한 라이브러리를 불러옵니다.
import torch
!pip install -q torch-scatter~=2.1.0 torch-sparse~=0.6.16 torch-cluster~=1.6.0 torch-spline-conv~=1.2.1 torch-geometric==2.2.0 -f https://data.pyg.org/whl/torch-{torch.__version__}.html
torch.manual_seed(0)
torch.cuda.manual_seed(0)
torch.cuda.manual_seed_all(0)
torch.backends.cudnn.deterministic = True
torch.backends.cudnn.benchmark = False
아래 코드를 통해 GCNConv층을 불러옵니다.
import numpy as np
np.random.seed(0)
import torch
torch.manual_seed(0)
from torch.nn import Linear
from torch_geometric.nn import MessagePassing
from torch_geometric.utils import add_self_loops, degree
class GCNConv(MessagePassing):
def __init__(self, dim_in, dim_h):
super().__init__(aggr='add')
self.linear = Linear(dim_in, dim_h, bias=False)
def forward(self, x, edge_index):
edge_index, _ = add_self_loops(edge_index, num_nodes=x.size(0))
x = self.linear(x)
row, col = edge_index
deg = degree(col, x.size(0), dtype=x.dtype)
deg_inv_sqrt = deg.pow(-0.5)
deg_inv_sqrt[deg_inv_sqrt == float('inf')] = 0
norm = deg_inv_sqrt[row] * deg_inv_sqrt[col]
out = self.propagate(edge_index, x=x, norm=norm)
return out
def message(self, x, norm):
return norm.view(-1, 1) * x
conv = GCNConv(16, 32)
- GCNConv 클래스는 torch_geometric.nn의 MessagePassing을 상속 받습니다.
- __init__은 입력 노드의 차원과 은닉층의 차원을 입력 받습니다.
- super().__init__(agg='add')를 통해, 이웃 노드의 정보를 aggregate 할 때, add 연산을 적용했습니다.
- forward() 함수는 노드의 feature x와 edge_index를 입력받아 linear 연산을 수행합니다.
- 이때 torch_geometrics.utils에서 add_self_loops를 적용하여 인접 행렬에서 자기 자신을 고려하도록 만들어 줍니다.
- GCN은 이웃 노드의 정보를 aggregate하는 과정에서 노드 degree에 따라 정규화를 적용합니다. \( \frac{1}{\sqrt{deg(i)} \sqrt{deg(j)}} \)
- self.propagate(): self-loop를 포함한 edge_index와 norm 변수에 저장된 normalization 텐서, 노드 피처 x를 사용하여 update를 수행합니다. 이때, (aggr='add')로 설정했기 때문에, aggreate function은 add 연산을 수행합니다.
- message() 함수는 노드 피처 x와 norm을 입력받아 이웃 노드들을 정규화 합니다.
Reference
Hands-On Graph Neural Networks Using Python, published by Packt.
[1] Gilmer, J., Schoenholz, S. S., Riley, P. F., Vinyals, O., & Dahl, G. E. (2017, July). Neural message passing for quantum chemistry. In International conference on machine learning (pp. 1263-1272). PMLR.
'Graph > Graph Theory' 카테고리의 다른 글
그래프와 행렬: 라플라시안 행렬(Laplacian matrix) (0) | 2023.07.06 |
---|