콘텐츠로 이동
Data Prep
상세

Dataset Distillation

메타 정보

항목 내용
분류 Data-Centric AI / Efficient Learning
핵심 논문 "Dataset Distillation" (Wang et al., 2018 - 최초 제안), "Dataset Condensation with Gradient Matching" (Zhao et al., ICLR 2021 - DC), "Dataset Distillation by Matching Training Trajectories" (Cazenavette et al., CVPR 2022 - MTT), "Distribution Matching for Dataset Distillation" (Zhao & Bilen, 2023 - DM)
주요 저자 Tongzhou Wang (MIT), Bo Zhao (Edinburgh), George Cazenavette (CMU), Hakan Bilen
핵심 개념 대규모 학습 데이터셋을 소규모 합성 데이터셋으로 압축하여, 합성 데이터로 학습한 모델이 원본 데이터로 학습한 것과 유사한 성능을 달성하도록 하는 기법
관련 분야 Coreset Selection, Knowledge Distillation, Data-Centric AI, Continual Learning, Privacy

정의

Dataset Distillation(DD)은 대규모 학습 데이터셋 \(\mathcal{T} = \{(x_i, y_i)\}_{i=1}^{N}\)를 소규모 합성 데이터셋 \(\mathcal{S} = \{(\tilde{x}_j, \tilde{y}_j)\}_{j=1}^{M}\) (\(M \ll N\))으로 압축하는 기법이다. 핵심 목표는 \(\mathcal{S}\)로 학습한 모델의 성능이 \(\mathcal{T}\)로 학습한 모델의 성능에 근접하도록 하는 것이다.

왜 필요한가

문제: 데이터가 커질수록 학습 비용이 기하급수적으로 증가

  원본 데이터: ImageNet (1.28M 이미지, ~150GB)
  학습 시간: GPU 수백 시간
  NAS 등 반복 실험: 수천 번 학습 필요

  --> 원본과 동등한 정보를 담은 소규모 데이터셋이 있다면?
  --> 학습 비용 대폭 절감 가능

Coreset Selection과의 차이

항목 Coreset Selection Dataset Distillation
데이터 소스 원본에서 부분집합 선택 새로운 합성 데이터 생성
데이터 형태 실제 데이터 포인트 합성(최적화된) 데이터
정보 밀도 원본 수준 원본보다 높음 (압축)
해석 가능성 높음 (실제 샘플) 낮을 수 있음 (추상적 패턴)
압축률 제한적 극단적 압축 가능 (1-50 IPC)

*IPC = Images Per Class


문제 정의

이중 최적화 (Bi-level Optimization)

Dataset Distillation의 일반적 정의:

\[ \mathcal{S}^* = \arg\min_{\mathcal{S}} \mathcal{L}(\theta^*(\mathcal{S}), \mathcal{T}) \]
\[ \text{s.t.} \quad \theta^*(\mathcal{S}) = \arg\min_{\theta} \mathcal{L}(\theta, \mathcal{S}) \]
  • 외부 루프: 합성 데이터 \(\mathcal{S}\)를 최적화 (원본 데이터에서의 성능 최대화)
  • 내부 루프: \(\mathcal{S}\)로 모델 파라미터 \(\theta\) 학습

주요 매칭 전략

Dataset Distillation 방법론 분류:

+------------------------------------------------------------------+
|                                                                  |
|  1. Performance Matching (Meta-Learning)                        |
|  ============================================                    |
|  - 합성 데이터로 학습한 모델이                                   |
|    원본 데이터에서 좋은 성능을 내도록 최적화                     |
|  - 대표: DD (Wang et al., 2018), KIP (Nguyen et al., 2021)     |
|  - 계산 비용 높음 (이중 최적화 필요)                            |
|                                                                  |
+------------------------------------------------------------------+
          |
          v
+------------------------------------------------------------------+
|                                                                  |
|  2. Parameter Matching (Gradient/Trajectory)                    |
|  ============================================                    |
|  - 학습 과정의 중간 상태를 매칭                                 |
|  - Gradient Matching: 원본/합성 데이터의 gradient 방향 일치     |
|  - Trajectory Matching: 학습 궤적 자체를 모방                   |
|  - 대표: DC (Zhao et al., 2021), MTT (Cazenavette et al., 2022)|
|                                                                  |
+------------------------------------------------------------------+
          |
          v
