콘텐츠로 이동
Data Prep
상세

Foundation Models for Tabular Data

개요

Tabular Foundation Models(TFM)은 NLP/Vision에서 성공한 대규모 사전 학습 패러다임을 테이블형 데이터에 적용한 모델이다. 다양한 테이블 데이터셋에서 학습하여 범용적인 표현을 획득하고, 새로운 태스크에 적은 데이터로 전이학습한다.

핵심 개념

Foundation Model이란?

대규모 데이터로 사전 학습되어 다양한 downstream 태스크에 적응 가능한 모델이다. GPT, BERT(NLP), CLIP, ViT(Vision)가 대표적이며, 테이블 데이터에도 이 패러다임을 적용하려는 시도가 활발하다.

특성 설명
사전 학습 대규모 데이터로 일반적 패턴 학습
전이 학습 새 태스크에 적은 데이터로 적응
Zero/Few-shot 학습 없이 또는 소량 데이터로 예측

패러다임 비교

foundation-models diagram


주요 모델

1. TABULA

테이블 데이터를 자연어 문장으로 변환하여 LLM으로 처리하는 접근법이다.

foundation-models diagram

장점 단점
기존 LLM 활용 가능 토큰 수 증가 (비효율적)
Zero-shot 예측 가능 수치형 데이터 정밀도 손실
컬럼 메타데이터 활용 대규모 테이블 처리 불가
구현이 간단함 추론 비용이 높음

실제 사용 예시:

def table_to_text(row: dict, schema: dict) -> str:
    """테이블 행을 자연어 문장으로 변환"""
    sentences = []
    for col, value in row.items():
        col_desc = schema.get(col, {}).get('description', col)
        sentences.append(f"The {col_desc} is {value}.")
    return " ".join(sentences)

# 예시
row = {"age": 35, "job": "sales", "income": 50000}
schema = {
    "age": {"description": "customer age"},
    "job": {"description": "occupation"},
    "income": {"description": "annual income in USD"}
}

text = table_to_text(row, schema)
# "The customer age is 35. The occupation is sales. The annual income in USD is 50000."

2. XTab

Cross-Table Pre-training을 통해 여러 테이블에서 공통 패턴을 학습하는 모델이다.

foundation-models diagram

핵심 아이디어:

  1. Column-wise Tokenization: 각 컬럼 타입에 맞는 임베딩 전략
  2. Cross-table Learning: 서로 다른 스키마의 테이블에서 공통 표현 학습
  3. Semantic Matching: 컬럼 이름/설명의 의미적 유사도로 매칭

3. TabTransformer

범주형 피처에 Transformer를 적용하여 컬럼 간 관계를 학습한다.

import torch
import torch.nn as nn

class TabTransformer(nn.Module):
    """
    TabTransformer: 범주형 피처에 Transformer 적용

    Args:
        cat_dims: 각 범주형 컬럼의 고유값 개수 리스트
        num_features: 수치형 피처 개수
        d_model: 임베딩 차원
        n_heads: 어텐션 헤드 수
        n_layers: Transformer 레이어 수
        n_classes: 출력 클래스 수
    """

    def __init__(
        self, 
        cat_dims: list[int],
        num_features: int,
        d_model: int = 32,
        n_heads: int = 8,
        n_layers: int = 6,
        n_classes: int = 2,
        dropout: float = 0.1
    ):
        super().__init__()

        # 범주형 피처별 임베딩 레이어
        self.cat_embeddings = nn.ModuleList([
            nn.Embedding(dim + 1, d_model)  # +1 for unknown
            for dim in cat_dims
        ])

        # 컬럼 위치 임베딩 (어떤 컬럼인지 구분)
        self.col_embedding = nn.Embedding(len(cat_dims), d_model)

        # Transformer Encoder
        encoder_layer = nn.TransformerEncoderLayer(
            d_model=d_model,
            nhead=n_heads,
            dim_feedforward=d_model * 4,
            dropout=dropout,
            batch_first=True
        )
        self.transformer = nn.TransformerEncoder(
            encoder_layer, 
            num_layers=n_layers
        )

        # 수치형 피처 MLP
        self.num_mlp = nn.Sequential(
            nn.Linear(num_features, d_model),
            nn.ReLU(),
            nn.Dropout(dropout),
            nn.Linear(d_model, d_model)
        )

        # 분류 헤드
        total_dim = d_model * (len(cat_dims) + 1)  # cat + num
        self.classifier = nn.Sequential(
            nn.Linear(total_dim, d_model),
            nn.ReLU(),
            nn.Dropout(dropout),
            nn.Linear(d_model, n_classes)
        )

    def forward(
        self, 
        cat_features: torch.Tensor,  # (batch, n_cat)
        num_features: torch.Tensor   # (batch, n_num)
    ) -> torch.Tensor:
        batch_size = cat_features.size(0)

        # 범주형: Embedding + Position
        cat_embeds = []
        for i, emb in enumerate(self.cat_embeddings):
            col_emb = emb(cat_features[:, i])  # (batch, d_model)
            pos_emb = self.col_embedding(
                torch.tensor(i, device=cat_features.device)
            )
            cat_embeds.append(col_emb + pos_emb)

        cat_embeds = torch.stack(cat_embeds, dim=1)  # (batch, n_cat, d_model)

        # Transformer로 컬럼 간 상호작용 학습
        cat_out = self.transformer(cat_embeds)  # (batch, n_cat, d_model)
        cat_out = cat_out.flatten(1)  # (batch, n_cat * d_model)

        # 수치형: MLP
        num_out = self.num_mlp(num_features)  # (batch, d_model)

        # 결합 및 분류
        combined = torch.cat([cat_out, num_out], dim=-1)
        logits = self.classifier(combined)

        return logits


