Consistency Models¶
메타정보¶
| 항목 | 내용 |
|---|---|
| 논문 | Consistency Models |
| 저자 | Yang Song, Prafulla Dhariwal, Mark Chen, Ilya Sutskever (OpenAI) |
| 발표 | ICML 2023 |
| 후속 | Improved CM (2023), LCM (2023), CTM (2023), ECT (2024), sCM (2024), PCM (2024) |
| arXiv | 2303.01469 |
| 키워드 | Generative Models, Diffusion, One-Step Generation, Distillation, Score Matching |
개요¶
Consistency Models (CM)는 noise에서 data로 직접 매핑하는 생성 모델로, diffusion 모델의 느린 샘플링 문제를 해결한다. 단일 스텝 또는 소수 스텝으로 고품질 샘플을 생성할 수 있다.
핵심 통찰: - Self-consistency property: PF-ODE 궤적의 모든 점이 동일한 데이터로 매핑 - Pretrained diffusion 모델에서 distillation 또는 독립 훈련 가능 - Zero-shot editing (inpainting, colorization, super-resolution) 지원 - CIFAR-10에서 1-step FID 3.55 (당시 SOTA)
배경: Diffusion 모델의 한계¶
Diffusion 모델의 샘플링 과정¶
Forward Process:
x_0 -> x_1 -> x_2 -> ... -> x_T (noise)
Reverse Process (sampling):
x_T -> x_{T-1} -> ... -> x_1 -> x_0 (data)
문제점¶
| 측면 | 설명 |
|---|---|
| 느린 샘플링 | 수백~수천 스텝 필요 (DDPM: 1000 steps) |
| 계산 비용 | 각 스텝마다 신경망 forward pass |
| 실시간 적용 | 고해상도 이미지 생성에 수십 초 |
기존 가속화 기법의 한계¶
DDIM: 50-100 steps로 축소 가능, 여전히 느림
DPM-Solver: 10-20 steps, 품질 저하 발생
Knowledge Distillation: 복잡한 학습, 불안정
Consistency Models의 핵심 아이디어¶
Probability Flow ODE (PF-ODE)¶
Diffusion 모델의 forward/reverse 과정은 ODE로 표현된다:
dx = [f(x,t) - (1/2)g(t)^2 * score(x,t)] dt
여기서:
- f(x,t): drift coefficient
- g(t): diffusion coefficient
- score(x,t) = grad_x log p_t(x)
Self-Consistency Property¶
핵심 정의: PF-ODE 궤적 상의 모든 점은 동일한 초기 데이터로 수렴한다.
이 성질을 이용하면: - 임의의 시점 t에서 단일 스텝으로 데이터 복원 가능 - 중간 스텝 없이 noise -> data 직접 매핑
시각적 이해¶
Diffusion (순차적):
noise --[step 1]--> ... --[step N]--> data
Consistency (직접):
noise --[single step]--> data
|
+-- 궤적의 어느 점에서도 동일한 종착점
수학적 정의¶
Consistency Function¶
Consistency function f: (x, t) -> x_epsilon 은 다음 조건을 만족:
1. Boundary condition:
f(x, epsilon) = x (t=epsilon에서 항등 함수)
2. Self-consistency:
f(x_t, t) = f(x_{t'}, t') (동일 궤적의 모든 점)
모델 파라미터화¶
def consistency_model(x, t, F_theta):
"""
Skip connection으로 boundary condition 만족
c_skip(t): skip coefficient, c_skip(epsilon) = 1
c_out(t): output coefficient, c_out(epsilon) = 0
"""
c_skip = sigma_data**2 / (t**2 + sigma_data**2)
c_out = t * sigma_data / sqrt(t**2 + sigma_data**2)
return c_skip * x + c_out * F_theta(x, t)
Loss Function¶
Consistency Distillation (CD) Loss:
def cd_loss(x, t_n, t_{n+1}, theta, theta_minus):
"""
theta_minus: EMA of theta (target network)
"""
# ODE solver로 x_{t_n}에서 x_{t_{n+1}} 계산
x_next = ode_solver(x, t_n, t_{n+1}, score_model)
# Consistency loss
loss = distance(
f_theta(x_next, t_{n+1}), # online network
f_theta_minus(x, t_n) # target network (stop grad)
)
return loss
Consistency Training (CT) Loss:
def ct_loss(x_0, t_n, t_{n+1}, theta, theta_minus):
"""
Score model 없이 데이터에서 직접 학습
"""
# Forward diffusion
noise = torch.randn_like(x_0)
x_n = x_0 + t_n * noise
x_{n+1} = x_0 + t_{n+1} * noise # 동일한 noise 사용!
loss = distance(
f_theta(x_{n+1}, t_{n+1}),
f_theta_minus(x_n, t_n)
)
return loss
훈련 방법¶
1. Consistency Distillation (CD)¶
Pretrained diffusion 모델로부터 knowledge distillation:
class ConsistencyDistillation:
def __init__(self, pretrained_diffusion):
self.teacher = pretrained_diffusion
self.student = ConsistencyModel()
self.student_ema = copy.deepcopy(self.student)
def train_step(self, x_0):
# Sample timestep pair
n = torch.randint(1, N)
t_n, t_{n+1} = schedule[n], schedule[n+1]
# Add noise
noise = torch.randn_like(x_0)
x_n = x_0 + t_n * noise
# Teacher denoising (ODE step)
score = self.teacher(x_n, t_n)
x_{n+1} = ode_step(x_n, t_n, t_{n+1}, score)
# Consistency loss
pred_online = self.student(x_{n+1}, t_{n+1})
pred_target = self.student_ema(x_n, t_n).detach()
loss = F.mse_loss(pred_online, pred_target)
# EMA update
ema_update(self.student_ema, self.student, mu=0.999)
return loss
2. Consistency Training (CT)¶
Score model 없이 데이터에서 직접 학습:
class ConsistencyTraining:
def train_step(self, x_0):
# Sample timestep pair
n = torch.randint(1, N)
t_n, t_{n+1} = schedule[n], schedule[n+1]
# 동일한 noise로 두 시점 샘플 생성
noise = torch.randn_like(x_0)
x_n = x_0 + t_n * noise
x_{n+1} = x_0 + t_{n+1} * noise
# Consistency loss
pred_online = self.model(x_{n+1}, t_{n+1})
pred_target = self.model_ema(x_n, t_n).detach()
loss = F.mse_loss(pred_online, pred_target)
return loss
주요 하이퍼파라미터¶
| 파라미터 | CD 권장값 | CT 권장값 | 설명 |
|---|---|---|---|
| N (스케줄 길이) | 18 | 150 | 시간 이산화 스텝 수 |
| mu (EMA rate) | 0.9999 | adaptive | Target network EMA |
| distance | LPIPS | L2 / LPIPS | 유사도 측정 함수 |
| sigma_data | 0.5 | 0.5 | 데이터 표준편차 |
샘플링¶
1-Step Sampling¶
def sample_one_step(model, batch_size, device):
"""
가장 빠른 샘플링: noise -> data 직접 매핑
"""
# Start from pure noise
x_T = torch.randn(batch_size, C, H, W, device=device) * T_max
# Single forward pass
x_0 = model(x_T, T_max)
return x_0
Multi-Step Sampling (품질 향상)¶
def sample_multi_step(model, batch_size, timesteps, device):
"""
더 많은 스텝으로 품질 개선
timesteps: [T_max, t_1, t_2, ..., epsilon]
"""
x = torch.randn(batch_size, C, H, W, device=device) * timesteps[0]
for i in range(len(timesteps) - 1):
# Denoise to data
x_0 = model(x, timesteps[i])
# Re-noise to next timestep (if not last)
if i < len(timesteps) - 2:
noise = torch.randn_like(x)
x = x_0 + timesteps[i+1] * noise
return x_0
샘플링 스텝 vs 품질 (CIFAR-10)¶
| Steps | FID (CD) | FID (CT) |
|---|---|---|
| 1 | 3.55 | 7.46 |
| 2 | 2.93 | 5.22 |
| 4 | 2.61 | 4.67 |
후속 연구¶
Latent Consistency Models (LCM, 2023)¶
Stable Diffusion 등 Latent Diffusion Model에 적용:
# LCM: Latent space에서 consistency 학습
# 32시간 A100 학습으로 2-4 step 768x768 생성
class LatentConsistencyModel:
def __init__(self, vae, ldm_teacher):
self.vae = vae
self.cm = ConsistencyModel()
def sample(self, prompt):
z_T = torch.randn(...) # Latent noise
z_0 = self.cm(z_T, T_max, prompt) # 1-4 steps
x_0 = self.vae.decode(z_0)
return x_0
Improved Consistency Training (iCT, 2023)¶
Distillation 없이 SOTA 달성:
| 개선점 | 설명 |
|---|---|
| Adaptive schedule | N을 학습 중 점진적 증가 |
| Pseudo-Huber loss | L2보다 robust한 distance |
| Variance reduction | Noise 분산 축소 기법 |
결과: CIFAR-10 1-step FID 2.51
Consistency Trajectory Models (CTM, 2023)¶
PF-ODE의 임의 구간 학습:
장점: - Long-jump sampling 가능 - Score function 복원 가능 - 더 유연한 샘플링 전략
Easy Consistency Tuning (ECT, 2024)¶
간소화된 학습 방법:
- 1시간 A100으로 CIFAR-10 2-step FID 2.73
- Pretrained diffusion에서 효율적 변환
- Adaptive EMA schedule
Simplified and Stabilized CM (sCM, 2024)¶
대규모 이미지 생성 안정화:
| 데이터셋 | Steps | FID |
|---|---|---|
| ImageNet 512x512 | 2 | 1.88 |
| ImageNet 64x64 | 1 | 1.48 |
핵심 기법: - TrigFlow 파라미터화 - Adaptive weighting - Stabilized training dynamics
Phased Consistency Models (PCM, 2024)¶
시간 구간을 phase로 분할:
장점: 각 phase를 독립적으로 학습, 더 나은 품질
응용 분야¶
1. 이미지 생성¶
# Text-to-Image with LCM-LoRA
from diffusers import DiffusionPipeline, LCMScheduler
pipe = DiffusionPipeline.from_pretrained("stabilityai/stable-diffusion-xl-base-1.0")
pipe.scheduler = LCMScheduler.from_config(pipe.scheduler.config)
pipe.load_lora_weights("latent-consistency/lcm-lora-sdxl")
# 4 steps로 고품질 생성
image = pipe(
"A photo of a cat",
num_inference_steps=4,
guidance_scale=1.0
).images[0]
2. 비디오 생성¶
# Motion Consistency Models
# 몇 스텝으로 비디오 프레임 생성
class VideoConsistencyModel:
def generate_video(self, prompt, num_frames=16):
# Temporal consistency 유지하며 생성
z_T = torch.randn(num_frames, C, H, W)
frames = self.model(z_T, T_max, prompt) # 4 steps
return frames
3. 오디오/음악 생성¶
# CoMoSpeech: 1-step TTS
# 150x faster than real-time
class ConsistencyTTS:
def synthesize(self, text):
mel_noise = torch.randn(...)
mel = self.model(mel_noise, T_max, text) # 1 step
audio = self.vocoder(mel)
return audio
4. 로보틱스¶
# Consistency Policy for Robot Control
# 10x faster than Diffusion Policy
class ConsistencyPolicy:
def get_action(self, observation):
noise = torch.randn(action_dim)
action = self.policy(noise, T_max, observation) # 1-2 steps
return action
5. Zero-Shot 이미지 편집¶
def inpainting(model, image, mask, T=1.0):
"""
학습 없이 inpainting 수행
"""
# Forward diffusion on masked region
noise = torch.randn_like(image)
x_T = image * (1 - mask) + (image + T * noise) * mask
# Consistency model로 복원
x_0 = model(x_T, T)
# 마스크되지 않은 영역 보존
result = image * (1 - mask) + x_0 * mask
return result
전체 구현 예제¶
import torch
import torch.nn as nn
import torch.nn.functional as F
import copy
from typing import List
class SinusoidalEmbedding(nn.Module):
"""시간 임베딩"""
def __init__(self, dim):
super().__init__()
self.dim = dim
def forward(self, t):
half_dim = self.dim // 2
emb = torch.log(torch.tensor(10000.0)) / (half_dim - 1)
emb = torch.exp(torch.arange(half_dim, device=t.device) * -emb)
emb = t[:, None] * emb[None, :]
return torch.cat([torch.sin(emb), torch.cos(emb)], dim=-1)
class ConsistencyUNet(nn.Module):
"""단순화된 U-Net 아키텍처"""
def __init__(self, in_channels=3, base_channels=64, sigma_data=0.5):
super().__init__()
self.sigma_data = sigma_data
# Time embedding
self.time_embed = nn.Sequential(
SinusoidalEmbedding(base_channels),
nn.Linear(base_channels, base_channels * 4),
nn.SiLU(),
nn.Linear(base_channels * 4, base_channels * 4)
)
# Encoder
self.enc1 = self._block(in_channels, base_channels)
self.enc2 = self._block(base_channels, base_channels * 2)
self.enc3 = self._block(base_channels * 2, base_channels * 4)
# Middle
self.mid = self._block(base_channels * 4, base_channels * 4)
# Decoder
self.dec3 = self._block(base_channels * 8, base_channels * 2)
self.dec2 = self._block(base_channels * 4, base_channels)
self.dec1 = self._block(base_channels * 2, base_channels)
# Output
self.out = nn.Conv2d(base_channels, in_channels, 3, padding=1)
# Pooling and upsampling
self.pool = nn.MaxPool2d(2)
self.up = nn.Upsample(scale_factor=2, mode='bilinear')
def _block(self, in_ch, out_ch):
return nn.Sequential(
nn.Conv2d(in_ch, out_ch, 3, padding=1),
nn.GroupNorm(8, out_ch),
nn.SiLU(),
nn.Conv2d(out_ch, out_ch, 3, padding=1),
nn.GroupNorm(8, out_ch),
nn.SiLU()
)
def forward(self, x, t):
# Skip connection coefficients (boundary condition)
c_skip = self.sigma_data**2 / (t**2 + self.sigma_data**2)
c_out = t * self.sigma_data / torch.sqrt(t**2 + self.sigma_data**2)
c_skip = c_skip.view(-1, 1, 1, 1)
c_out = c_out.view(-1, 1, 1, 1)
# Time embedding
t_emb = self.time_embed(t)
# Encoder
e1 = self.enc1(x)
e2 = self.enc2(self.pool(e1))
e3 = self.enc3(self.pool(e2))
# Middle
m = self.mid(self.pool(e3))
# Decoder with skip connections
d3 = self.dec3(torch.cat([self.up(m), e3], dim=1))
d2 = self.dec2(torch.cat([self.up(d3), e2], dim=1))
d1 = self.dec1(torch.cat([self.up(d2), e1], dim=1))
# Output with boundary condition
F_theta = self.out(d1)
return c_skip * x + c_out * F_theta
class ConsistencyTrainer:
"""Consistency Training 구현"""
def __init__(
self,
model: nn.Module,
T_max: float = 80.0,
T_min: float = 0.002,
N: int = 150,
mu_init: float = 0.9,
lr: float = 1e-4
):
self.model = model
self.model_ema = copy.deepcopy(model)
self.T_max = T_max
self.T_min = T_min
self.N = N
self.mu = mu_init
self.optimizer = torch.optim.Adam(model.parameters(), lr=lr)
# Time schedule (Karras schedule)
self.schedule = self._karras_schedule(N, T_max, T_min)
def _karras_schedule(self, N, T_max, T_min, rho=7.0):
"""Karras et al. time schedule"""
step_indices = torch.arange(N + 1)
t = (T_max ** (1/rho) + step_indices / N * (T_min ** (1/rho) - T_max ** (1/rho))) ** rho
return t
@torch.no_grad()
def _ema_update(self):
"""EMA update of target network"""
for p, p_ema in zip(self.model.parameters(), self.model_ema.parameters()):
p_ema.data.mul_(self.mu).add_(p.data, alpha=1 - self.mu)
def train_step(self, x_0: torch.Tensor) -> torch.Tensor:
"""Single training step"""
device = x_0.device
batch_size = x_0.shape[0]
# Sample timestep indices
n = torch.randint(0, self.N, (batch_size,), device=device)
t_n = self.schedule[n].to(device)
t_n1 = self.schedule[n + 1].to(device)
# Sample noise and create noisy samples
noise = torch.randn_like(x_0)
x_n = x_0 + t_n.view(-1, 1, 1, 1) * noise
x_n1 = x_0 + t_n1.view(-1, 1, 1, 1) * noise
# Forward pass
pred_online = self.model(x_n1, t_n1)
with torch.no_grad():
pred_target = self.model_ema(x_n, t_n)
# Pseudo-Huber loss
c = 0.00054 * torch.sqrt(torch.tensor(x_0.numel() / batch_size))
loss = torch.sqrt((pred_online - pred_target)**2 + c**2) - c
loss = loss.mean()
# Optimization
self.optimizer.zero_grad()
loss.backward()
self.optimizer.step()
# EMA update
self._ema_update()
return loss
@torch.no_grad()
def sample(self, batch_size: int, device: torch.device, steps: int = 1) -> torch.Tensor:
"""Generate samples"""
self.model_ema.eval()
# Sample timesteps for multi-step
if steps == 1:
timesteps = [self.T_max]
else:
indices = torch.linspace(0, self.N - 1, steps).long()
timesteps = self.schedule[indices].tolist()
# Start from noise
x = torch.randn(batch_size, 3, 32, 32, device=device) * self.T_max
for i, t in enumerate(timesteps):
t_tensor = torch.full((batch_size,), t, device=device)
x = self.model_ema(x, t_tensor)
# Re-noise for next step (except last)
if i < len(timesteps) - 1:
noise = torch.randn_like(x)
x = x + timesteps[i + 1] * noise
return x.clamp(-1, 1)
# 사용 예제
if __name__ == "__main__":
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
# Model initialization
model = ConsistencyUNet(in_channels=3, base_channels=64).to(device)
trainer = ConsistencyTrainer(model, N=150)
# Training loop
for epoch in range(100):
for batch in dataloader:
x_0 = batch.to(device)
loss = trainer.train_step(x_0)
# Adaptive N schedule
if epoch % 10 == 0:
trainer.N = min(trainer.N + 10, 200)
trainer.schedule = trainer._karras_schedule(
trainer.N, trainer.T_max, trainer.T_min
)
print(f"Epoch {epoch}, Loss: {loss.item():.4f}")
# Sampling
samples = trainer.sample(batch_size=16, device=device, steps=1)
print(f"Generated {samples.shape}")
Diffusion vs Consistency 비교¶
| 측면 | Diffusion Models | Consistency Models |
|---|---|---|
| 샘플링 스텝 | 50-1000 | 1-4 |
| 샘플링 시간 | 수십 초 | 밀리초 |
| 학습 안정성 | 높음 | 중간 |
| 품질 (최대) | 매우 높음 | 높음 |
| Zero-shot editing | 제한적 | 지원 |
| Real-time 적용 | 어려움 | 가능 |
핵심 요약¶
| 항목 | 내용 |
|---|---|
| 핵심 아이디어 | PF-ODE 궤적의 self-consistency를 학습하여 단일 스텝 생성 |
| 학습 방식 | Distillation (CD) 또는 독립 훈련 (CT) |
| 장점 | 1-step 생성, zero-shot editing, 빠른 샘플링 |
| 한계 | Diffusion 대비 약간 낮은 최대 품질 |
| 핵심 후속 | LCM (latent space), CTM (trajectory), sCM (stability) |
참고 문헌¶
- Song, Y., et al. "Consistency Models." ICML 2023. arXiv:2303.01469
- Song, Y., et al. "Improved Techniques for Training Consistency Models." arXiv:2310.14189
- Luo, S., et al. "Latent Consistency Models." arXiv:2310.04378
- Kim, D., et al. "Consistency Trajectory Models." arXiv:2310.02279
- Geng, Z., et al. "Consistency Models Made Easy." ICLR 2025. arXiv:2406.14548
- Lu, Y., et al. "Simplifying, Stabilizing and Scaling Continuous-Time Consistency Models." arXiv:2410.11081
- Wang, F., et al. "Phased Consistency Models." arXiv:2405.18407