콘텐츠로 이동
Data Prep
상세

Training Data Attribution (TDA)

학습 데이터 귀속: 모델 예측에 대한 학습 데이터의 기여도 측정


개요

항목 내용
분야 Interpretable ML, Data-Centric AI
핵심 질문 "이 예측 결과에 어떤 학습 데이터가 가장 큰 영향을 미쳤는가?"
주요 방법론 Influence Functions, TracIn, TRAK, AirRep
응용 데이터 디버깅, 데이터 가치 평가, 모델 해석, 저작권 추적

Training Data Attribution (TDA)은 모델의 특정 예측에 대해 각 학습 데이터 포인트가 얼마나 기여했는지를 정량화하는 기법이다. 데이터 품질 이슈 탐지, 모델 행동 해석, 데이터 가치 평가 등에 활용된다.


핵심 방법론

1. Influence Functions (ICML 2017)

논문: Koh & Liang, "Understanding Black-box Predictions via Influence Functions"

Leave-one-out retraining의 근사 방법. 특정 학습 샘플 \(z_i\)를 제거했을 때 테스트 샘플 \(z_{test}\)의 손실 변화를 추정한다.

수학적 정의:

\[ \mathcal{I}(z_i, z_{test}) = -\nabla_\theta L(z_{test}, \hat{\theta})^\top H_{\hat{\theta}}^{-1} \nabla_\theta L(z_i, \hat{\theta}) \]
  • \(H_{\hat{\theta}}\): Hessian 행렬 \(\nabla^2_\theta \sum_i L(z_i, \theta)\)
  • \(\nabla_\theta L\): 파라미터에 대한 손실의 그래디언트

특징:

  • 이론적으로 엄밀한 근사
  • Hessian 역행렬 계산 필요 (계산 비용 높음)
  • 비볼록 모델에서 근사 오차 발생 가능

2. TracIn (NeurIPS 2020)

논문: Pruthi et al., "Estimating Training Data Influence by Tracing Gradient Descent"

학습 과정 중 체크포인트에서의 그래디언트 내적을 합산하여 영향력 추정.

수학적 정의:

\[ \text{TracIn}(z_i, z_{test}) = \sum_{t: z_i \in B_t} \eta_t \nabla_\theta L(z_{test}, \theta_t) \cdot \nabla_\theta L(z_i, \theta_t) \]
  • \(B_t\): 시점 \(t\)의 미니배치
  • \(\eta_t\): 학습률
  • 체크포인트별 그래디언트 내적의 가중 합

특징:

  • Hessian 계산 불필요
  • 체크포인트 저장 필요
  • 계산 효율성과 정확성의 균형

3. TRAK (ICML 2023)

논문: Park et al., "TRAK: Attributing Model Behavior at Scale"

랜덤 프로젝션을 활용한 확장 가능한 TDA 방법.

핵심 아이디어:

  1. 그래디언트를 저차원 공간으로 프로젝션
  2. 여러 모델 체크포인트의 결과를 앙상블
  3. 선형 데이터스토어로 효율적 검색

수학적 정의:

\[ \text{TRAK}(z_i, z_{test}) = \phi(z_{test})^\top \left( \Phi^\top \Phi + \lambda I \right)^{-1} \phi(z_i) \]
  • \(\phi(z)\): 그래디언트의 랜덤 프로젝션
  • \(\Phi\): 모든 학습 샘플의 프로젝션 행렬

특징:

  • 대규모 데이터셋에 확장 가능
  • ImageNet, CIFAR-10 등에서 검증
  • 오픈소스 라이브러리 제공

4. AirRep (NeurIPS 2025)

논문: Sun et al., "Enhancing Training Data Attribution with Representational Optimization"

표현 학습 기반의 최신 TDA 방법. 학습 가능한 인코더로 태스크 특화 표현을 추출.

핵심 구성요소:

컴포넌트 역할
Trainable Encoder 귀속 품질을 위한 표현 학습
Attention Pooling 그룹 단위 영향력 추정
Ranking Objective 경험적 효과 기반 학습

특징:

  • Influence Functions 대비 약 100배 추론 효율
  • 그룹 단위 귀속 지원
  • 다양한 다운스트림 태스크에 적용 가능

방법론 비교

방법 계산 복잡도 Hessian 필요 체크포인트 확장성 정확도
Influence Functions O(np + p^3) Yes No Low High (이론적)
TracIn O(nkp) No Yes (k개) Medium Medium
TRAK O(np + d^3) No Yes High High
AirRep O(np) No No Very High High
  • n: 학습 샘플 수, p: 파라미터 수, k: 체크포인트 수, d: 프로젝션 차원