# 사용 예시
model = TabTransformer(
    cat_dims=[10, 5, 3],  # 3개 범주형 컬럼
    num_features=7,       # 7개 수치형 컬럼
    d_model=64,
    n_heads=4,
    n_layers=3,
    n_classes=2
)

# 입력
cat_x = torch.randint(0, 10, (32, 3))  # batch=32, 3 categorical
num_x = torch.randn(32, 7)              # batch=32, 7 numerical

# 추론
logits = model(cat_x, num_x)  # (32, 2)

4. SAINT (Self-Attention and Intersample Attention)

SAINT는 컬럼 간 관계(Column Attention)와 샘플 간 관계(Row Attention)를 모두 학습한다.

foundation-models diagram

Row Attention의 장점:

장점 설명
유사 샘플 참조 비슷한 특성의 샘플 정보 활용
Semi-supervised 라벨 없는 샘플도 문맥으로 활용
Noisy 데이터 강건 이상치의 영향 완화

사전 학습 전략

1. Masked Feature Modeling

BERT의 Masked Language Modeling을 테이블에 적용한다.

import torch
import torch.nn as nn

class MaskedFeatureModeling:
    """
    테이블 데이터를 위한 Masked Feature Modeling

    BERT처럼 일부 피처를 마스킹하고 복원하도록 학습
    """

    def __init__(
        self, 
        mask_ratio: float = 0.15,
        mask_token_id: int = -1
    ):
        self.mask_ratio = mask_ratio
        self.mask_token_id = mask_token_id

    def mask_features(
        self, 
        table: torch.Tensor
    ) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
        """
        테이블의 일부 셀을 마스킹

        Args:
            table: (n_samples, n_features) 형태의 텐서

        Returns:
            masked_table: 마스킹된 테이블
            mask: 마스킹 위치 (True = 마스킹됨)
            targets: 마스킹된 원본 값
        """
        n_samples, n_features = table.shape

        # 랜덤 마스크 생성
        mask = torch.rand(n_samples, n_features) < self.mask_ratio

        # 마스킹 적용
        masked_table = table.clone()
        masked_table[mask] = self.mask_token_id

        # 복원 타겟
        targets = table[mask]

        return masked_table, mask, targets

    def compute_loss(
        self,
        predictions: torch.Tensor,
        targets: torch.Tensor,
        feature_types: list[str]
    ) -> torch.Tensor:
        """
        피처 타입에 따른 손실 계산

        수치형: MSE Loss
        범주형: Cross Entropy Loss
        """
        total_loss = 0

        for i, ftype in enumerate(feature_types):
            pred_i = predictions[:, i]
            target_i = targets[:, i]

            if ftype == 'numerical':
                loss = nn.MSELoss()(pred_i, target_i)
            else:  # categorical
                loss = nn.CrossEntropyLoss()(pred_i, target_i.long())

            total_loss += loss

        return total_loss / len(feature_types)


