Diffusion Memorization Theory¶
개요¶
NeurIPS 2025 Best Paper로 선정된 연구. Diffusion 모델이 학습 데이터를 언제 "생성"하고 언제 "암기"하는지에 대한 이론적 프레임워크를 제시한다.
핵심 발견¶
두 가지 시간 스케일¶
Diffusion 모델 학습에는 두 가지 구분되는 시간 스케일이 존재한다:
| 시간 스케일 | 기호 | 의미 | 행동 |
|---|---|---|---|
| 생성 시간 | τ_gen | 유효한 샘플 생성 시작 | 분포 학습 |
| 암기 시간 | τ_mem | 학습 데이터 암기 시작 | 과적합 |
학습 진행
─────────────────────────────────────────────────────────▶
│ │
│ │
▼ ▼
τ_gen τ_mem
[랜덤 노이즈] → [유효 샘플] → [훈련 데이터 복제]
초기 일반화 과적합
이론적 분석¶
생성 시간 τ_gen: $$ \tau_{gen} \sim \frac{d}{n \cdot \text{SNR}} $$
- d: 데이터 차원
- n: 학습 샘플 수
- SNR: 신호 대 잡음비
암기 시간 τ_mem: $$ \tau_{mem} \sim \frac{n}{d} \cdot \tau_{gen} $$
핵심 비율: $$ \frac{\tau_{mem}}{\tau_{gen}} \sim \frac{n}{d} $$
→ 데이터셋이 크고(n↑), 차원이 낮을수록(d↓) 암기까지 더 오래 걸림
실험 검증¶
1. 합성 데이터 실험¶
import torch
import torch.nn as nn
import numpy as np
class SimpleDiffusion(nn.Module):
"""간단한 Diffusion 모델"""
def __init__(self, dim, hidden_dim=256):
super().__init__()
self.net = nn.Sequential(
nn.Linear(dim + 1, hidden_dim), # +1 for time
nn.ReLU(),
nn.Linear(hidden_dim, hidden_dim),
nn.ReLU(),
nn.Linear(hidden_dim, dim)
)
def forward(self, x, t):
t_embed = t.unsqueeze(-1)
return self.net(torch.cat([x, t_embed], dim=-1))
def measure_memorization(model, train_data, generated_samples):
"""암기 정도 측정"""
# 가장 가까운 훈련 샘플까지의 거리
distances = []
for sample in generated_samples:
dists = torch.norm(train_data - sample, dim=-1)
min_dist = dists.min().item()
distances.append(min_dist)
return {
'avg_min_distance': np.mean(distances),
'memorized_ratio': np.mean([d < 0.1 for d in distances])
}
2. 학습 단계별 분석¶
| 학습 에포크 | FID | 암기율 | 상태 |
|---|---|---|---|
| 100 | 150 | 0% | 미숙 |
| 500 | 45 | 0% | 생성 (τ_gen) |
| 2000 | 20 | 2% | 최적 |
| 5000 | 18 | 15% | 과적합 시작 |
| 10000 | 22 | 45% | 암기 (τ_mem) |
3. 데이터셋 크기별 영향¶
암기율
100%│
│ ┌─── n=1000
│ ┌────┘
50%│ ┌────┘
│ ┌────┘ ┌───────── n=10000
│ ┌────┘ ┌────┘
│ ┌────┘ ┌────┘
0%│──┘────────┘────────────────────── n=100000
└─────────────────────────────────▶
학습 시간 (에포크)
실용적 함의¶
1. 최적 학습 중단점¶
def find_optimal_checkpoint(
train_losses: list,
fid_scores: list,
memorization_rates: list
) -> int:
"""최적 체크포인트 찾기"""
optimal_idx = 0
best_score = float('inf')
for i, (loss, fid, mem_rate) in enumerate(zip(
train_losses, fid_scores, memorization_rates
)):
# FID 최소화 + 암기율 페널티
score = fid + 10 * mem_rate # 암기에 페널티
if score < best_score:
best_score = score
optimal_idx = i
return optimal_idx
def early_stopping_criterion(
current_mem_rate: float,
prev_mem_rate: float,
threshold: float = 0.05
) -> bool:
"""암기 기반 조기 종료"""
# 암기율 급증 감지
if current_mem_rate - prev_mem_rate > threshold:
return True
# 절대 암기율 기준
if current_mem_rate > 0.1: # 10% 이상 암기
return True
return False
2. 데이터 증강의 효과¶
데이터 증강은 효과적으로 n을 증가시켜 τ_mem을 지연시킴:
| 증강 | 효과적 n | τ_mem 지연 |
|---|---|---|
| 없음 | N | 기준 |
| 기본 (flip, crop) | ~4N | ~4x |
| 강한 증강 | ~10N | ~10x |
| 합성 데이터 추가 | ~100N | ~100x |
3. 모델 크기와 암기¶
τ_mem / τ_gen 비율
큰 모델 │ ████████████████████████████
│ (빠르게 암기)
│
중간 모델│ ████████████████████████████████████
│ (적절한 균형)
│
작은 모델│ ████████████████████████████████████████████
│ (느리게 암기)
└──────────────────────────────────────────────▶
→ 큰 모델일수록 더 빨리 암기. 하지만 생성 품질도 더 높음.
암기 감지 방법¶
1. 최근접 이웃 거리¶
from sklearn.neighbors import NearestNeighbors
def detect_memorization_nn(
generated_samples: np.ndarray,
train_samples: np.ndarray,
k: int = 1,
threshold: float = 0.1
) -> dict:
"""최근접 이웃 기반 암기 감지"""
nn = NearestNeighbors(n_neighbors=k)
nn.fit(train_samples)
distances, indices = nn.kneighbors(generated_samples)
memorized = distances[:, 0] < threshold
return {
'memorization_rate': memorized.mean(),
'avg_nn_distance': distances[:, 0].mean(),
'memorized_indices': np.where(memorized)[0]
}
2. SSCD (Self-Supervised Copy Detection)¶
# Meta의 SSCD 모델 사용
# pip install sscd
def detect_copies_sscd(
generated_images: list,
train_images: list,
threshold: float = 0.9
) -> dict:
"""SSCD 기반 복제 감지"""
# SSCD 임베딩 추출 (pseudo-code)
gen_embeddings = sscd_model.encode(generated_images)
train_embeddings = sscd_model.encode(train_images)
# 코사인 유사도 계산
similarities = cosine_similarity(gen_embeddings, train_embeddings)
max_sims = similarities.max(axis=1)
copies = max_sims > threshold
return {
'copy_rate': copies.mean(),
'max_similarity_avg': max_sims.mean(),
'copy_indices': np.where(copies)[0]
}
3. 멤버십 추론 공격¶
def membership_inference(
model,
samples: np.ndarray,
labels: np.ndarray, # 1=train, 0=test
n_steps: int = 50
) -> dict:
"""멤버십 추론으로 암기 정도 측정"""
# 재구성 손실 기반
losses = []
for sample in samples:
# Diffusion 역과정 손실 계산
loss = compute_reconstruction_loss(model, sample, n_steps)
losses.append(loss)
losses = np.array(losses)
# 훈련/테스트 손실 분포 비교
train_losses = losses[labels == 1]
test_losses = losses[labels == 0]
# AUC로 구분 가능성 측정
from sklearn.metrics import roc_auc_score
auc = roc_auc_score(labels, -losses) # 낮은 손실 = 훈련 샘플
return {
'membership_auc': auc,
'train_loss_mean': train_losses.mean(),
'test_loss_mean': test_losses.mean(),
'memorization_signal': test_losses.mean() - train_losses.mean()
}
암기 완화 전략¶
1. 노이즈 정규화¶
class NoiseRegularizedDiffusion(nn.Module):
def __init__(self, base_model, noise_scale: float = 0.1):
super().__init__()
self.base_model = base_model
self.noise_scale = noise_scale
def forward(self, x, t, training=True):
if training:
# 학습 시 추가 노이즈
x = x + self.noise_scale * torch.randn_like(x)
return self.base_model(x, t)
2. 드롭아웃 스케줄링¶
def adaptive_dropout_schedule(
epoch: int,
total_epochs: int,
min_dropout: float = 0.0,
max_dropout: float = 0.5
) -> float:
"""학습 후반에 드롭아웃 증가"""
# 로지스틱 스케줄
progress = epoch / total_epochs
dropout = min_dropout + (max_dropout - min_dropout) * (
1 / (1 + np.exp(-10 * (progress - 0.6)))
)
return dropout
3. 데이터 샤딩¶
각 배치에서 전체 데이터셋의 일부만 사용:
class ShardedDataLoader:
def __init__(self, dataset, shard_ratio: float = 0.1):
self.dataset = dataset
self.shard_ratio = shard_ratio
self.shard_size = int(len(dataset) * shard_ratio)
def __iter__(self):
# 랜덤 샤드 선택
indices = np.random.choice(
len(self.dataset),
size=self.shard_size,
replace=False
)
for idx in indices:
yield self.dataset[idx]
응용¶
1. 프라이버시 보호¶
- 암기율 모니터링으로 개인정보 유출 방지
- τ_mem 이전에 학습 중단
2. 저작권 보호¶
- 생성물과 훈련 데이터 유사도 검사
- 복제 감지 파이프라인 구축
3. 모델 디버깅¶
- 과적합 조기 감지
- 데이터 품질 문제 파악
코드 참조¶
# 전체 파이프라인 예시
def train_with_memorization_monitoring(
model,
train_loader,
val_samples,
max_epochs: int = 1000,
mem_threshold: float = 0.1
):
"""암기 모니터링 포함 학습"""
optimizer = torch.optim.Adam(model.parameters())
prev_mem_rate = 0.0
for epoch in range(max_epochs):
# 학습
for batch in train_loader:
loss = diffusion_loss(model, batch)
optimizer.zero_grad()
loss.backward()
optimizer.step()
# 주기적 암기 체크
if epoch % 50 == 0:
generated = sample_from_model(model, n=1000)
mem_result = detect_memorization_nn(
generated.numpy(),
val_samples.numpy()
)
current_mem_rate = mem_result['memorization_rate']
print(f"Epoch {epoch}: Mem Rate = {current_mem_rate:.2%}")
# 조기 종료 체크
if early_stopping_criterion(current_mem_rate, prev_mem_rate, mem_threshold):
print(f"Early stopping at epoch {epoch}")
break
prev_mem_rate = current_mem_rate
return model
요약¶
핵심 포인트¶
- 두 시간 스케일: 생성(τ_gen)과 암기(τ_mem)는 구분되는 현상
- 비율 τ_mem/τ_gen ~ n/d: 데이터 많고 저차원일수록 암기 지연
- 실용적 함의: 조기 종료, 데이터 증강, 모델 크기 조절에 활용
- 감지 방법: 최근접 이웃, SSCD, 멤버십 추론
참고 자료¶
- 논문: Understanding Diffusion Model Memorization
- NeurIPS 2025 Best Paper - Diffusion 이론 부문
- 관련 연구: Carlini et al. "Extracting Training Data from Diffusion Models"
마지막 업데이트: 2026-03-04