Graph Neural Networks for Tabular Data¶
개요¶
테이블형 데이터에 GNN을 적용하는 것은 행(샘플) 간의 관계 또는 열(피처) 간의 관계를 그래프로 모델링하여 추가적인 구조 정보를 학습하는 접근법이다.
그래프 구성 방식¶
1. Sample-based Graph (행 기반)¶
┌─────────────────────────────────────────────────────────────┐
│ Sample-based Graph │
│ │
│ 각 행(샘플)이 노드, 샘플 간 유사도가 엣지 │
│ │
│ User A ──────── User B │
│ │ ╲ │ │
│ │ ╲ │ │
│ │ ╲User C │ │
│ │ ╲ ╲ │ │
│ │ ╲ ╲ │ │
│ User D ──────── User E │
│ │
│ 엣지 생성 방법: │
│ ├─ k-NN: 피처 공간에서 k개 최근접 이웃 연결 │
│ ├─ 임계값: 유사도 > threshold인 쌍 연결 │
│ └─ 도메인: 같은 그룹, 거래 관계 등 │
└─────────────────────────────────────────────────────────────┘
2. Feature-based Graph (열 기반)¶
┌─────────────────────────────────────────────────────────────┐
│ Feature-based Graph │
│ │
│ 각 피처가 노드, 피처 간 관계가 엣지 │
│ │
│ Age ─────── Income │
│ │╲ │ │
│ │ ╲ │ │
│ │ ╲ Education │
│ │ ╲ ╱ │ │
│ │ ╲ ╱ │ │
│ Gender ──────── Occupation │
│ │
│ 엣지 생성 방법: │
│ ├─ 상관관계: |corr| > threshold │
│ ├─ 상호정보량: MI > threshold │
│ └─ 도메인 지식: 관련 피처 직접 연결 │
└─────────────────────────────────────────────────────────────┘
3. Bipartite Graph (이분 그래프)¶
┌─────────────────────────────────────────────────────────────┐
│ Bipartite Graph │
│ │
│ 샘플 노드와 피처 노드를 분리 │
│ │
│ Samples Features │
│ ┌──────┐ ┌──────┐ │
│ │ S1 │────────│ F1 │ │
│ │ S2 │──┬─────│ F2 │ │
│ │ S3 │──┘ ┌──│ F3 │ │
│ │ S4 │─────┴──│ F4 │ │
│ └──────┘ └──────┘ │
│ │
│ 엣지 가중치: 피처 값 (정규화) │
│ → 메시지 패싱으로 피처↔샘플 정보 교환 │
└─────────────────────────────────────────────────────────────┘
주요 모델¶
1. GRAPE (Graph-based Prediction with Embeddings)¶
import torch
import torch.nn as nn
from torch_geometric.nn import GCNConv
class GRAPE(nn.Module):
def __init__(self, n_features, hidden_dim, n_classes):
super().__init__()
# 샘플 임베딩
self.feature_encoder = nn.Linear(n_features, hidden_dim)
# GCN layers (샘플 그래프)
self.conv1 = GCNConv(hidden_dim, hidden_dim)
self.conv2 = GCNConv(hidden_dim, hidden_dim)
# 예측 head
self.classifier = nn.Linear(hidden_dim, n_classes)
def forward(self, x, edge_index):
# 피처 인코딩
h = self.feature_encoder(x)
# GNN 메시지 패싱
h = F.relu(self.conv1(h, edge_index))
h = F.relu(self.conv2(h, edge_index))
# 분류
return self.classifier(h)
2. BGNN (Boost Graph Neural Network)¶
┌─────────────────────────────────────────────────────────────┐
│ BGNN │
│ │
│ GBDT + GNN 결합: │
│ │
│ Input Features │
│ │ │
│ ▼ │
│ ┌─────────┐ │
│ │ GBDT │ → 예측 1 + 잔차 │
│ └────┬────┘ │
│ │ │
│ ▼ (잔차 + 그래프 정보) │
│ ┌─────────┐ │
│ │ GNN │ → 예측 2 + 잔차 │
│ └────┬────┘ │
│ │ │
│ ▼ │
│ ┌─────────┐ │
│ │ GBDT │ → 예측 3 │
│ └────┬────┘ │
│ │ │
│ ▼ │
│ 최종 예측 = Σ 예측 │
└─────────────────────────────────────────────────────────────┘
3. TabGNN¶
class TabGNN(nn.Module):
def __init__(self, n_features, hidden_dim, n_classes):
super().__init__()
# 피처 그래프용 GNN
self.feature_gnn = GAT(n_features, hidden_dim)
# 샘플 그래프용 GNN
self.sample_gnn = GAT(hidden_dim, hidden_dim)
# 양방향 attention
self.cross_attention = nn.MultiheadAttention(hidden_dim, n_heads=4)
def forward(self, x, feature_graph, sample_graph):
# 1. 피처 그래프 메시지 패싱
feature_embeds = self.feature_gnn(x.T, feature_graph) # (n_features, hidden)
# 2. 피처 임베딩으로 샘플 표현 구성
sample_embeds = x @ feature_embeds # (n_samples, hidden)
# 3. 샘플 그래프 메시지 패싱
refined = self.sample_gnn(sample_embeds, sample_graph)
return refined
그래프 구성 기법¶
k-NN 그래프¶
from sklearn.neighbors import NearestNeighbors
import numpy as np
def build_knn_graph(X, k=10, metric='cosine'):
"""k-NN 기반 그래프 구성"""
nn = NearestNeighbors(n_neighbors=k, metric=metric)
nn.fit(X)
distances, indices = nn.kneighbors(X)
# Edge list 생성
n_samples = X.shape[0]
src = np.repeat(np.arange(n_samples), k)
dst = indices.flatten()
edge_index = np.stack([src, dst], axis=0)
edge_weight = 1 / (distances.flatten() + 1e-6) # 거리 역수를 가중치로
return edge_index, edge_weight
적응적 그래프 학습¶
class AdaptiveGraphLearner(nn.Module):
"""학습 가능한 그래프 구조"""
def __init__(self, n_samples, hidden_dim, threshold=0.5):
super().__init__()
self.node_embeds = nn.Parameter(torch.randn(n_samples, hidden_dim))
self.threshold = threshold
def forward(self):
# 노드 임베딩 간 유사도 → 엣지 가중치
sim = F.cosine_similarity(
self.node_embeds.unsqueeze(1),
self.node_embeds.unsqueeze(0),
dim=-1
)
# 임계값 적용 (미분 가능하게)
adj = torch.sigmoid((sim - self.threshold) * 10)
return adj
도메인 기반 그래프¶
def build_domain_graph(df, domain_rules):
"""도메인 지식 기반 그래프"""
edges = []
for rule in domain_rules:
if rule['type'] == 'same_group':
# 같은 그룹 내 연결
groups = df.groupby(rule['column']).indices
for group_indices in groups.values():
for i in group_indices:
for j in group_indices:
if i != j:
edges.append((i, j))
elif rule['type'] == 'transaction':
# 거래 관계
for _, row in df.iterrows():
edges.append((row['sender_id'], row['receiver_id']))
return edges
# 예시: 금융 사기 탐지
domain_rules = [
{'type': 'same_group', 'column': 'merchant_category'},
{'type': 'same_group', 'column': 'device_id'},
{'type': 'transaction'} # 거래 연결
]
응용 사례¶
1. 금융 사기 탐지¶
┌─────────────────────────────────────────────────────────────┐
│ Fraud Detection with GNN │
│ │
│ 그래프 구성: │
│ ├─ 노드: 거래 (transaction) │
│ ├─ 엣지 1: 동일 카드 │
│ ├─ 엣지 2: 동일 가맹점 │
│ ├─ 엣지 3: 동일 IP │
│ └─ 엣지 4: 시간적 근접 (1시간 내) │
│ │
│ GNN 효과: │
│ ├─ 단일 거래로는 정상 → 연결된 거래 패턴으로 사기 탐지 │
│ └─ 사기 거래 주변의 "의심 전파" │
│ │
│ 성능 향상: AUC 0.85 → 0.92 (+7%) │
└─────────────────────────────────────────────────────────────┘
2. 추천 시스템¶
# User-Item 이분 그래프
class UserItemGNN(nn.Module):
def __init__(self, n_users, n_items, embed_dim):
self.user_embed = nn.Embedding(n_users, embed_dim)
self.item_embed = nn.Embedding(n_items, embed_dim)
self.gnn = LightGCN(embed_dim) # 경량 GNN
def forward(self, user_item_edges):
user_emb = self.user_embed.weight
item_emb = self.item_embed.weight
# 메시지 패싱
user_out, item_out = self.gnn(user_emb, item_emb, user_item_edges)
return user_out, item_out
3. 환자 유사도 기반 예측¶
성능 비교¶
Benchmark Results¶
| 데이터셋 | XGBoost | TabNet | GNN-only | BGNN |
|---|---|---|---|---|
| Adult | 0.876 | 0.871 | 0.868 | 0.882 |
| HIGGS | 0.753 | 0.748 | 0.745 | 0.759 |
| Fraud | 0.892 | 0.885 | 0.901 | 0.918 |
그래프 효과 분석¶
┌─────────────────────────────────────────────────────────────┐
│ When Does Graph Help? │
│ │
│ 효과 큼: │
│ ├─ 명시적 관계 존재 (소셜 네트워크, 거래) │
│ ├─ 레이블 희소 (semi-supervised) │
│ └─ 노이즈 많음 (이웃 정보로 보정) │
│ │
│ 효과 작음: │
│ ├─ IID 데이터 (샘플 간 독립) │
│ ├─ 충분한 레이블 │
│ └─ 저차원 데이터 │
└─────────────────────────────────────────────────────────────┘
구현 팁¶
1. 스케일링¶
# 대규모 그래프: 미니배치 학습
from torch_geometric.loader import NeighborLoader
loader = NeighborLoader(
data,
num_neighbors=[10, 5], # 1홉 10개, 2홉 5개 샘플링
batch_size=256,
shuffle=True
)
2. 그래프 품질 검증¶
def validate_graph(edge_index, labels):
"""좋은 그래프인지 검증"""
# 동질성 (Homophily): 같은 클래스끼리 연결 비율
src, dst = edge_index
same_label = (labels[src] == labels[dst]).float().mean()
# 너무 낮으면 (< 0.3): 그래프가 태스크에 도움 안됨
# 너무 높으면 (> 0.9): 그래프 없이도 잘 됨
print(f"Homophily: {same_label:.3f}")
return same_label
3. GBDT와 결합¶
# BGNN 스타일 파이프라인
from xgboost import XGBClassifier
# 1단계: GBDT
gbdt = XGBClassifier(n_estimators=100)
gbdt.fit(X_train, y_train)
gbdt_pred = gbdt.predict_proba(X_train)[:, 1]
# 2단계: 잔차 + GNN
residual = y_train - gbdt_pred
gnn_input = np.column_stack([X_train, gbdt_pred])
gnn_model.fit(gnn_input, edge_index, residual)
# 최종 예측
final_pred = gbdt_pred + gnn_model.predict(...)
참고 자료¶
- Ivanov & Prokhorenkova, "Boost then Convolve: Gradient Boosting Meets Graph Neural Networks", ICLR 2021
- Du et al., "GRAPE: GRaph neural network with Adaptive PEs for tabular data", NeurIPS 2021
- You et al., "Graph Structure Learning for Tabular Data", ICML 2022