콘텐츠로 이동
Data Prep
상세

Consistency Models

메타정보

항목 내용
논문 Consistency Models
저자 Yang Song, Prafulla Dhariwal, Mark Chen, Ilya Sutskever (OpenAI)
발표 ICML 2023
후속 Improved CM (2023), LCM (2023), CTM (2023), ECT (2024), sCM (2024), PCM (2024)
arXiv 2303.01469
키워드 Generative Models, Diffusion, One-Step Generation, Distillation, Score Matching

개요

Consistency Models (CM)는 noise에서 data로 직접 매핑하는 생성 모델로, diffusion 모델의 느린 샘플링 문제를 해결한다. 단일 스텝 또는 소수 스텝으로 고품질 샘플을 생성할 수 있다.

핵심 통찰: - Self-consistency property: PF-ODE 궤적의 모든 점이 동일한 데이터로 매핑 - Pretrained diffusion 모델에서 distillation 또는 독립 훈련 가능 - Zero-shot editing (inpainting, colorization, super-resolution) 지원 - CIFAR-10에서 1-step FID 3.55 (당시 SOTA)


배경: Diffusion 모델의 한계

Diffusion 모델의 샘플링 과정

Forward Process:
x_0 -> x_1 -> x_2 -> ... -> x_T (noise)

Reverse Process (sampling):
x_T -> x_{T-1} -> ... -> x_1 -> x_0 (data)

문제점

측면 설명
느린 샘플링 수백~수천 스텝 필요 (DDPM: 1000 steps)
계산 비용 각 스텝마다 신경망 forward pass
실시간 적용 고해상도 이미지 생성에 수십 초

기존 가속화 기법의 한계

DDIM: 50-100 steps로 축소 가능, 여전히 느림
DPM-Solver: 10-20 steps, 품질 저하 발생
Knowledge Distillation: 복잡한 학습, 불안정

Consistency Models의 핵심 아이디어

Probability Flow ODE (PF-ODE)

Diffusion 모델의 forward/reverse 과정은 ODE로 표현된다:

dx = [f(x,t) - (1/2)g(t)^2 * score(x,t)] dt

여기서:
- f(x,t): drift coefficient
- g(t): diffusion coefficient  
- score(x,t) = grad_x log p_t(x)

Self-Consistency Property

핵심 정의: PF-ODE 궤적 상의 모든 점은 동일한 초기 데이터로 수렴한다.