Python 구현 예시

Influence Functions (간소화 버전)

import torch
import torch.nn.functional as F
from torch.autograd import grad

def compute_gradient(model, loss_fn, x, y):
    """단일 샘플에 대한 그래디언트 계산"""
    model.zero_grad()
    output = model(x.unsqueeze(0))
    loss = loss_fn(output, y.unsqueeze(0))
    grads = grad(loss, model.parameters(), create_graph=False)
    return torch.cat([g.flatten() for g in grads])

def influence_function(model, loss_fn, train_data, test_sample, 
                       damping=0.01, num_samples=100):
    """
    Influence Function 근사 계산
    LiSSA (Linear time Stochastic Second-order Algorithm) 사용
    """
    x_test, y_test = test_sample
    test_grad = compute_gradient(model, loss_fn, x_test, y_test)

    # Hessian-vector product를 LiSSA로 근사
    ihvp = test_grad.clone()
    for _ in range(num_samples):
        idx = torch.randint(len(train_data), (1,)).item()
        x_train, y_train = train_data[idx]
        train_grad = compute_gradient(model, loss_fn, x_train, y_train)

        # Hessian-vector product 근사
        hvp = compute_hvp(model, loss_fn, x_train, y_train, ihvp)
        ihvp = test_grad + (1 - damping) * ihvp - hvp / len(train_data)

    # 각 학습 샘플에 대한 영향력 계산
    influences = []
    for x_train, y_train in train_data:
        train_grad = compute_gradient(model, loss_fn, x_train, y_train)
        influence = -torch.dot(ihvp, train_grad).item()
        influences.append(influence)

    return influences

def compute_hvp(model, loss_fn, x, y, v):
    """Hessian-vector product 계산"""
    model.zero_grad()
    output = model(x.unsqueeze(0))
    loss = loss_fn(output, y.unsqueeze(0))

    grads = grad(loss, model.parameters(), create_graph=True)
    flat_grads = torch.cat([g.flatten() for g in grads])

    grad_v = torch.dot(flat_grads, v)
    hvp = grad(grad_v, model.parameters())

    return torch.cat([h.flatten() for h in hvp])

TracIn 구현

import torch
from typing import List, Tuple

class TracInAttributor:
    """TracIn 기반 Training Data Attribution"""

    def __init__(self, model_class, checkpoint_paths: List[str], 
                 learning_rates: List[float]):
        self.checkpoints = []
        self.learning_rates = learning_rates

        for path in checkpoint_paths:
            model = model_class()
            model.load_state_dict(torch.load(path))
            model.eval()
            self.checkpoints.append(model)

    def compute_gradient(self, model, loss_fn, x, y):
        """그래디언트 벡터 계산"""
        model.zero_grad()
        output = model(x.unsqueeze(0))
        loss = loss_fn(output, y.unsqueeze(0))
        loss.backward()

        grads = []
        for param in model.parameters():
            if param.grad is not None:
                grads.append(param.grad.flatten())
        return torch.cat(grads)

    def attribute(self, loss_fn, train_data, test_sample) -> List[float]:
        """
        각 학습 샘플의 테스트 샘플에 대한 영향력 계산

        Returns:
            List[float]: 각 학습 샘플의 TracIn 점수
        """
        x_test, y_test = test_sample
        influences = [0.0] * len(train_data)

        for ckpt_idx, model in enumerate(self.checkpoints):
            lr = self.learning_rates[ckpt_idx]

            # 테스트 샘플 그래디언트
            test_grad = self.compute_gradient(model, loss_fn, x_test, y_test)

            # 각 학습 샘플과의 그래디언트 내적
            for i, (x_train, y_train) in enumerate(train_data):
                train_grad = self.compute_gradient(model, loss_fn, 
                                                   x_train, y_train)
                influence = lr * torch.dot(test_grad, train_grad).item()
                influences[i] += influence

        return influences

# 사용 예시
# attributor = TracInAttributor(MyModel, checkpoint_paths, learning_rates)
# scores = attributor.attribute(F.cross_entropy, train_dataset, test_sample)
# top_influential = sorted(enumerate(scores), key=lambda x: x[1], reverse=True)[:10]

TRAK 라이브러리 활용

# pip install traker
from trak import TRAKer
from trak.projectors import CudaProjector