+------------------------------------------------------------------+
|                                                                  |
|  3. Distribution Matching                                       |
|  ============================================                    |
|  - 합성 데이터의 특성 분포를 원본 데이터와 일치                 |
|  - 특성 공간에서의 MMD (Maximum Mean Discrepancy) 최소화        |
|  - 내부 루프 불필요 -> 빠른 학습                                |
|  - 대표: DM (Zhao & Bilen, 2023), CAFE (Wang et al., 2022)     |
|                                                                  |
+------------------------------------------------------------------+
          |
          v
+------------------------------------------------------------------+
|                                                                  |
|  4. Generative Model 기반                                       |
|  ============================================                    |
|  - Diffusion Model 등을 활용하여 합성 데이터 생성               |
|  - 잠재 공간에서의 distillation                                 |
|  - 대표: GLaD (Cazenavette et al., 2023), IT-GAN (Zhao, 2022)  |
|                                                                  |
+------------------------------------------------------------------+

핵심 방법론

1. DD - Dataset Distillation (Wang et al., 2018)

최초의 Dataset Distillation 논문. Meta-learning 기반 접근.

목적 함수:

\[ \mathcal{S}^* = \arg\min_{\mathcal{S}} \mathbb{E}_{\theta_0} \left[ \mathcal{L}\left( \text{GD}(\theta_0, \mathcal{S}, \eta, T), \mathcal{T} \right) \right] \]

여기서 \(\text{GD}(\theta_0, \mathcal{S}, \eta, T)\)는 초기 파라미터 \(\theta_0\)에서 \(\mathcal{S}\)\(T\)스텝 학습한 결과.

한계: 이중 최적화의 Unrolled Gradient가 필요하여 메모리/연산 비용이 높음.

2. DC - Dataset Condensation with Gradient Matching (Zhao et al., ICLR 2021)

핵심 아이디어: 합성 데이터와 원본 데이터에서 계산한 gradient의 방향을 일치시킴.

목적 함수:

\[ \min_{\mathcal{S}} \mathbb{E}_{\theta_0 \sim P_{\theta_0}} \sum_{t=0}^{T-1} D\left( \nabla_\theta \mathcal{L}(\theta_t, \mathcal{S}), \nabla_\theta \mathcal{L}(\theta_t, \mathcal{B}_t) \right) \]

여기서 \(D\)는 gradient 간의 거리 함수 (보통 cosine similarity), \(\mathcal{B}_t\)는 원본 데이터의 미니배치.

import torch
import torch.nn as nn
import torch.nn.functional as F
from torchvision import datasets, transforms

def gradient_matching_step(
    model: nn.Module,
    real_images: torch.Tensor,
    real_labels: torch.Tensor,
    syn_images: nn.Parameter,
    syn_labels: torch.Tensor,
    criterion: nn.Module
) -> torch.Tensor:
    """
    Gradient Matching: 합성 데이터의 gradient를
    원본 데이터의 gradient에 매칭
    """
    # 원본 데이터 gradient
    output_real = model(real_images)
    loss_real = criterion(output_real, real_labels)
    grad_real = torch.autograd.grad(loss_real, model.parameters(), create_graph=True)

    # 합성 데이터 gradient
    output_syn = model(syn_images)
    loss_syn = criterion(output_syn, syn_labels)
    grad_syn = torch.autograd.grad(loss_syn, model.parameters(), create_graph=True)

    # Cosine similarity 기반 매칭 손실
    matching_loss = 0.0
    for g_real, g_syn in zip(grad_real, grad_syn):
        g_real_flat = g_real.reshape(1, -1)
        g_syn_flat = g_syn.reshape(1, -1)
        cosine_sim = F.cosine_similarity(g_real_flat, g_syn_flat)
        matching_loss += (1 - cosine_sim).mean()

    return matching_loss


