TabM: Advancing Tabular Deep Learning with Parameter-Efficient Ensembling
| 항목 |
내용 |
| 논문 |
TabM: Advancing Tabular Deep Learning with Parameter-Efficient Ensembling |
| 저자 |
Yury Gorishniy et al. (Yandex Research) |
| 학회 |
ICLR 2025 |
| arXiv |
2410.24210 |
| GitHub |
yandex-research/tabm |
핵심 아이디어
TabM은 Parameter-Efficient Ensembling을 통해 단일 MLP 모델이 여러 MLP의 앙상블처럼 동작하도록 설계된 tabular deep learning 아키텍처다.
기존 앙상블의 문제점
| 방식 |
문제점 |
| Traditional Deep Ensemble |
각 모델 독립 학습, 메모리/연산 비용 k배 |
| Dropout Ensemble |
학습/추론 시 행동 불일치 |
| Bagging |
데이터 서브샘플링으로 인한 정보 손실 |
TabM의 해결책
┌─────────────────────────────────────────────────────┐
│ TabM │
│ │
│ Input x ──┬──> MLP_1 (shared params) ──> pred_1 │
│ ├──> MLP_2 (shared params) ──> pred_2 │
│ ├──> MLP_3 (shared params) ──> pred_3 │
│ └──> ... ──> ... │
│ │
│ Output: mean(pred_1, pred_2, ..., pred_k) │
└─────────────────────────────────────────────────────┘
핵심 특징:
1. 동시 학습: k개의 implicit MLP가 동시에 학습
2. 파라미터 공유: 대부분의 가중치를 공유하여 효율성 확보
3. 정규화 효과: 파라미터 공유가 자연스러운 정규화로 작용
아키텍처 상세
EnsembleView와 Weight Sharing
TabM의 핵심 구성 요소:
| 컴포넌트 |
역할 |
| EnsembleView |
입력 텐서를 (B, D) -> (B, k, D)로 변환 |
| LinearEnsemble |
k개의 linear layer를 효율적으로 구현 |
| make_tabm_backbone |
MLP backbone 생성 |
아키텍처 변형 (arch_type)
| Type |
설명 |
특징 |
tabm |
기본 TabM |
가장 단순, 빠른 학습 |
tabm-mini |
MiniEnsemble 사용 |
더 적은 파라미터 공유 |
tabm-batch |
BatchEnsemble 사용 |
중간 수준 공유 |
Feature Embeddings
수치형 피처에 임베딩을 적용하면 성능이 향상된다:
| 임베딩 타입 |
설명 |
| LinearReLUEmbeddings |
단순, 빠름 |
| PiecewiseLinearEmbeddings |
비선형 변환, 더 좋은 성능 |
| PeriodicEmbeddings |
주기적 패턴에 효과적 |
성능 비교
Traditional Benchmarks (95개 데이터셋)
| 모델 |
평균 Rank |
특징 |
| TabM (with embeddings) |
1-2 |
Tabular DL SOTA |
| XGBoost |
2-3 |
여전히 강력한 baseline |
| CatBoost |
3-4 |
범주형에 강함 |
| TabR |
4-5 |
Retrieval 기반 |
| FT-Transformer |
5-6 |
Attention 기반 |
| MLP |
6-7 |
단순 baseline |
TabReD Benchmark (산업용 실데이터)
TabReD는 시간에 따른 분포 변화(distribution drift)와 수백 개의 피처를 가진 도전적인 벤치마크:
| 모델 |
평균 성능 |
| TabM |
0.82 |
| XGBoost |
0.79 |
| TabR |
0.77 |
| FT-Transformer |
0.73 |
효율성 비교
| 모델 |
학습 시간 (상대) |
추론 처리량 (상대) |
| MLP |
1x |
10x |
| TabM |
2-3x |
3-5x |
| TabR |
10x |
0.5x |
| FT-Transformer |
15x |
0.3x |
Python 구현
설치
기본 사용법
import torch
from tabm import TabM
# 데이터 설정
n_num_features = 24
cat_cardinalities = [3, 7] # 범주형 피처의 카디널리티
d_out = 1 # 회귀 태스크
# TabM 모델 생성
model = TabM.make(
n_num_features=n_num_features,
cat_cardinalities=cat_cardinalities,
d_out=d_out,
)
# 입력 데이터
batch_size = 256
x_num = torch.randn(batch_size, n_num_features)
x_cat = torch.column_stack([
torch.randint(0, c, (batch_size,))
for c in cat_cardinalities
])
# 예측 (k개의 앙상블 예측)
y_pred = model(x_num, x_cat)
# shape: (batch_size, k, d_out)
# 최종 예측: k개의 평균
final_pred = y_pred.mean(dim=1)
Feature Embeddings 적용
from tabm import TabM
from rtdl_num_embeddings import PiecewiseLinearEmbeddings
# 임베딩이 적용된 TabM (더 좋은 성능)
model = TabM.make(
n_num_features=n_num_features,
num_embeddings=PiecewiseLinearEmbeddings(
n_num_features,
n_bins=48, # 구간 개수
),
cat_cardinalities=cat_cardinalities,
d_out=d_out,
)
학습 루프
import torch.nn.functional as F
optimizer = torch.optim.AdamW(model.parameters(), lr=1e-3)
for epoch in range(100):
model.train()
optimizer.zero_grad()
y_pred = model(x_num, x_cat) # (B, k, d_out)
# 중요: k개의 예측에 대해 독립적으로 손실 계산
# mean loss를 최적화, not loss of mean prediction
loss = F.mse_loss(y_pred, y_target.unsqueeze(1).expand_as(y_pred))
loss.backward()
optimizer.step()
분류 태스크 추론
model.eval()
with torch.no_grad():
logits = model(x_num, x_cat) # (B, k, n_classes)
# 분류: 확률을 평균 (logits 아님!)
probs = F.softmax(logits, dim=-1)
avg_probs = probs.mean(dim=1) # (B, n_classes)
predictions = avg_probs.argmax(dim=-1)
하이퍼파라미터 가이드
기본 설정
# 대부분의 경우 잘 작동하는 기본값
model = TabM.make(
n_num_features=n_features,
d_out=d_out,
# 기본 하이퍼파라미터
# k=32 (앙상블 크기)
# arch_type='tabm'
# n_blocks=3
# d_block=256
)
튜닝 가이드
| 하이퍼파라미터 |
범위 |
영향 |
| k |
16-64 |
클수록 안정적, 느림 |
| n_blocks |
2-4 |
깊이 |
| d_block |
128-512 |
폭 |
| dropout |
0.0-0.3 |
정규화 |
| lr |
1e-4 - 1e-2 |
학습률 |
데이터셋 크기별 권장 설정
| 데이터 크기 |
arch_type |
k |
num_embeddings |
| < 10K |
tabm |
32 |
LinearReLU |
| 10K - 100K |
tabm |
32 |
PiecewiseLinear |
| > 100K |
tabm-mini |
16 |
PiecewiseLinear |
실제 적용 사례
Kaggle 우승 솔루션
- UM Game Playing Strength (2025): TabM 기반 1위
- CIBMTR HCT Survival (2025): TabM 단독으로 25위/3300+ 달성
적합한 사용 케이스
| 케이스 |
적합도 |
이유 |
| 중소규모 tabular 분류/회귀 |
높음 |
SOTA 성능 |
| 대규모 데이터 (10M+) |
중간 |
학습 시간 고려 필요 |
| 실시간 추론 |
중간 |
MLP보다 느리지만 합리적 |
| 해석 가능성 필요 |
낮음 |
블랙박스 모델 |
관련 연구
| 모델 |
특징 |
비교 |
| TabPFN |
In-context learning |
작은 데이터에 강함, 확장성 제한 |
| TabR |
Retrieval 기반 |
성능 좋으나 추론 느림 |
| FT-Transformer |
Attention 기반 |
범용적이나 비효율적 |
| XGBoost/CatBoost |
GBDT |
여전히 강력한 baseline |
참고 자료