# TRAK 초기화
traker = TRAKer(
    model=model,
    task='image_classification',
    train_set_size=len(train_dataset),
    projector=CudaProjector(
        grad_dim=sum(p.numel() for p in model.parameters()),
        proj_dim=2048,
        seed=42
    ),
    proj_dim=2048,
    save_dir='./trak_results'
)

# 학습 데이터 피처 계산
traker.load_checkpoint(checkpoint_path)
for batch_idx, (inputs, targets) in enumerate(train_loader):
    traker.featurize(
        batch=batch_idx,
        inputs=inputs,
        targets=targets
    )

# Attribution 점수 계산
traker.finalize_features()

# 테스트 샘플에 대한 귀속
for batch_idx, (inputs, targets) in enumerate(test_loader):
    scores = traker.score(
        batch=batch_idx,
        inputs=inputs,
        targets=targets
    )

응용 분야

1. 데이터 디버깅

def find_mislabeled_samples(model, train_data, val_data, loss_fn, 
                            top_k=100):
    """
    검증 오류에 가장 큰 영향을 미치는 학습 샘플 탐지
    (잠재적 레이블 오류)
    """
    attributor = TracInAttributor(model, checkpoints, learning_rates)

    problematic_samples = {}

    for val_sample in val_data:
        if is_misclassified(model, val_sample):
            scores = attributor.attribute(loss_fn, train_data, val_sample)

            # 높은 긍정적 영향력 = 오분류에 기여
            top_indices = sorted(enumerate(scores), 
                                key=lambda x: x[1], 
                                reverse=True)[:top_k]

            for idx, score in top_indices:
                problematic_samples[idx] = problematic_samples.get(idx, 0) + 1

    # 빈도 기준 정렬
    return sorted(problematic_samples.items(), 
                  key=lambda x: x[1], 
                  reverse=True)

2. 데이터 가치 평가 (Data Valuation)

def compute_data_shapley_approx(model, train_data, val_data, 
                                 loss_fn, n_permutations=100):
    """
    TDA 기반 Data Shapley 근사
    """
    n = len(train_data)
    shapley_values = [0.0] * n

    attributor = TracInAttributor(model, checkpoints, learning_rates)

    # 검증 데이터 전체에 대한 평균 영향력
    for val_sample in val_data:
        scores = attributor.attribute(loss_fn, train_data, val_sample)
        for i, score in enumerate(scores):
            shapley_values[i] += score / len(val_data)

    return shapley_values

3. 저작권/출처 추적

LLM 출력에 대해 어떤 학습 문서가 영향을 미쳤는지 추적:

def trace_generation_sources(llm, train_corpus, generated_text, 
                             top_k=10):
    """
    생성된 텍스트에 가장 큰 영향을 미친 학습 문서 추적
    """
    # 생성 텍스트에 대한 귀속 점수 계산
    scores = compute_attribution_scores(llm, train_corpus, generated_text)

    # 상위 k개 소스 문서 반환
    top_sources = sorted(enumerate(scores), 
                         key=lambda x: x[1], 
                         reverse=True)[:top_k]

    return [(train_corpus[idx], score) for idx, score in top_sources]

평가 지표

TDA 방법의 품질을 평가하는 주요 지표:

지표 설명 계산 방법
LDS (Linear Datamodeling Score) 귀속 점수와 실제 영향의 상관관계 부분집합 제거 후 성능 변화와 예측값 비교
LOO Correlation Leave-One-Out과의 상관관계 실제 LOO 결과와 추정값의 Spearman 상관
Mislabel Detection 레이블 오류 탐지 정확도 의도적 오염 후 탐지율 측정

한계 및 주의사항

  1. 근사 오차: 대부분의 방법이 정확한 계산이 아닌 근사
  2. 모델 의존성: 비볼록 손실 함수에서 이론적 보장 약화
  3. 계산 비용: 대규모 모델/데이터셋에서 여전히 비용 높음
  4. 그룹 효과: 개별 샘플이 아닌 데이터 그룹의 시너지 효과 포착 어려움

참고 자료

논문 학회 연도
Understanding Black-box Predictions via Influence Functions ICML 2017
Estimating Training Data Influence by Tracing Gradient Descent NeurIPS 2020
TRAK: Attributing Model Behavior at Scale ICML 2023
Enhancing Training Data Attribution with Representational Optimization NeurIPS 2025
Data Shapley: Equitable Valuation of Data for Machine Learning ICML 2019

관련 문서: Data-Centric AI | Explainability