def distill_dataset(
    num_classes: int = 10,
    ipc: int = 10,
    image_size: int = 32,
    channels: int = 3,
    num_iterations: int = 1000,
    lr_syn: float = 0.1
):
    """
    DC 알고리즘의 단순화된 구현.

    Args:
        ipc: Images Per Class (클래스당 합성 이미지 수)
    """
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

    # 합성 데이터 초기화 (학습 가능한 파라미터)
    syn_images = nn.Parameter(
        torch.randn(num_classes * ipc, channels, image_size, image_size, device=device)
    )
    syn_labels = torch.repeat_interleave(
        torch.arange(num_classes, device=device), ipc
    )

    optimizer_syn = torch.optim.SGD([syn_images], lr=lr_syn, momentum=0.5)
    criterion = nn.CrossEntropyLoss()

    # 원본 데이터 로더 (예: CIFAR-10)
    transform = transforms.Compose([
        transforms.ToTensor(),
        transforms.Normalize((0.5,), (0.5,))
    ])
    dataset = datasets.CIFAR10(root="./data", train=True, transform=transform)
    loader = torch.utils.data.DataLoader(dataset, batch_size=256, shuffle=True)

    for iteration in range(num_iterations):
        # 랜덤 네트워크 초기화
        model = simple_convnet(channels, num_classes).to(device)

        real_images, real_labels = next(iter(loader))
        real_images, real_labels = real_images.to(device), real_labels.to(device)

        # Gradient matching loss 계산
        loss = gradient_matching_step(
            model, real_images, real_labels,
            syn_images, syn_labels, criterion
        )

        optimizer_syn.zero_grad()
        loss.backward()
        optimizer_syn.step()

        if iteration % 100 == 0:
            print(f"[{iteration}/{num_iterations}] Loss: {loss.item():.4f}")

    return syn_images.detach(), syn_labels


def simple_convnet(channels: int, num_classes: int) -> nn.Module:
    """간단한 3층 ConvNet"""
    return nn.Sequential(
        nn.Conv2d(channels, 128, 3, padding=1),
        nn.GroupNorm(8, 128),
        nn.ReLU(),
        nn.AvgPool2d(2),
        nn.Conv2d(128, 256, 3, padding=1),
        nn.GroupNorm(8, 256),
        nn.ReLU(),
        nn.AvgPool2d(2),
        nn.Flatten(),
        nn.Linear(256 * 8 * 8, num_classes)
    )

3. MTT - Matching Training Trajectories (Cazenavette et al., CVPR 2022)

핵심 아이디어: 단일 gradient가 아닌, 전체 학습 궤적(trajectory)을 매칭.

목적 함수:

\[ \min_{\mathcal{S}} \sum_{\tau \sim \mathcal{T}_{exp}} \left\| \theta_{t+M}^{\mathcal{S}} - \theta_{t+N}^{\tau} \right\|^2 \]
  • \(\theta_{t+M}^{\mathcal{S}}\): 합성 데이터로 \(M\)스텝 학습한 파라미터
  • \(\theta_{t+N}^{\tau}\): 원본 데이터의 전문가 궤적에서 \(N\)스텝 후의 파라미터

장점: Gradient Matching보다 긴 학습 동태를 반영하여 더 높은 성능.

def trajectory_matching_loss(
    model: nn.Module,
    syn_images: nn.Parameter,
    syn_labels: torch.Tensor,
    expert_trajectory: list,  # 사전 저장된 전문가 학습 궤적
    start_epoch: int,
    syn_steps: int = 50,
    expert_steps: int = 1,
    lr_model: float = 0.01
):
    """
    MTT: 전문가 궤적과 합성 데이터 학습 궤적의 거리 최소화
    """
    # 전문가 궤적의 시작점에서 모델 초기화
    starting_params = expert_trajectory[start_epoch]
    target_params = expert_trajectory[start_epoch + expert_steps]

    # 모델에 시작 파라미터 로드
    load_params(model, starting_params)
    criterion = nn.CrossEntropyLoss()

    # 합성 데이터로 syn_steps만큼 학습
    for _ in range(syn_steps):
        output = model(syn_images)
        loss = criterion(output, syn_labels)
        grad = torch.autograd.grad(loss, model.parameters())
        with torch.no_grad():
            for p, g in zip(model.parameters(), grad):
                p.data -= lr_model * g

    # 전문가 궤적과의 거리 (파라미터 공간에서)
    student_params = get_params(model)
    trajectory_loss = sum(
        (s - t).pow(2).sum()
        for s, t in zip(student_params, target_params)
    )

    # 정규화: 시작-목표 거리로 나눔
    normalizer = sum(
        (s - t).pow(2).sum()
        for s, t in zip(starting_params, target_params)
    )

    return trajectory_loss / (normalizer + 1e-6)

