Dataset Distillation¶
메타 정보¶
| 항목 | 내용 |
|---|---|
| 분류 | Data-Centric AI / Efficient Learning |
| 핵심 논문 | "Dataset Distillation" (Wang et al., 2018 - 최초 제안), "Dataset Condensation with Gradient Matching" (Zhao et al., ICLR 2021 - DC), "Dataset Distillation by Matching Training Trajectories" (Cazenavette et al., CVPR 2022 - MTT), "Distribution Matching for Dataset Distillation" (Zhao & Bilen, 2023 - DM) |
| 주요 저자 | Tongzhou Wang (MIT), Bo Zhao (Edinburgh), George Cazenavette (CMU), Hakan Bilen |
| 핵심 개념 | 대규모 학습 데이터셋을 소규모 합성 데이터셋으로 압축하여, 합성 데이터로 학습한 모델이 원본 데이터로 학습한 것과 유사한 성능을 달성하도록 하는 기법 |
| 관련 분야 | Coreset Selection, Knowledge Distillation, Data-Centric AI, Continual Learning, Privacy |
정의¶
Dataset Distillation(DD)은 대규모 학습 데이터셋 \(\mathcal{T} = \{(x_i, y_i)\}_{i=1}^{N}\)를 소규모 합성 데이터셋 \(\mathcal{S} = \{(\tilde{x}_j, \tilde{y}_j)\}_{j=1}^{M}\) (\(M \ll N\))으로 압축하는 기법이다. 핵심 목표는 \(\mathcal{S}\)로 학습한 모델의 성능이 \(\mathcal{T}\)로 학습한 모델의 성능에 근접하도록 하는 것이다.
왜 필요한가¶
문제: 데이터가 커질수록 학습 비용이 기하급수적으로 증가
원본 데이터: ImageNet (1.28M 이미지, ~150GB)
학습 시간: GPU 수백 시간
NAS 등 반복 실험: 수천 번 학습 필요
--> 원본과 동등한 정보를 담은 소규모 데이터셋이 있다면?
--> 학습 비용 대폭 절감 가능
Coreset Selection과의 차이¶
| 항목 | Coreset Selection | Dataset Distillation |
|---|---|---|
| 데이터 소스 | 원본에서 부분집합 선택 | 새로운 합성 데이터 생성 |
| 데이터 형태 | 실제 데이터 포인트 | 합성(최적화된) 데이터 |
| 정보 밀도 | 원본 수준 | 원본보다 높음 (압축) |
| 해석 가능성 | 높음 (실제 샘플) | 낮을 수 있음 (추상적 패턴) |
| 압축률 | 제한적 | 극단적 압축 가능 (1-50 IPC) |
*IPC = Images Per Class
문제 정의¶
이중 최적화 (Bi-level Optimization)¶
Dataset Distillation의 일반적 정의:
- 외부 루프: 합성 데이터 \(\mathcal{S}\)를 최적화 (원본 데이터에서의 성능 최대화)
- 내부 루프: \(\mathcal{S}\)로 모델 파라미터 \(\theta\) 학습
주요 매칭 전략¶
Dataset Distillation 방법론 분류:
+------------------------------------------------------------------+
| |
| 1. Performance Matching (Meta-Learning) |
| ============================================ |
| - 합성 데이터로 학습한 모델이 |
| 원본 데이터에서 좋은 성능을 내도록 최적화 |
| - 대표: DD (Wang et al., 2018), KIP (Nguyen et al., 2021) |
| - 계산 비용 높음 (이중 최적화 필요) |
| |
+------------------------------------------------------------------+
|
v
+------------------------------------------------------------------+
| |
| 2. Parameter Matching (Gradient/Trajectory) |
| ============================================ |
| - 학습 과정의 중간 상태를 매칭 |
| - Gradient Matching: 원본/합성 데이터의 gradient 방향 일치 |
| - Trajectory Matching: 학습 궤적 자체를 모방 |
| - 대표: DC (Zhao et al., 2021), MTT (Cazenavette et al., 2022)|
| |
+------------------------------------------------------------------+
|
v
+------------------------------------------------------------------+
| |
| 3. Distribution Matching |
| ============================================ |
| - 합성 데이터의 특성 분포를 원본 데이터와 일치 |
| - 특성 공간에서의 MMD (Maximum Mean Discrepancy) 최소화 |
| - 내부 루프 불필요 -> 빠른 학습 |
| - 대표: DM (Zhao & Bilen, 2023), CAFE (Wang et al., 2022) |
| |
+------------------------------------------------------------------+
|
v
+------------------------------------------------------------------+
| |
| 4. Generative Model 기반 |
| ============================================ |
| - Diffusion Model 등을 활용하여 합성 데이터 생성 |
| - 잠재 공간에서의 distillation |
| - 대표: GLaD (Cazenavette et al., 2023), IT-GAN (Zhao, 2022) |
| |
+------------------------------------------------------------------+
핵심 방법론¶
1. DD - Dataset Distillation (Wang et al., 2018)¶
최초의 Dataset Distillation 논문. Meta-learning 기반 접근.
목적 함수:
여기서 \(\text{GD}(\theta_0, \mathcal{S}, \eta, T)\)는 초기 파라미터 \(\theta_0\)에서 \(\mathcal{S}\)로 \(T\)스텝 학습한 결과.
한계: 이중 최적화의 Unrolled Gradient가 필요하여 메모리/연산 비용이 높음.
2. DC - Dataset Condensation with Gradient Matching (Zhao et al., ICLR 2021)¶
핵심 아이디어: 합성 데이터와 원본 데이터에서 계산한 gradient의 방향을 일치시킴.
목적 함수:
여기서 \(D\)는 gradient 간의 거리 함수 (보통 cosine similarity), \(\mathcal{B}_t\)는 원본 데이터의 미니배치.
import torch
import torch.nn as nn
import torch.nn.functional as F
from torchvision import datasets, transforms
def gradient_matching_step(
model: nn.Module,
real_images: torch.Tensor,
real_labels: torch.Tensor,
syn_images: nn.Parameter,
syn_labels: torch.Tensor,
criterion: nn.Module
) -> torch.Tensor:
"""
Gradient Matching: 합성 데이터의 gradient를
원본 데이터의 gradient에 매칭
"""
# 원본 데이터 gradient
output_real = model(real_images)
loss_real = criterion(output_real, real_labels)
grad_real = torch.autograd.grad(loss_real, model.parameters(), create_graph=True)
# 합성 데이터 gradient
output_syn = model(syn_images)
loss_syn = criterion(output_syn, syn_labels)
grad_syn = torch.autograd.grad(loss_syn, model.parameters(), create_graph=True)
# Cosine similarity 기반 매칭 손실
matching_loss = 0.0
for g_real, g_syn in zip(grad_real, grad_syn):
g_real_flat = g_real.reshape(1, -1)
g_syn_flat = g_syn.reshape(1, -1)
cosine_sim = F.cosine_similarity(g_real_flat, g_syn_flat)
matching_loss += (1 - cosine_sim).mean()
return matching_loss
def distill_dataset(
num_classes: int = 10,
ipc: int = 10,
image_size: int = 32,
channels: int = 3,
num_iterations: int = 1000,
lr_syn: float = 0.1
):
"""
DC 알고리즘의 단순화된 구현.
Args:
ipc: Images Per Class (클래스당 합성 이미지 수)
"""
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
# 합성 데이터 초기화 (학습 가능한 파라미터)
syn_images = nn.Parameter(
torch.randn(num_classes * ipc, channels, image_size, image_size, device=device)
)
syn_labels = torch.repeat_interleave(
torch.arange(num_classes, device=device), ipc
)
optimizer_syn = torch.optim.SGD([syn_images], lr=lr_syn, momentum=0.5)
criterion = nn.CrossEntropyLoss()
# 원본 데이터 로더 (예: CIFAR-10)
transform = transforms.Compose([
transforms.ToTensor(),
transforms.Normalize((0.5,), (0.5,))
])
dataset = datasets.CIFAR10(root="./data", train=True, transform=transform)
loader = torch.utils.data.DataLoader(dataset, batch_size=256, shuffle=True)
for iteration in range(num_iterations):
# 랜덤 네트워크 초기화
model = simple_convnet(channels, num_classes).to(device)
real_images, real_labels = next(iter(loader))
real_images, real_labels = real_images.to(device), real_labels.to(device)
# Gradient matching loss 계산
loss = gradient_matching_step(
model, real_images, real_labels,
syn_images, syn_labels, criterion
)
optimizer_syn.zero_grad()
loss.backward()
optimizer_syn.step()
if iteration % 100 == 0:
print(f"[{iteration}/{num_iterations}] Loss: {loss.item():.4f}")
return syn_images.detach(), syn_labels
def simple_convnet(channels: int, num_classes: int) -> nn.Module:
"""간단한 3층 ConvNet"""
return nn.Sequential(
nn.Conv2d(channels, 128, 3, padding=1),
nn.GroupNorm(8, 128),
nn.ReLU(),
nn.AvgPool2d(2),
nn.Conv2d(128, 256, 3, padding=1),
nn.GroupNorm(8, 256),
nn.ReLU(),
nn.AvgPool2d(2),
nn.Flatten(),
nn.Linear(256 * 8 * 8, num_classes)
)
3. MTT - Matching Training Trajectories (Cazenavette et al., CVPR 2022)¶
핵심 아이디어: 단일 gradient가 아닌, 전체 학습 궤적(trajectory)을 매칭.
목적 함수:
- \(\theta_{t+M}^{\mathcal{S}}\): 합성 데이터로 \(M\)스텝 학습한 파라미터
- \(\theta_{t+N}^{\tau}\): 원본 데이터의 전문가 궤적에서 \(N\)스텝 후의 파라미터
장점: Gradient Matching보다 긴 학습 동태를 반영하여 더 높은 성능.
def trajectory_matching_loss(
model: nn.Module,
syn_images: nn.Parameter,
syn_labels: torch.Tensor,
expert_trajectory: list, # 사전 저장된 전문가 학습 궤적
start_epoch: int,
syn_steps: int = 50,
expert_steps: int = 1,
lr_model: float = 0.01
):
"""
MTT: 전문가 궤적과 합성 데이터 학습 궤적의 거리 최소화
"""
# 전문가 궤적의 시작점에서 모델 초기화
starting_params = expert_trajectory[start_epoch]
target_params = expert_trajectory[start_epoch + expert_steps]
# 모델에 시작 파라미터 로드
load_params(model, starting_params)
criterion = nn.CrossEntropyLoss()
# 합성 데이터로 syn_steps만큼 학습
for _ in range(syn_steps):
output = model(syn_images)
loss = criterion(output, syn_labels)
grad = torch.autograd.grad(loss, model.parameters())
with torch.no_grad():
for p, g in zip(model.parameters(), grad):
p.data -= lr_model * g
# 전문가 궤적과의 거리 (파라미터 공간에서)
student_params = get_params(model)
trajectory_loss = sum(
(s - t).pow(2).sum()
for s, t in zip(student_params, target_params)
)
# 정규화: 시작-목표 거리로 나눔
normalizer = sum(
(s - t).pow(2).sum()
for s, t in zip(starting_params, target_params)
)
return trajectory_loss / (normalizer + 1e-6)
4. DM - Distribution Matching (Zhao & Bilen, 2023)¶
핵심 아이디어: 특성 공간에서 합성 데이터와 원본 데이터의 분포를 직접 매칭.
목적 함수 (MMD 기반):
여기서 \(\psi\)는 랜덤 초기화된 네트워크의 특성 추출기.
장점: 내부 루프 없이 단일 수준 최적화 -> 속도가 빠름.
def distribution_matching_step(
feature_extractor: nn.Module,
real_images: torch.Tensor,
syn_images: nn.Parameter,
per_class: bool = True,
labels_real: torch.Tensor = None,
labels_syn: torch.Tensor = None,
num_classes: int = 10
) -> torch.Tensor:
"""
Distribution Matching: 특성 공간에서 분포 거리 최소화
"""
if per_class and labels_real is not None:
total_loss = 0.0
for c in range(num_classes):
# 클래스별 매칭
mask_real = labels_real == c
mask_syn = labels_syn == c
feat_real = feature_extractor(real_images[mask_real])
feat_syn = feature_extractor(syn_images[mask_syn])
# 클래스별 평균 특성 간의 MMD
mean_real = feat_real.mean(dim=0)
mean_syn = feat_syn.mean(dim=0)
total_loss += (mean_real - mean_syn).pow(2).sum()
return total_loss / num_classes
else:
feat_real = feature_extractor(real_images)
feat_syn = feature_extractor(syn_images)
mean_real = feat_real.mean(dim=0)
mean_syn = feat_syn.mean(dim=0)
return (mean_real - mean_syn).pow(2).sum()
성능 비교¶
CIFAR-10 벤치마크 (IPC = 10, ConvNet-3)¶
| 방법 | 정확도 (%) | 연산 비용 | 발표 |
|---|---|---|---|
| Random Selection | 26.0 | - | - |
| DD (Wang et al.) | 36.8 | 높음 | 2018 |
| DC (Zhao et al.) | 44.9 | 중간 | ICLR 2021 |
| DSA (Zhao & Bilen) | 52.1 | 중간 | ICML 2021 |
| DM (Zhao & Bilen) | 48.9 | 낮음 | 2023 |
| MTT (Cazenavette et al.) | 65.3 | 높음 | CVPR 2022 |
| TESLA (Cui et al.) | 66.4 | 중간 | ICML 2023 |
| 전체 데이터셋 | 84.8 | 매우 높음 | - |
*IPC 10 = 클래스당 10개 이미지 (총 100개 vs 원본 50,000개, 압축률 500배)
CIFAR-10 벤치마크 (IPC = 50, ConvNet-3)¶
| 방법 | 정확도 (%) | 비고 |
|---|---|---|
| DC | 53.9 | |
| DSA | 60.6 | |
| DM | 63.0 | |
| MTT | 71.6 | |
| 전체 데이터셋 | 84.8 |
고급 기법¶
Factorization 기반 방법¶
합성 이미지를 직접 최적화하는 대신, 기저(basis)와 계수(coefficient)로 분해:
- HaBa (Liu et al., NeurIPS 2022): Hallucinator + Base로 분해
- LinBa (Deng & Russakovsky, NeurIPS 2022): Linear base 활용
Data Augmentation 통합¶
- DSA (Zhao & Bilen, ICML 2021): Differentiable Siamese Augmentation
- 합성 데이터와 원본 데이터에 동일한 augmentation을 미분 가능하게 적용
- DC 대비 성능 향상
Cross-Architecture Generalization¶
DD의 주요 과제 중 하나: 특정 아키텍처에서 distill한 데이터가 다른 아키텍처에서도 잘 작동하는가?
| 학습 아키텍처 | 평가 아키텍처 | DC | MTT | DM |
|---|---|---|---|---|
| ConvNet | ConvNet | 44.9 | 65.3 | 48.9 |
| ConvNet | ResNet-18 | 25.2 | 47.7 | 36.1 |
| ConvNet | VGG-11 | 29.7 | 41.2 | 34.5 |
DM이 Cross-Architecture 일반화에서 상대적으로 강점을 보임 (내부 루프 모델 비의존적).
응용 분야¶
1. Neural Architecture Search (NAS)¶
기존 NAS:
각 후보 아키텍처를 전체 데이터로 학습 -> 평가
--> 수천~수만 회 학습 필요 (비용 막대)
DD + NAS:
전체 데이터를 distill (1회)
각 후보 아키텍처를 합성 데이터로 학습 -> 평가 (빠름)
--> 탐색 시간 대폭 단축
2. Continual Learning¶
- 이전 태스크의 데이터를 distill하여 저장
- 새 태스크 학습 시 이전 distilled 데이터와 함께 학습
- 메모리 효율적 catastrophic forgetting 방지
3. Privacy-Preserving ML¶
- 원본 데이터 대신 distilled 데이터를 공유
- 합성 데이터에서 개인 정보 추출이 어려움
- Federated Learning과 결합 가능
4. Data Marketplace¶
- 데이터 가치를 보존하면서 소량의 "미리보기" 데이터 제공
- 구매자가 distilled 데이터로 모델 성능을 평가
한계와 과제¶
현재 한계¶
| 한계 | 설명 |
|---|---|
| 확장성 | 고해상도 이미지(ImageNet-1K)에서 성능 급락 |
| 아키텍처 의존성 | distill 시 사용한 모델에 편향 |
| 라벨 정보 | 비지도학습/자기지도학습에 적용 어려움 |
| 합성 데이터 해석 | 생성된 이미지가 비현실적일 수 있음 |
| 평가 프로토콜 | 표준화된 벤치마크 부족 |
연구 방향 (2024-2026)¶
현재 주요 연구 방향:
1. Large-Scale DD
- ImageNet, LAION 등 대규모 데이터셋에 적용
- 효율적 최적화 알고리즘 개발
- 생성 모델(Diffusion) 활용 (GLaD, SRe2L)
2. Text/Multimodal DD
- NLP 데이터셋에 대한 distillation
- Vision-Language 데이터 압축
- ICLR 2025: "Dataset Distillation via Knowledge Distillation"
3. DD 이론
- 정보론적 관점의 분석
- 최적 합성 데이터의 이론적 한계
- 일반화 보장
4. 신뢰성/강건성
- DD-RobustBench: 적대적 강건성 벤치마크
- 분포 이동(distribution shift) 하에서의 DD 성능
관련 문서¶
| 주제 | 링크 |
|---|---|
| Knowledge Distillation | ../knowledge-distillation/summary |
| Data-Centric AI | ../data-centric-ai/summary |
| Continual Learning | ../continual-learning/summary |
| Training Data Attribution | ../training-data-attribution/summary |
| Synthetic Data Generation | ../synthetic-data-generation/summary |
참고¶
- Wang, T. et al. (2018). "Dataset Distillation." arXiv:1811.10959
- Zhao, B. et al. (2021). "Dataset Condensation with Gradient Matching." ICLR 2021
- Zhao, B. & Bilen, H. (2021). "Dataset Condensation with Differentiable Siamese Augmentation." ICML 2021
- Cazenavette, G. et al. (2022). "Dataset Distillation by Matching Training Trajectories." CVPR 2022
- Zhao, B. & Bilen, H. (2023). "Dataset Condensation with Distribution Matching." WACV 2023
- Lei, S. & Tao, D. (2023). "A Comprehensive Survey of Dataset Distillation." IEEE TPAMI 2023
- DC-Bench: https://github.com/justincui03/dc_benchmark
- Awesome-Dataset-Distillation: https://github.com/Guang000/Awesome-Dataset-Distillation