# 사용 예시
mfm = MaskedFeatureModeling(mask_ratio=0.15)

# 학습 루프
for batch in dataloader:
    masked, mask, targets = mfm.mask_features(batch)

    # 모델이 마스킹된 값 예측
    predictions = model(masked)

    # 손실 계산 (마스킹된 위치만)
    loss = mfm.compute_loss(predictions[mask], targets, feature_types)

    loss.backward()
    optimizer.step()

2. Contrastive Learning

같은 샘플의 다른 augmentation은 가깝게, 다른 샘플은 멀게 학습한다.

import torch
import torch.nn.functional as F

class TabularContrastive:
    """
    테이블 데이터를 위한 Contrastive Learning
    """

    def __init__(self, temperature: float = 0.07):
        self.temperature = temperature

    def augment(
        self, 
        table: torch.Tensor,
        aug_type: str = 'dropout'
    ) -> torch.Tensor:
        """데이터 증강"""
        if aug_type == 'dropout':
            # 일부 피처 드롭
            mask = torch.rand_like(table) > 0.1
            return table * mask
        elif aug_type == 'noise':
            # 가우시안 노이즈 추가
            noise = torch.randn_like(table) * 0.01
            return table + noise
        elif aug_type == 'swap':
            # 피처 값 일부 교환
            idx = torch.randperm(table.size(1))[:2]
            augmented = table.clone()
            augmented[:, idx[0]], augmented[:, idx[1]] = \
                table[:, idx[1]], table[:, idx[0]]
            return augmented
        else:
            return table

    def nt_xent_loss(
        self,
        z1: torch.Tensor,
        z2: torch.Tensor
    ) -> torch.Tensor:
        """
        NT-Xent Loss (SimCLR)

        같은 샘플의 두 augmentation은 positive pair
        다른 샘플은 negative pair
        """
        batch_size = z1.size(0)

        # 정규화
        z1 = F.normalize(z1, dim=1)
        z2 = F.normalize(z2, dim=1)

        # 유사도 행렬 계산
        z = torch.cat([z1, z2], dim=0)  # (2B, D)
        sim_matrix = torch.mm(z, z.t()) / self.temperature  # (2B, 2B)

        # 자기 자신과의 유사도 제외
        mask = torch.eye(2 * batch_size, device=z.device).bool()
        sim_matrix = sim_matrix.masked_fill(mask, -float('inf'))

        # Positive pairs: (i, i+B) and (i+B, i)
        pos_sim = torch.cat([
            torch.diag(sim_matrix, batch_size),
            torch.diag(sim_matrix, -batch_size)
        ])

        # Contrastive loss
        loss = -pos_sim + torch.logsumexp(sim_matrix, dim=1)

        return loss.mean()


# 사용 예시
contrastive = TabularContrastive(temperature=0.07)

for batch in dataloader:
    # 두 가지 augmentation 적용
    aug1 = contrastive.augment(batch, 'dropout')
    aug2 = contrastive.augment(batch, 'noise')

    # 인코더로 표현 추출
    z1 = encoder(aug1)
    z2 = encoder(aug2)

    # Contrastive loss
    loss = contrastive.nt_xent_loss(z1, z2)

    loss.backward()
    optimizer.step()

3. Cross-Table Pre-training

foundation-models diagram


성능 비교

OpenML Benchmarks

모델 평균 AUC 학습 시간 Zero-shot 튜닝 필요
XGBoost (튜닝) 0.847 중간 X O
TabTransformer 0.832 높음 X O
SAINT 0.841 높음 X O
TabPFN 0.852 매우 낮음 O X
XTab (fine-tuned) 0.855 중간 부분적 O

Few-shot Learning 성능

foundation-models diagram

핵심 인사이트:

  • 소량 데이터 (< 100): Foundation Model이 XGBoost를 크게 앞섬
  • 대량 데이터 (> 1000): XGBoost가 점차 따라잡거나 역전
  • Pre-training 효과: XTab이 가장 일관된 성능, 사전학습의 가치 입증

실무 적용 가이드