4. DM - Distribution Matching (Zhao & Bilen, 2023)

핵심 아이디어: 특성 공간에서 합성 데이터와 원본 데이터의 분포를 직접 매칭.

목적 함수 (MMD 기반):

\[ \min_{\mathcal{S}} \mathbb{E}_{\psi \sim \Psi} \left\| \frac{1}{|\mathcal{T}|} \sum_{x \in \mathcal{T}} \psi(x) - \frac{1}{|\mathcal{S}|} \sum_{\tilde{x} \in \mathcal{S}} \psi(\tilde{x}) \right\|^2 \]

여기서 \(\psi\)는 랜덤 초기화된 네트워크의 특성 추출기.

장점: 내부 루프 없이 단일 수준 최적화 -> 속도가 빠름.

def distribution_matching_step(
    feature_extractor: nn.Module,
    real_images: torch.Tensor,
    syn_images: nn.Parameter,
    per_class: bool = True,
    labels_real: torch.Tensor = None,
    labels_syn: torch.Tensor = None,
    num_classes: int = 10
) -> torch.Tensor:
    """
    Distribution Matching: 특성 공간에서 분포 거리 최소화
    """
    if per_class and labels_real is not None:
        total_loss = 0.0
        for c in range(num_classes):
            # 클래스별 매칭
            mask_real = labels_real == c
            mask_syn = labels_syn == c

            feat_real = feature_extractor(real_images[mask_real])
            feat_syn = feature_extractor(syn_images[mask_syn])

            # 클래스별 평균 특성 간의 MMD
            mean_real = feat_real.mean(dim=0)
            mean_syn = feat_syn.mean(dim=0)
            total_loss += (mean_real - mean_syn).pow(2).sum()

        return total_loss / num_classes
    else:
        feat_real = feature_extractor(real_images)
        feat_syn = feature_extractor(syn_images)
        mean_real = feat_real.mean(dim=0)
        mean_syn = feat_syn.mean(dim=0)
        return (mean_real - mean_syn).pow(2).sum()

성능 비교

CIFAR-10 벤치마크 (IPC = 10, ConvNet-3)

방법 정확도 (%) 연산 비용 발표
Random Selection 26.0 - -
DD (Wang et al.) 36.8 높음 2018
DC (Zhao et al.) 44.9 중간 ICLR 2021
DSA (Zhao & Bilen) 52.1 중간 ICML 2021
DM (Zhao & Bilen) 48.9 낮음 2023
MTT (Cazenavette et al.) 65.3 높음 CVPR 2022
TESLA (Cui et al.) 66.4 중간 ICML 2023
전체 데이터셋 84.8 매우 높음 -

*IPC 10 = 클래스당 10개 이미지 (총 100개 vs 원본 50,000개, 압축률 500배)

CIFAR-10 벤치마크 (IPC = 50, ConvNet-3)

방법 정확도 (%) 비고
DC 53.9
DSA 60.6
DM 63.0
MTT 71.6
전체 데이터셋 84.8

고급 기법

Factorization 기반 방법

합성 이미지를 직접 최적화하는 대신, 기저(basis)와 계수(coefficient)로 분해:

\[ \tilde{X} = \text{Decode}(Z), \quad Z \in \mathbb{R}^{M \times d} \]
  • HaBa (Liu et al., NeurIPS 2022): Hallucinator + Base로 분해
  • LinBa (Deng & Russakovsky, NeurIPS 2022): Linear base 활용

