콘텐츠로 이동
Data Prep
상세

TabM: Advancing Tabular Deep Learning with Parameter-Efficient Ensembling

항목 내용
논문 TabM: Advancing Tabular Deep Learning with Parameter-Efficient Ensembling
저자 Yury Gorishniy et al. (Yandex Research)
학회 ICLR 2025
arXiv 2410.24210
GitHub yandex-research/tabm

핵심 아이디어

TabM은 Parameter-Efficient Ensembling을 통해 단일 MLP 모델이 여러 MLP의 앙상블처럼 동작하도록 설계된 tabular deep learning 아키텍처다.

기존 앙상블의 문제점

방식 문제점
Traditional Deep Ensemble 각 모델 독립 학습, 메모리/연산 비용 k배
Dropout Ensemble 학습/추론 시 행동 불일치
Bagging 데이터 서브샘플링으로 인한 정보 손실

TabM의 해결책

┌─────────────────────────────────────────────────────┐
│                      TabM                            │
│                                                      │
│   Input x ──┬──> MLP_1 (shared params) ──> pred_1   │
│             ├──> MLP_2 (shared params) ──> pred_2   │
│             ├──> MLP_3 (shared params) ──> pred_3   │
│             └──> ...                   ──> ...       │
│                                                      │
│   Output: mean(pred_1, pred_2, ..., pred_k)         │
└─────────────────────────────────────────────────────┘

핵심 특징: 1. 동시 학습: k개의 implicit MLP가 동시에 학습 2. 파라미터 공유: 대부분의 가중치를 공유하여 효율성 확보 3. 정규화 효과: 파라미터 공유가 자연스러운 정규화로 작용


아키텍처 상세

EnsembleView와 Weight Sharing

TabM의 핵심 구성 요소:

컴포넌트 역할
EnsembleView 입력 텐서를 (B, D) -> (B, k, D)로 변환
LinearEnsemble k개의 linear layer를 효율적으로 구현
make_tabm_backbone MLP backbone 생성

아키텍처 변형 (arch_type)

Type 설명 특징
tabm 기본 TabM 가장 단순, 빠른 학습
tabm-mini MiniEnsemble 사용 더 적은 파라미터 공유
tabm-batch BatchEnsemble 사용 중간 수준 공유

Feature Embeddings

수치형 피처에 임베딩을 적용하면 성능이 향상된다:

임베딩 타입 설명
LinearReLUEmbeddings 단순, 빠름
PiecewiseLinearEmbeddings 비선형 변환, 더 좋은 성능
PeriodicEmbeddings 주기적 패턴에 효과적

성능 비교

Traditional Benchmarks (95개 데이터셋)

모델 평균 Rank 특징
TabM (with embeddings) 1-2 Tabular DL SOTA
XGBoost 2-3 여전히 강력한 baseline
CatBoost 3-4 범주형에 강함
TabR 4-5 Retrieval 기반
FT-Transformer 5-6 Attention 기반
MLP 6-7 단순 baseline

TabReD Benchmark (산업용 실데이터)

TabReD는 시간에 따른 분포 변화(distribution drift)와 수백 개의 피처를 가진 도전적인 벤치마크:

모델 평균 성능
TabM 0.82
XGBoost 0.79
TabR 0.77
FT-Transformer 0.73

효율성 비교

모델 학습 시간 (상대) 추론 처리량 (상대)
MLP 1x 10x
TabM 2-3x 3-5x
TabR 10x 0.5x
FT-Transformer 15x 0.3x

Python 구현

설치

pip install tabm

기본 사용법

import torch
from tabm import TabM

# 데이터 설정
n_num_features = 24
cat_cardinalities = [3, 7]  # 범주형 피처의 카디널리티
d_out = 1  # 회귀 태스크

# TabM 모델 생성
model = TabM.make(
    n_num_features=n_num_features,
    cat_cardinalities=cat_cardinalities,
    d_out=d_out,
)

# 입력 데이터
batch_size = 256
x_num = torch.randn(batch_size, n_num_features)
x_cat = torch.column_stack([
    torch.randint(0, c, (batch_size,)) 
    for c in cat_cardinalities
])

# 예측 (k개의 앙상블 예측)
y_pred = model(x_num, x_cat)
# shape: (batch_size, k, d_out)

# 최종 예측: k개의 평균
final_pred = y_pred.mean(dim=1)

Feature Embeddings 적용

from tabm import TabM
from rtdl_num_embeddings import PiecewiseLinearEmbeddings

# 임베딩이 적용된 TabM (더 좋은 성능)
model = TabM.make(
    n_num_features=n_num_features,
    num_embeddings=PiecewiseLinearEmbeddings(
        n_num_features, 
        n_bins=48,  # 구간 개수
    ),
    cat_cardinalities=cat_cardinalities,
    d_out=d_out,
)

학습 루프

import torch.nn.functional as F

optimizer = torch.optim.AdamW(model.parameters(), lr=1e-3)

for epoch in range(100):
    model.train()
    optimizer.zero_grad()

    y_pred = model(x_num, x_cat)  # (B, k, d_out)

    # 중요: k개의 예측에 대해 독립적으로 손실 계산
    # mean loss를 최적화, not loss of mean prediction
    loss = F.mse_loss(y_pred, y_target.unsqueeze(1).expand_as(y_pred))

    loss.backward()
    optimizer.step()

분류 태스크 추론

model.eval()
with torch.no_grad():
    logits = model(x_num, x_cat)  # (B, k, n_classes)

    # 분류: 확률을 평균 (logits 아님!)
    probs = F.softmax(logits, dim=-1)
    avg_probs = probs.mean(dim=1)  # (B, n_classes)

    predictions = avg_probs.argmax(dim=-1)

하이퍼파라미터 가이드

기본 설정

# 대부분의 경우 잘 작동하는 기본값
model = TabM.make(
    n_num_features=n_features,
    d_out=d_out,
    # 기본 하이퍼파라미터
    # k=32 (앙상블 크기)
    # arch_type='tabm'
    # n_blocks=3
    # d_block=256
)

튜닝 가이드

하이퍼파라미터 범위 영향
k 16-64 클수록 안정적, 느림
n_blocks 2-4 깊이
d_block 128-512
dropout 0.0-0.3 정규화
lr 1e-4 - 1e-2 학습률

데이터셋 크기별 권장 설정

데이터 크기 arch_type k num_embeddings
< 10K tabm 32 LinearReLU
10K - 100K tabm 32 PiecewiseLinear
> 100K tabm-mini 16 PiecewiseLinear

실제 적용 사례

Kaggle 우승 솔루션

  1. UM Game Playing Strength (2025): TabM 기반 1위
  2. CIBMTR HCT Survival (2025): TabM 단독으로 25위/3300+ 달성

적합한 사용 케이스

케이스 적합도 이유
중소규모 tabular 분류/회귀 높음 SOTA 성능
대규모 데이터 (10M+) 중간 학습 시간 고려 필요
실시간 추론 중간 MLP보다 느리지만 합리적
해석 가능성 필요 낮음 블랙박스 모델

관련 연구

모델 특징 비교
TabPFN In-context learning 작은 데이터에 강함, 확장성 제한
TabR Retrieval 기반 성능 좋으나 추론 느림
FT-Transformer Attention 기반 범용적이나 비효율적
XGBoost/CatBoost GBDT 여전히 강력한 baseline

참고 자료