언제 Foundation Model을 사용할까?

상황 권장 모델 이유
데이터 < 100개 TabPFN 튜닝 없이 높은 성능
데이터 100-1000개 XTab, TabPFN 앙상블 Few-shot 강점
데이터 > 10000개 XGBoost, LightGBM 대용량에서 여전히 강력
빠른 프로토타입 TabPFN 1초 이내 학습
해석 필요 Tree 모델 SHAP 등 지원
레이블 비용 높음 Foundation Model 적은 라벨로 학습

주의사항

Foundation Model이 항상 최선은 아니다:

  1. 대규모 라벨 데이터: GBDT가 여전히 강력
  2. 도메인 특화 피처: 전통 ML + 피처 엔지니어링이 효과적
  3. 해석 필요: 선형 모델, 의사결정 트리 권장
  4. 실시간 추론: 경량 모델 필요 (TabPFN은 무거움)

흔한 에러와 해결책

에러 원인 해결책
CUDA out of memory TabPFN 배치 크기 초과 batch_size 줄이기, CPU 사용
낮은 성능 스키마 불일치 컬럼명 표준화, 설명 추가
느린 추론 Transformer 오버헤드 캐싱, 배치 처리
NaN 예측 범주 인코딩 오류 Unknown 토큰 처리 추가

구현 예시: Fine-tuning with LoRA

from transformers import AutoModel
from peft import LoraConfig, get_peft_model
import torch

def setup_tabular_lora(
    base_model: str = "meta-llama/Llama-2-7b-hf",
    r: int = 8,
    lora_alpha: int = 32,
    lora_dropout: float = 0.05
):
    """
    TABULA 스타일: 테이블 → 텍스트 → LLM with LoRA
    """
    # LoRA 설정
    config = LoraConfig(
        r=r,
        lora_alpha=lora_alpha,
        target_modules=["q_proj", "v_proj", "k_proj", "o_proj"],
        lora_dropout=lora_dropout,
        task_type="CAUSAL_LM"
    )

    # 베이스 모델 로드 및 LoRA 적용
    base = AutoModel.from_pretrained(base_model, torch_dtype=torch.float16)
    model = get_peft_model(base, config)

    # 학습 가능 파라미터 확인
    trainable = sum(p.numel() for p in model.parameters() if p.requires_grad)
    total = sum(p.numel() for p in model.parameters())
    print(f"Trainable: {trainable:,} / {total:,} ({100*trainable/total:.2f}%)")

    return model


def table_to_prompt(row: dict, columns: list, label_col: str = None) -> str:
    """테이블 행을 프롬프트로 변환"""
    feature_text = " ".join([
        f"The {col} is {row[col]}." 
        for col in columns if col != label_col
    ])

    if label_col and label_col in row:
        return f"{feature_text} The {label_col} is {row[label_col]}."
    else:
        return f"{feature_text} The {label_col} is"


# 학습 예시
def train_tabular_llm(model, tokenizer, train_df, columns, label_col):
    from torch.utils.data import DataLoader

    optimizer = torch.optim.AdamW(model.parameters(), lr=2e-5)

    for epoch in range(3):
        for i, row in train_df.iterrows():
            prompt = table_to_prompt(row.to_dict(), columns, label_col)

            inputs = tokenizer(prompt, return_tensors="pt", truncation=True)
            inputs = {k: v.to(model.device) for k, v in inputs.items()}

            outputs = model(**inputs, labels=inputs["input_ids"])
            loss = outputs.loss

            optimizer.zero_grad()
            loss.backward()
            optimizer.step()

            if i % 100 == 0:
                print(f"Epoch {epoch}, Step {i}, Loss: {loss.item():.4f}")

향후 전망

foundation-models diagram


참고 자료

논문 학회 핵심 기여
Huang et al., "XTab" ICML 2023 Cross-table pretraining
Gorishniy et al., "Revisiting Deep Learning for Tabular" NeurIPS 2021 Transformer vs GBDT 벤치마크
Somepalli et al., "SAINT" NeurIPS 2021 Workshop Inter-sample attention
Hegselmann et al., "TabLLM" AISTATS 2023 LLM for tabular few-shot
Hollmann et al., "TabPFN" ICLR 2023 In-context learning for tabular

추가 리소스: