콘텐츠로 이동
Data Prep
상세

Graph Neural Networks for Tabular Data

개요

테이블형 데이터에 GNN을 적용하는 것은 행(샘플) 간의 관계 또는 열(피처) 간의 관계를 그래프로 모델링하여 추가적인 구조 정보를 학습하는 접근법이다.

그래프 구성 방식

1. Sample-based Graph (행 기반)

┌─────────────────────────────────────────────────────────────┐
│               Sample-based Graph                            │
│                                                             │
│  각 행(샘플)이 노드, 샘플 간 유사도가 엣지                  │
│                                                             │
│       User A ──────── User B                               │
│         │  ╲           │                                   │
│         │   ╲          │                                   │
│         │    ╲User C   │                                   │
│         │     ╲    ╲   │                                   │
│         │      ╲    ╲  │                                   │
│       User D ──────── User E                               │
│                                                             │
│  엣지 생성 방법:                                            │
│  ├─ k-NN: 피처 공간에서 k개 최근접 이웃 연결                │
│  ├─ 임계값: 유사도 > threshold인 쌍 연결                   │
│  └─ 도메인: 같은 그룹, 거래 관계 등                        │
└─────────────────────────────────────────────────────────────┘

2. Feature-based Graph (열 기반)

┌─────────────────────────────────────────────────────────────┐
│               Feature-based Graph                           │
│                                                             │
│  각 피처가 노드, 피처 간 관계가 엣지                        │
│                                                             │
│       Age ─────── Income                                   │
│        │╲          │                                       │
│        │ ╲         │                                       │
│        │  ╲ Education                                      │
│        │   ╲   ╱   │                                       │
│        │    ╲ ╱    │                                       │
│    Gender ──────── Occupation                              │
│                                                             │
│  엣지 생성 방법:                                            │
│  ├─ 상관관계: |corr| > threshold                          │
│  ├─ 상호정보량: MI > threshold                            │
│  └─ 도메인 지식: 관련 피처 직접 연결                       │
└─────────────────────────────────────────────────────────────┘

3. Bipartite Graph (이분 그래프)

┌─────────────────────────────────────────────────────────────┐
│                 Bipartite Graph                             │
│                                                             │
│  샘플 노드와 피처 노드를 분리                               │
│                                                             │
│  Samples          Features                                 │
│  ┌──────┐        ┌──────┐                                  │
│  │ S1   │────────│ F1   │                                  │
│  │ S2   │──┬─────│ F2   │                                  │
│  │ S3   │──┘  ┌──│ F3   │                                  │
│  │ S4   │─────┴──│ F4   │                                  │
│  └──────┘        └──────┘                                  │
│                                                             │
│  엣지 가중치: 피처 값 (정규화)                              │
│  → 메시지 패싱으로 피처↔샘플 정보 교환                     │
└─────────────────────────────────────────────────────────────┘

주요 모델

1. GRAPE (Graph-based Prediction with Embeddings)

import torch
import torch.nn as nn
from torch_geometric.nn import GCNConv

class GRAPE(nn.Module):
    def __init__(self, n_features, hidden_dim, n_classes):
        super().__init__()
        # 샘플 임베딩
        self.feature_encoder = nn.Linear(n_features, hidden_dim)

        # GCN layers (샘플 그래프)
        self.conv1 = GCNConv(hidden_dim, hidden_dim)
        self.conv2 = GCNConv(hidden_dim, hidden_dim)

        # 예측 head
        self.classifier = nn.Linear(hidden_dim, n_classes)

    def forward(self, x, edge_index):
        # 피처 인코딩
        h = self.feature_encoder(x)

        # GNN 메시지 패싱
        h = F.relu(self.conv1(h, edge_index))
        h = F.relu(self.conv2(h, edge_index))

        # 분류
        return self.classifier(h)

2. BGNN (Boost Graph Neural Network)

