콘텐츠로 이동
Data Prep
상세

Diffusion Memorization Theory

개요

NeurIPS 2025 Best Paper로 선정된 연구. Diffusion 모델이 학습 데이터를 언제 "생성"하고 언제 "암기"하는지에 대한 이론적 프레임워크를 제시한다.

핵심 발견

두 가지 시간 스케일

Diffusion 모델 학습에는 두 가지 구분되는 시간 스케일이 존재한다:

시간 스케일 기호 의미 행동
생성 시간 τ_gen 유효한 샘플 생성 시작 분포 학습
암기 시간 τ_mem 학습 데이터 암기 시작 과적합
학습 진행
─────────────────────────────────────────────────────────▶
    │                    │
    │                    │
    ▼                    ▼
  τ_gen               τ_mem

[랜덤 노이즈]  →  [유효 샘플]  →  [훈련 데이터 복제]
   초기           일반화            과적합

이론적 분석

생성 시간 τ_gen: $$ \tau_{gen} \sim \frac{d}{n \cdot \text{SNR}} $$

  • d: 데이터 차원
  • n: 학습 샘플 수
  • SNR: 신호 대 잡음비

암기 시간 τ_mem: $$ \tau_{mem} \sim \frac{n}{d} \cdot \tau_{gen} $$

핵심 비율: $$ \frac{\tau_{mem}}{\tau_{gen}} \sim \frac{n}{d} $$

→ 데이터셋이 크고(n↑), 차원이 낮을수록(d↓) 암기까지 더 오래 걸림

실험 검증

1. 합성 데이터 실험

import torch
import torch.nn as nn
import numpy as np

class SimpleDiffusion(nn.Module):
    """간단한 Diffusion 모델"""
    def __init__(self, dim, hidden_dim=256):
        super().__init__()
        self.net = nn.Sequential(
            nn.Linear(dim + 1, hidden_dim),  # +1 for time
            nn.ReLU(),
            nn.Linear(hidden_dim, hidden_dim),
            nn.ReLU(),
            nn.Linear(hidden_dim, dim)
        )

    def forward(self, x, t):
        t_embed = t.unsqueeze(-1)
        return self.net(torch.cat([x, t_embed], dim=-1))

def measure_memorization(model, train_data, generated_samples):
    """암기 정도 측정"""
    # 가장 가까운 훈련 샘플까지의 거리
    distances = []
    for sample in generated_samples:
        dists = torch.norm(train_data - sample, dim=-1)
        min_dist = dists.min().item()
        distances.append(min_dist)

    return {
        'avg_min_distance': np.mean(distances),
        'memorized_ratio': np.mean([d < 0.1 for d in distances])
    }

2. 학습 단계별 분석

학습 에포크 FID 암기율 상태
100 150 0% 미숙
500 45 0% 생성 (τ_gen)
2000 20 2% 최적
5000 18 15% 과적합 시작
10000 22 45% 암기 (τ_mem)

3. 데이터셋 크기별 영향

암기율
100%│
    │                           ┌─── n=1000
    │                      ┌────┘
 50%│                 ┌────┘
    │            ┌────┘    ┌───────── n=10000
    │       ┌────┘    ┌────┘
    │  ┌────┘    ┌────┘
  0%│──┘────────┘────────────────────── n=100000
    └─────────────────────────────────▶
              학습 시간 (에포크)

실용적 함의

1. 최적 학습 중단점

def find_optimal_checkpoint(
    train_losses: list,
    fid_scores: list,
    memorization_rates: list
) -> int:
    """최적 체크포인트 찾기"""

    optimal_idx = 0
    best_score = float('inf')

    for i, (loss, fid, mem_rate) in enumerate(zip(
        train_losses, fid_scores, memorization_rates
    )):
        # FID 최소화 + 암기율 페널티
        score = fid + 10 * mem_rate  # 암기에 페널티

        if score < best_score:
            best_score = score
            optimal_idx = i

    return optimal_idx

def early_stopping_criterion(
    current_mem_rate: float,
    prev_mem_rate: float,
    threshold: float = 0.05
) -> bool:
    """암기 기반 조기 종료"""

    # 암기율 급증 감지
    if current_mem_rate - prev_mem_rate > threshold:
        return True

    # 절대 암기율 기준
    if current_mem_rate > 0.1:  # 10% 이상 암기
        return True

    return False

2. 데이터 증강의 효과

데이터 증강은 효과적으로 n을 증가시켜 τ_mem을 지연시킴:

증강 효과적 n τ_mem 지연
없음 N 기준
기본 (flip, crop) ~4N ~4x
강한 증강 ~10N ~10x
합성 데이터 추가 ~100N ~100x

3. 모델 크기와 암기

τ_mem / τ_gen 비율

큰 모델 │  ████████████████████████████
        │  (빠르게 암기)
중간 모델│  ████████████████████████████████████
        │  (적절한 균형)
작은 모델│  ████████████████████████████████████████████
        │  (느리게 암기)
        └──────────────────────────────────────────────▶

→ 큰 모델일수록 더 빨리 암기. 하지만 생성 품질도 더 높음.

암기 감지 방법

1. 최근접 이웃 거리

from sklearn.neighbors import NearestNeighbors

def detect_memorization_nn(
    generated_samples: np.ndarray,
    train_samples: np.ndarray,
    k: int = 1,
    threshold: float = 0.1
) -> dict:
    """최근접 이웃 기반 암기 감지"""

    nn = NearestNeighbors(n_neighbors=k)
    nn.fit(train_samples)

    distances, indices = nn.kneighbors(generated_samples)

    memorized = distances[:, 0] < threshold

    return {
        'memorization_rate': memorized.mean(),
        'avg_nn_distance': distances[:, 0].mean(),
        'memorized_indices': np.where(memorized)[0]
    }