f_theta(x_t, t) = f_theta(x_{t'}, t')  for all t, t' on same trajectory

목표: f_theta: (x_t, t) -> x_0

이 성질을 이용하면: - 임의의 시점 t에서 단일 스텝으로 데이터 복원 가능 - 중간 스텝 없이 noise -> data 직접 매핑

시각적 이해

Diffusion (순차적):
noise --[step 1]--> ... --[step N]--> data

Consistency (직접):
noise --[single step]--> data
   |
   +-- 궤적의 어느 점에서도 동일한 종착점

수학적 정의

Consistency Function

Consistency function f: (x, t) -> x_epsilon 은 다음 조건을 만족:

1. Boundary condition:
   f(x, epsilon) = x    (t=epsilon에서 항등 함수)

2. Self-consistency:
   f(x_t, t) = f(x_{t'}, t')    (동일 궤적의 모든 점)

모델 파라미터화

def consistency_model(x, t, F_theta):
    """
    Skip connection으로 boundary condition 만족

    c_skip(t): skip coefficient, c_skip(epsilon) = 1
    c_out(t): output coefficient, c_out(epsilon) = 0
    """
    c_skip = sigma_data**2 / (t**2 + sigma_data**2)
    c_out = t * sigma_data / sqrt(t**2 + sigma_data**2)

    return c_skip * x + c_out * F_theta(x, t)

Loss Function

Consistency Distillation (CD) Loss:

def cd_loss(x, t_n, t_{n+1}, theta, theta_minus):
    """
    theta_minus: EMA of theta (target network)
    """
    # ODE solver로 x_{t_n}에서 x_{t_{n+1}} 계산
    x_next = ode_solver(x, t_n, t_{n+1}, score_model)

    # Consistency loss
    loss = distance(
        f_theta(x_next, t_{n+1}),    # online network
        f_theta_minus(x, t_n)         # target network (stop grad)
    )
    return loss

Consistency Training (CT) Loss:

def ct_loss(x_0, t_n, t_{n+1}, theta, theta_minus):
    """
    Score model 없이 데이터에서 직접 학습
    """
    # Forward diffusion
    noise = torch.randn_like(x_0)
    x_n = x_0 + t_n * noise
    x_{n+1} = x_0 + t_{n+1} * noise  # 동일한 noise 사용!

    loss = distance(
        f_theta(x_{n+1}, t_{n+1}),
        f_theta_minus(x_n, t_n)
    )
    return loss

훈련 방법

1. Consistency Distillation (CD)

Pretrained diffusion 모델로부터 knowledge distillation:

class ConsistencyDistillation:
    def __init__(self, pretrained_diffusion):
        self.teacher = pretrained_diffusion
        self.student = ConsistencyModel()
        self.student_ema = copy.deepcopy(self.student)

    def train_step(self, x_0):
        # Sample timestep pair
        n = torch.randint(1, N)
        t_n, t_{n+1} = schedule[n], schedule[n+1]

        # Add noise
        noise = torch.randn_like(x_0)
        x_n = x_0 + t_n * noise

        # Teacher denoising (ODE step)
        score = self.teacher(x_n, t_n)
        x_{n+1} = ode_step(x_n, t_n, t_{n+1}, score)

        # Consistency loss
        pred_online = self.student(x_{n+1}, t_{n+1})
        pred_target = self.student_ema(x_n, t_n).detach()

        loss = F.mse_loss(pred_online, pred_target)

        # EMA update
        ema_update(self.student_ema, self.student, mu=0.999)

        return loss

2. Consistency Training (CT)

Score model 없이 데이터에서 직접 학습:

class ConsistencyTraining:
    def train_step(self, x_0):
        # Sample timestep pair
        n = torch.randint(1, N)
        t_n, t_{n+1} = schedule[n], schedule[n+1]

        # 동일한 noise로 두 시점 샘플 생성
        noise = torch.randn_like(x_0)
        x_n = x_0 + t_n * noise
        x_{n+1} = x_0 + t_{n+1} * noise

        # Consistency loss
        pred_online = self.model(x_{n+1}, t_{n+1})
        pred_target = self.model_ema(x_n, t_n).detach()

        loss = F.mse_loss(pred_online, pred_target)

        return loss

주요 하이퍼파라미터

파라미터 CD 권장값 CT 권장값 설명
N (스케줄 길이) 18 150 시간 이산화 스텝 수
mu (EMA rate) 0.9999 adaptive Target network EMA
distance LPIPS L2 / LPIPS 유사도 측정 함수
sigma_data 0.5 0.5 데이터 표준편차

샘플링

1-Step Sampling

def sample_one_step(model, batch_size, device):
    """
    가장 빠른 샘플링: noise -> data 직접 매핑
    """
    # Start from pure noise
    x_T = torch.randn(batch_size, C, H, W, device=device) * T_max

    # Single forward pass
    x_0 = model(x_T, T_max)

    return x_0

Multi-Step Sampling (품질 향상)

def sample_multi_step(model, batch_size, timesteps, device):
    """
    더 많은 스텝으로 품질 개선
    timesteps: [T_max, t_1, t_2, ..., epsilon]
    """
    x = torch.randn(batch_size, C, H, W, device=device) * timesteps[0]

    for i in range(len(timesteps) - 1):
        # Denoise to data
        x_0 = model(x, timesteps[i])

        # Re-noise to next timestep (if not last)
        if i < len(timesteps) - 2:
            noise = torch.randn_like(x)
            x = x_0 + timesteps[i+1] * noise

    return x_0

샘플링 스텝 vs 품질 (CIFAR-10)

Steps FID (CD) FID (CT)
1 3.55 7.46
2 2.93 5.22
4 2.61 4.67

후속 연구

Latent Consistency Models (LCM, 2023)

Stable Diffusion 등 Latent Diffusion Model에 적용:

# LCM: Latent space에서 consistency 학습
# 32시간 A100 학습으로 2-4 step 768x768 생성

class LatentConsistencyModel:
    def __init__(self, vae, ldm_teacher):
        self.vae = vae
        self.cm = ConsistencyModel()

    def sample(self, prompt):
        z_T = torch.randn(...)  # Latent noise
        z_0 = self.cm(z_T, T_max, prompt)  # 1-4 steps
        x_0 = self.vae.decode(z_0)
        return x_0

Improved Consistency Training (iCT, 2023)

Distillation 없이 SOTA 달성:

개선점 설명
Adaptive schedule N을 학습 중 점진적 증가
Pseudo-Huber loss L2보다 robust한 distance
Variance reduction Noise 분산 축소 기법

결과: CIFAR-10 1-step FID 2.51

Consistency Trajectory Models (CTM, 2023)

PF-ODE의 임의 구간 학습:

기존 CM: (x_t, t) -> x_0
CTM: (x_t, t, s) -> x_s    (s < t 임의)

장점: - Long-jump sampling 가능 - Score function 복원 가능 - 더 유연한 샘플링 전략

Easy Consistency Tuning (ECT, 2024)

간소화된 학습 방법:

  • 1시간 A100으로 CIFAR-10 2-step FID 2.73
  • Pretrained diffusion에서 효율적 변환
  • Adaptive EMA schedule

Simplified and Stabilized CM (sCM, 2024)

대규모 이미지 생성 안정화:

데이터셋 Steps FID
ImageNet 512x512 2 1.88
ImageNet 64x64 1 1.48

핵심 기법: - TrigFlow 파라미터화 - Adaptive weighting - Stabilized training dynamics

Phased Consistency Models (PCM, 2024)

시간 구간을 phase로 분할:

Phase 1: t in [T, T/2]  -> CM_1
Phase 2: t in [T/2, T/4] -> CM_2
Phase 3: t in [T/4, 0]   -> CM_3

장점: 각 phase를 독립적으로 학습, 더 나은 품질


응용 분야

1. 이미지 생성

# Text-to-Image with LCM-LoRA
from diffusers import DiffusionPipeline, LCMScheduler

pipe = DiffusionPipeline.from_pretrained("stabilityai/stable-diffusion-xl-base-1.0")
pipe.scheduler = LCMScheduler.from_config(pipe.scheduler.config)
pipe.load_lora_weights("latent-consistency/lcm-lora-sdxl")

# 4 steps로 고품질 생성
image = pipe(
    "A photo of a cat",
    num_inference_steps=4,
    guidance_scale=1.0
).images[0]

2. 비디오 생성

# Motion Consistency Models
# 몇 스텝으로 비디오 프레임 생성

class VideoConsistencyModel:
    def generate_video(self, prompt, num_frames=16):
        # Temporal consistency 유지하며 생성
        z_T = torch.randn(num_frames, C, H, W)
        frames = self.model(z_T, T_max, prompt)  # 4 steps
        return frames

3. 오디오/음악 생성

# CoMoSpeech: 1-step TTS
# 150x faster than real-time

class ConsistencyTTS:
    def synthesize(self, text):
        mel_noise = torch.randn(...)
        mel = self.model(mel_noise, T_max, text)  # 1 step
        audio = self.vocoder(mel)
        return audio

4. 로보틱스

# Consistency Policy for Robot Control
# 10x faster than Diffusion Policy

class ConsistencyPolicy:
    def get_action(self, observation):
        noise = torch.randn(action_dim)
        action = self.policy(noise, T_max, observation)  # 1-2 steps
        return action

5. Zero-Shot 이미지 편집

def inpainting(model, image, mask, T=1.0):
    """
    학습 없이 inpainting 수행
    """
    # Forward diffusion on masked region
    noise = torch.randn_like(image)
    x_T = image * (1 - mask) + (image + T * noise) * mask

    # Consistency model로 복원
    x_0 = model(x_T, T)

    # 마스크되지 않은 영역 보존
    result = image * (1 - mask) + x_0 * mask
    return result

전체 구현 예제

import torch
import torch.nn as nn
import torch.nn.functional as F
import copy
from typing import List

class SinusoidalEmbedding(nn.Module):
    """시간 임베딩"""
    def __init__(self, dim):
        super().__init__()
        self.dim = dim

    def forward(self, t):
        half_dim = self.dim // 2
        emb = torch.log(torch.tensor(10000.0)) / (half_dim - 1)
        emb = torch.exp(torch.arange(half_dim, device=t.device) * -emb)
        emb = t[:, None] * emb[None, :]
        return torch.cat([torch.sin(emb), torch.cos(emb)], dim=-1)


class ConsistencyUNet(nn.Module):
    """단순화된 U-Net 아키텍처"""
    def __init__(self, in_channels=3, base_channels=64, sigma_data=0.5):
        super().__init__()
        self.sigma_data = sigma_data

        # Time embedding
        self.time_embed = nn.Sequential(
            SinusoidalEmbedding(base_channels),
            nn.Linear(base_channels, base_channels * 4),
            nn.SiLU(),
            nn.Linear(base_channels * 4, base_channels * 4)
        )

        # Encoder
        self.enc1 = self._block(in_channels, base_channels)
        self.enc2 = self._block(base_channels, base_channels * 2)
        self.enc3 = self._block(base_channels * 2, base_channels * 4)

        # Middle
        self.mid = self._block(base_channels * 4, base_channels * 4)

        # Decoder
        self.dec3 = self._block(base_channels * 8, base_channels * 2)
        self.dec2 = self._block(base_channels * 4, base_channels)
        self.dec1 = self._block(base_channels * 2, base_channels)

        # Output
        self.out = nn.Conv2d(base_channels, in_channels, 3, padding=1)

        # Pooling and upsampling
        self.pool = nn.MaxPool2d(2)
        self.up = nn.Upsample(scale_factor=2, mode='bilinear')

    def _block(self, in_ch, out_ch):
        return nn.Sequential(
            nn.Conv2d(in_ch, out_ch, 3, padding=1),
            nn.GroupNorm(8, out_ch),
            nn.SiLU(),
            nn.Conv2d(out_ch, out_ch, 3, padding=1),
            nn.GroupNorm(8, out_ch),
            nn.SiLU()
        )

    def forward(self, x, t):
        # Skip connection coefficients (boundary condition)
        c_skip = self.sigma_data**2 / (t**2 + self.sigma_data**2)
        c_out = t * self.sigma_data / torch.sqrt(t**2 + self.sigma_data**2)
        c_skip = c_skip.view(-1, 1, 1, 1)
        c_out = c_out.view(-1, 1, 1, 1)

        # Time embedding
        t_emb = self.time_embed(t)

        # Encoder
        e1 = self.enc1(x)
        e2 = self.enc2(self.pool(e1))
        e3 = self.enc3(self.pool(e2))

        # Middle
        m = self.mid(self.pool(e3))

        # Decoder with skip connections
        d3 = self.dec3(torch.cat([self.up(m), e3], dim=1))
        d2 = self.dec2(torch.cat([self.up(d3), e2], dim=1))
        d1 = self.dec1(torch.cat([self.up(d2), e1], dim=1))

        # Output with boundary condition
        F_theta = self.out(d1)
        return c_skip * x + c_out * F_theta


class ConsistencyTrainer:
    """Consistency Training 구현"""

    def __init__(
        self,
        model: nn.Module,
        T_max: float = 80.0,
        T_min: float = 0.002,
        N: int = 150,
        mu_init: float = 0.9,
        lr: float = 1e-4
    ):
        self.model = model
        self.model_ema = copy.deepcopy(model)
        self.T_max = T_max
        self.T_min = T_min
        self.N = N
        self.mu = mu_init

        self.optimizer = torch.optim.Adam(model.parameters(), lr=lr)

        # Time schedule (Karras schedule)
        self.schedule = self._karras_schedule(N, T_max, T_min)

    def _karras_schedule(self, N, T_max, T_min, rho=7.0):
        """Karras et al. time schedule"""
        step_indices = torch.arange(N + 1)
        t = (T_max ** (1/rho) + step_indices / N * (T_min ** (1/rho) - T_max ** (1/rho))) ** rho
        return t

    @torch.no_grad()
    def _ema_update(self):
        """EMA update of target network"""
        for p, p_ema in zip(self.model.parameters(), self.model_ema.parameters()):
            p_ema.data.mul_(self.mu).add_(p.data, alpha=1 - self.mu)

    def train_step(self, x_0: torch.Tensor) -> torch.Tensor:
        """Single training step"""
        device = x_0.device
        batch_size = x_0.shape[0]

        # Sample timestep indices
        n = torch.randint(0, self.N, (batch_size,), device=device)
        t_n = self.schedule[n].to(device)
        t_n1 = self.schedule[n + 1].to(device)

        # Sample noise and create noisy samples
        noise = torch.randn_like(x_0)
        x_n = x_0 + t_n.view(-1, 1, 1, 1) * noise
        x_n1 = x_0 + t_n1.view(-1, 1, 1, 1) * noise

        # Forward pass
        pred_online = self.model(x_n1, t_n1)

        with torch.no_grad():
            pred_target = self.model_ema(x_n, t_n)

        # Pseudo-Huber loss
        c = 0.00054 * torch.sqrt(torch.tensor(x_0.numel() / batch_size))
        loss = torch.sqrt((pred_online - pred_target)**2 + c**2) - c
        loss = loss.mean()

        # Optimization
        self.optimizer.zero_grad()
        loss.backward()
        self.optimizer.step()

        # EMA update
        self._ema_update()

        return loss

    @torch.no_grad()
    def sample(self, batch_size: int, device: torch.device, steps: int = 1) -> torch.Tensor:
        """Generate samples"""
        self.model_ema.eval()

        # Sample timesteps for multi-step
        if steps == 1:
            timesteps = [self.T_max]
        else:
            indices = torch.linspace(0, self.N - 1, steps).long()
            timesteps = self.schedule[indices].tolist()

        # Start from noise
        x = torch.randn(batch_size, 3, 32, 32, device=device) * self.T_max

        for i, t in enumerate(timesteps):
            t_tensor = torch.full((batch_size,), t, device=device)
            x = self.model_ema(x, t_tensor)

            # Re-noise for next step (except last)
            if i < len(timesteps) - 1:
                noise = torch.randn_like(x)
                x = x + timesteps[i + 1] * noise

        return x.clamp(-1, 1)


# 사용 예제
if __name__ == "__main__":
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

    # Model initialization
    model = ConsistencyUNet(in_channels=3, base_channels=64).to(device)
    trainer = ConsistencyTrainer(model, N=150)

    # Training loop
    for epoch in range(100):
        for batch in dataloader:
            x_0 = batch.to(device)
            loss = trainer.train_step(x_0)

        # Adaptive N schedule
        if epoch % 10 == 0:
            trainer.N = min(trainer.N + 10, 200)
            trainer.schedule = trainer._karras_schedule(
                trainer.N, trainer.T_max, trainer.T_min
            )

        print(f"Epoch {epoch}, Loss: {loss.item():.4f}")

    # Sampling
    samples = trainer.sample(batch_size=16, device=device, steps=1)
    print(f"Generated {samples.shape}")

Diffusion vs Consistency 비교

측면 Diffusion Models Consistency Models
샘플링 스텝 50-1000 1-4
샘플링 시간 수십 초 밀리초
학습 안정성 높음 중간
품질 (최대) 매우 높음 높음
Zero-shot editing 제한적 지원
Real-time 적용 어려움 가능

핵심 요약

항목 내용
핵심 아이디어 PF-ODE 궤적의 self-consistency를 학습하여 단일 스텝 생성
학습 방식 Distillation (CD) 또는 독립 훈련 (CT)
장점 1-step 생성, zero-shot editing, 빠른 샘플링
한계 Diffusion 대비 약간 낮은 최대 품질
핵심 후속 LCM (latent space), CTM (trajectory), sCM (stability)

참고 문헌

  1. Song, Y., et al. "Consistency Models." ICML 2023. arXiv:2303.01469
  2. Song, Y., et al. "Improved Techniques for Training Consistency Models." arXiv:2310.14189
  3. Luo, S., et al. "Latent Consistency Models." arXiv:2310.04378
  4. Kim, D., et al. "Consistency Trajectory Models." arXiv:2310.02279
  5. Geng, Z., et al. "Consistency Models Made Easy." ICLR 2025. arXiv:2406.14548
  6. Lu, Y., et al. "Simplifying, Stabilizing and Scaling Continuous-Time Consistency Models." arXiv:2410.11081
  7. Wang, F., et al. "Phased Consistency Models." arXiv:2405.18407