┌─────────────────────────────────────────────────────────────┐
│                      BGNN                                   │
│                                                             │
│  GBDT + GNN 결합:                                          │
│                                                             │
│  Input Features                                            │
│       │                                                     │
│       ▼                                                     │
│  ┌─────────┐                                               │
│  │  GBDT   │ → 예측 1 + 잔차                               │
│  └────┬────┘                                               │
│       │                                                     │
│       ▼ (잔차 + 그래프 정보)                               │
│  ┌─────────┐                                               │
│  │  GNN    │ → 예측 2 + 잔차                               │
│  └────┬────┘                                               │
│       │                                                     │
│       ▼                                                     │
│  ┌─────────┐                                               │
│  │  GBDT   │ → 예측 3                                      │
│  └────┬────┘                                               │
│       │                                                     │
│       ▼                                                     │
│  최종 예측 = Σ 예측                                        │
└─────────────────────────────────────────────────────────────┘

3. TabGNN

class TabGNN(nn.Module):
    def __init__(self, n_features, hidden_dim, n_classes):
        super().__init__()

        # 피처 그래프용 GNN
        self.feature_gnn = GAT(n_features, hidden_dim)

        # 샘플 그래프용 GNN
        self.sample_gnn = GAT(hidden_dim, hidden_dim)

        # 양방향 attention
        self.cross_attention = nn.MultiheadAttention(hidden_dim, n_heads=4)

    def forward(self, x, feature_graph, sample_graph):
        # 1. 피처 그래프 메시지 패싱
        feature_embeds = self.feature_gnn(x.T, feature_graph)  # (n_features, hidden)

        # 2. 피처 임베딩으로 샘플 표현 구성
        sample_embeds = x @ feature_embeds  # (n_samples, hidden)

        # 3. 샘플 그래프 메시지 패싱
        refined = self.sample_gnn(sample_embeds, sample_graph)

        return refined

그래프 구성 기법

k-NN 그래프

from sklearn.neighbors import NearestNeighbors
import numpy as np

def build_knn_graph(X, k=10, metric='cosine'):
    """k-NN 기반 그래프 구성"""
    nn = NearestNeighbors(n_neighbors=k, metric=metric)
    nn.fit(X)

    distances, indices = nn.kneighbors(X)

    # Edge list 생성
    n_samples = X.shape[0]
    src = np.repeat(np.arange(n_samples), k)
    dst = indices.flatten()

    edge_index = np.stack([src, dst], axis=0)
    edge_weight = 1 / (distances.flatten() + 1e-6)  # 거리 역수를 가중치로

    return edge_index, edge_weight

적응적 그래프 학습

class AdaptiveGraphLearner(nn.Module):
    """학습 가능한 그래프 구조"""
    def __init__(self, n_samples, hidden_dim, threshold=0.5):
        super().__init__()
        self.node_embeds = nn.Parameter(torch.randn(n_samples, hidden_dim))
        self.threshold = threshold

    def forward(self):
        # 노드 임베딩 간 유사도 → 엣지 가중치
        sim = F.cosine_similarity(
            self.node_embeds.unsqueeze(1),
            self.node_embeds.unsqueeze(0),
            dim=-1
        )

        # 임계값 적용 (미분 가능하게)
        adj = torch.sigmoid((sim - self.threshold) * 10)

        return adj

도메인 기반 그래프

def build_domain_graph(df, domain_rules):
    """도메인 지식 기반 그래프"""
    edges = []

    for rule in domain_rules:
        if rule['type'] == 'same_group':
            # 같은 그룹 내 연결
            groups = df.groupby(rule['column']).indices
            for group_indices in groups.values():
                for i in group_indices:
                    for j in group_indices:
                        if i != j:
                            edges.append((i, j))

        elif rule['type'] == 'transaction':
            # 거래 관계
            for _, row in df.iterrows():
                edges.append((row['sender_id'], row['receiver_id']))

    return edges

# 예시: 금융 사기 탐지
domain_rules = [
    {'type': 'same_group', 'column': 'merchant_category'},
    {'type': 'same_group', 'column': 'device_id'},
    {'type': 'transaction'}  # 거래 연결
]

응용 사례

1. 금융 사기 탐지