Data Augmentation 통합

  • DSA (Zhao & Bilen, ICML 2021): Differentiable Siamese Augmentation
  • 합성 데이터와 원본 데이터에 동일한 augmentation을 미분 가능하게 적용
  • DC 대비 성능 향상

Cross-Architecture Generalization

DD의 주요 과제 중 하나: 특정 아키텍처에서 distill한 데이터가 다른 아키텍처에서도 잘 작동하는가?

학습 아키텍처 평가 아키텍처 DC MTT DM
ConvNet ConvNet 44.9 65.3 48.9
ConvNet ResNet-18 25.2 47.7 36.1
ConvNet VGG-11 29.7 41.2 34.5

DM이 Cross-Architecture 일반화에서 상대적으로 강점을 보임 (내부 루프 모델 비의존적).


응용 분야

1. Neural Architecture Search (NAS)

기존 NAS:
  각 후보 아키텍처를 전체 데이터로 학습 -> 평가
  --> 수천~수만 회 학습 필요 (비용 막대)

DD + NAS:
  전체 데이터를 distill (1회)
  각 후보 아키텍처를 합성 데이터로 학습 -> 평가 (빠름)
  --> 탐색 시간 대폭 단축

2. Continual Learning

  • 이전 태스크의 데이터를 distill하여 저장
  • 새 태스크 학습 시 이전 distilled 데이터와 함께 학습
  • 메모리 효율적 catastrophic forgetting 방지

3. Privacy-Preserving ML

  • 원본 데이터 대신 distilled 데이터를 공유
  • 합성 데이터에서 개인 정보 추출이 어려움
  • Federated Learning과 결합 가능

4. Data Marketplace

  • 데이터 가치를 보존하면서 소량의 "미리보기" 데이터 제공
  • 구매자가 distilled 데이터로 모델 성능을 평가

한계와 과제

현재 한계

한계 설명
확장성 고해상도 이미지(ImageNet-1K)에서 성능 급락
아키텍처 의존성 distill 시 사용한 모델에 편향
라벨 정보 비지도학습/자기지도학습에 적용 어려움
합성 데이터 해석 생성된 이미지가 비현실적일 수 있음
평가 프로토콜 표준화된 벤치마크 부족

연구 방향 (2024-2026)

현재 주요 연구 방향:

  1. Large-Scale DD
     - ImageNet, LAION 등 대규모 데이터셋에 적용
     - 효율적 최적화 알고리즘 개발
     - 생성 모델(Diffusion) 활용 (GLaD, SRe2L)

  2. Text/Multimodal DD
     - NLP 데이터셋에 대한 distillation
     - Vision-Language 데이터 압축
     - ICLR 2025: "Dataset Distillation via Knowledge Distillation"

  3. DD 이론
     - 정보론적 관점의 분석
     - 최적 합성 데이터의 이론적 한계
     - 일반화 보장

  4. 신뢰성/강건성
     - DD-RobustBench: 적대적 강건성 벤치마크
     - 분포 이동(distribution shift) 하에서의 DD 성능

관련 문서

주제 링크
Knowledge Distillation ../knowledge-distillation/summary
Data-Centric AI ../data-centric-ai/summary
Continual Learning ../continual-learning/summary
Training Data Attribution ../training-data-attribution/summary
Synthetic Data Generation ../synthetic-data-generation/summary

참고

  • Wang, T. et al. (2018). "Dataset Distillation." arXiv:1811.10959
  • Zhao, B. et al. (2021). "Dataset Condensation with Gradient Matching." ICLR 2021
  • Zhao, B. & Bilen, H. (2021). "Dataset Condensation with Differentiable Siamese Augmentation." ICML 2021
  • Cazenavette, G. et al. (2022). "Dataset Distillation by Matching Training Trajectories." CVPR 2022
  • Zhao, B. & Bilen, H. (2023). "Dataset Condensation with Distribution Matching." WACV 2023
  • Lei, S. & Tao, D. (2023). "A Comprehensive Survey of Dataset Distillation." IEEE TPAMI 2023
  • DC-Bench: https://github.com/justincui03/dc_benchmark
  • Awesome-Dataset-Distillation: https://github.com/Guang000/Awesome-Dataset-Distillation