2. SSCD (Self-Supervised Copy Detection)

# Meta의 SSCD 모델 사용
# pip install sscd

def detect_copies_sscd(
    generated_images: list,
    train_images: list,
    threshold: float = 0.9
) -> dict:
    """SSCD 기반 복제 감지"""

    # SSCD 임베딩 추출 (pseudo-code)
    gen_embeddings = sscd_model.encode(generated_images)
    train_embeddings = sscd_model.encode(train_images)

    # 코사인 유사도 계산
    similarities = cosine_similarity(gen_embeddings, train_embeddings)
    max_sims = similarities.max(axis=1)

    copies = max_sims > threshold

    return {
        'copy_rate': copies.mean(),
        'max_similarity_avg': max_sims.mean(),
        'copy_indices': np.where(copies)[0]
    }

3. 멤버십 추론 공격

def membership_inference(
    model,
    samples: np.ndarray,
    labels: np.ndarray,  # 1=train, 0=test
    n_steps: int = 50
) -> dict:
    """멤버십 추론으로 암기 정도 측정"""

    # 재구성 손실 기반
    losses = []
    for sample in samples:
        # Diffusion 역과정 손실 계산
        loss = compute_reconstruction_loss(model, sample, n_steps)
        losses.append(loss)

    losses = np.array(losses)

    # 훈련/테스트 손실 분포 비교
    train_losses = losses[labels == 1]
    test_losses = losses[labels == 0]

    # AUC로 구분 가능성 측정
    from sklearn.metrics import roc_auc_score
    auc = roc_auc_score(labels, -losses)  # 낮은 손실 = 훈련 샘플

    return {
        'membership_auc': auc,
        'train_loss_mean': train_losses.mean(),
        'test_loss_mean': test_losses.mean(),
        'memorization_signal': test_losses.mean() - train_losses.mean()
    }

암기 완화 전략

1. 노이즈 정규화

class NoiseRegularizedDiffusion(nn.Module):
    def __init__(self, base_model, noise_scale: float = 0.1):
        super().__init__()
        self.base_model = base_model
        self.noise_scale = noise_scale

    def forward(self, x, t, training=True):
        if training:
            # 학습 시 추가 노이즈
            x = x + self.noise_scale * torch.randn_like(x)

        return self.base_model(x, t)

2. 드롭아웃 스케줄링

def adaptive_dropout_schedule(
    epoch: int,
    total_epochs: int,
    min_dropout: float = 0.0,
    max_dropout: float = 0.5
) -> float:
    """학습 후반에 드롭아웃 증가"""

    # 로지스틱 스케줄
    progress = epoch / total_epochs
    dropout = min_dropout + (max_dropout - min_dropout) * (
        1 / (1 + np.exp(-10 * (progress - 0.6)))
    )

    return dropout

3. 데이터 샤딩

각 배치에서 전체 데이터셋의 일부만 사용:

class ShardedDataLoader:
    def __init__(self, dataset, shard_ratio: float = 0.1):
        self.dataset = dataset
        self.shard_ratio = shard_ratio
        self.shard_size = int(len(dataset) * shard_ratio)

    def __iter__(self):
        # 랜덤 샤드 선택
        indices = np.random.choice(
            len(self.dataset), 
            size=self.shard_size, 
            replace=False
        )

        for idx in indices:
            yield self.dataset[idx]

응용

1. 프라이버시 보호

  • 암기율 모니터링으로 개인정보 유출 방지
  • τ_mem 이전에 학습 중단

2. 저작권 보호

  • 생성물과 훈련 데이터 유사도 검사
  • 복제 감지 파이프라인 구축

3. 모델 디버깅

  • 과적합 조기 감지
  • 데이터 품질 문제 파악

코드 참조

# 전체 파이프라인 예시
def train_with_memorization_monitoring(
    model, 
    train_loader, 
    val_samples,
    max_epochs: int = 1000,
    mem_threshold: float = 0.1
):
    """암기 모니터링 포함 학습"""

    optimizer = torch.optim.Adam(model.parameters())
    prev_mem_rate = 0.0

    for epoch in range(max_epochs):
        # 학습
        for batch in train_loader:
            loss = diffusion_loss(model, batch)
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()

        # 주기적 암기 체크
        if epoch % 50 == 0:
            generated = sample_from_model(model, n=1000)
            mem_result = detect_memorization_nn(
                generated.numpy(), 
                val_samples.numpy()
            )

            current_mem_rate = mem_result['memorization_rate']

            print(f"Epoch {epoch}: Mem Rate = {current_mem_rate:.2%}")

            # 조기 종료 체크
            if early_stopping_criterion(current_mem_rate, prev_mem_rate, mem_threshold):
                print(f"Early stopping at epoch {epoch}")
                break

            prev_mem_rate = current_mem_rate

    return model

요약

핵심 포인트

  1. 두 시간 스케일: 생성(τ_gen)과 암기(τ_mem)는 구분되는 현상
  2. 비율 τ_mem/τ_gen ~ n/d: 데이터 많고 저차원일수록 암기 지연
  3. 실용적 함의: 조기 종료, 데이터 증강, 모델 크기 조절에 활용
  4. 감지 방법: 최근접 이웃, SSCD, 멤버십 추론

참고 자료


마지막 업데이트: 2026-03-04