┌─────────────────────────────────────────────────────────────┐
│            Fraud Detection with GNN                         │
│                                                             │
│  그래프 구성:                                               │
│  ├─ 노드: 거래 (transaction)                               │
│  ├─ 엣지 1: 동일 카드                                      │
│  ├─ 엣지 2: 동일 가맹점                                    │
│  ├─ 엣지 3: 동일 IP                                        │
│  └─ 엣지 4: 시간적 근접 (1시간 내)                         │
│                                                             │
│  GNN 효과:                                                  │
│  ├─ 단일 거래로는 정상 → 연결된 거래 패턴으로 사기 탐지    │
│  └─ 사기 거래 주변의 "의심 전파"                           │
│                                                             │
│  성능 향상: AUC 0.85 → 0.92 (+7%)                          │
└─────────────────────────────────────────────────────────────┘

2. 추천 시스템

# User-Item 이분 그래프
class UserItemGNN(nn.Module):
    def __init__(self, n_users, n_items, embed_dim):
        self.user_embed = nn.Embedding(n_users, embed_dim)
        self.item_embed = nn.Embedding(n_items, embed_dim)
        self.gnn = LightGCN(embed_dim)  # 경량 GNN

    def forward(self, user_item_edges):
        user_emb = self.user_embed.weight
        item_emb = self.item_embed.weight

        # 메시지 패싱
        user_out, item_out = self.gnn(user_emb, item_emb, user_item_edges)

        return user_out, item_out

3. 환자 유사도 기반 예측

환자 그래프:
- 노드: 환자
- 엣지: 진단 코드 유사도, 인구통계 유사도

효과:
- 희귀 질환 예측 개선 (유사 환자로부터 정보 전파)
- 콜드스타트 환자 처리 가능

성능 비교

Benchmark Results

데이터셋 XGBoost TabNet GNN-only BGNN
Adult 0.876 0.871 0.868 0.882
HIGGS 0.753 0.748 0.745 0.759
Fraud 0.892 0.885 0.901 0.918

그래프 효과 분석

┌─────────────────────────────────────────────────────────────┐
│         When Does Graph Help?                               │
│                                                             │
│  효과 큼:                                                   │
│  ├─ 명시적 관계 존재 (소셜 네트워크, 거래)                 │
│  ├─ 레이블 희소 (semi-supervised)                          │
│  └─ 노이즈 많음 (이웃 정보로 보정)                         │
│                                                             │
│  효과 작음:                                                 │
│  ├─ IID 데이터 (샘플 간 독립)                              │
│  ├─ 충분한 레이블                                          │
│  └─ 저차원 데이터                                          │
└─────────────────────────────────────────────────────────────┘

구현 팁

1. 스케일링

# 대규모 그래프: 미니배치 학습
from torch_geometric.loader import NeighborLoader

loader = NeighborLoader(
    data,
    num_neighbors=[10, 5],  # 1홉 10개, 2홉 5개 샘플링
    batch_size=256,
    shuffle=True
)

2. 그래프 품질 검증

def validate_graph(edge_index, labels):
    """좋은 그래프인지 검증"""
    # 동질성 (Homophily): 같은 클래스끼리 연결 비율
    src, dst = edge_index
    same_label = (labels[src] == labels[dst]).float().mean()

    # 너무 낮으면 (< 0.3): 그래프가 태스크에 도움 안됨
    # 너무 높으면 (> 0.9): 그래프 없이도 잘 됨
    print(f"Homophily: {same_label:.3f}")

    return same_label

3. GBDT와 결합

# BGNN 스타일 파이프라인
from xgboost import XGBClassifier

# 1단계: GBDT
gbdt = XGBClassifier(n_estimators=100)
gbdt.fit(X_train, y_train)
gbdt_pred = gbdt.predict_proba(X_train)[:, 1]

# 2단계: 잔차 + GNN
residual = y_train - gbdt_pred
gnn_input = np.column_stack([X_train, gbdt_pred])
gnn_model.fit(gnn_input, edge_index, residual)

# 최종 예측
final_pred = gbdt_pred + gnn_model.predict(...)

참고 자료

  • Ivanov & Prokhorenkova, "Boost then Convolve: Gradient Boosting Meets Graph Neural Networks", ICLR 2021
  • Du et al., "GRAPE: GRaph neural network with Adaptive PEs for tabular data", NeurIPS 2021
  • You et al., "Graph Structure Learning for Tabular Data", ICML 2022