MeanFlow: One-step Generative Modeling¶
메타정보¶
| 항목 | 내용 |
|---|---|
| 논문 | Mean Flows for One-step Generative Modeling |
| 저자 | Zhengyang Geng (CMU), Mingyang Deng, Xingjian Bai, J. Zico Kolter, Kaiming He (MIT) |
| 발표 | NeurIPS 2025 Oral |
| arXiv | 2505.13447 |
| 코드 | github.com/haidog-yaqub/MeanFlow (unofficial) |
개요¶
MeanFlow는 한 번의 forward pass로 고품질 이미지를 생성하는 프레임워크다. Flow Matching이 순간 속도(instantaneous velocity)를 모델링하는 것과 달리, 평균 속도(average velocity) 개념을 도입하여 one-step generation을 가능하게 한다.
핵심 성과: - ImageNet 256x256에서 FID 3.43 (1-NFE, single function evaluation) - 기존 one-step 모델 대비 50-70% 성능 향상 - Pre-training, distillation, curriculum learning 불필요
배경: Flow Matching의 한계¶
Flow Matching 복습¶
Flow Matching은 prior 분포를 data 분포로 변환하는 velocity field를 학습한다:
일반적인 스케줄: a_t = 1-t, b_t = t -> v_t = epsilon - x
문제점¶
- 곡선 궤적: Conditional flow가 직선이어도 marginal velocity field는 곡선 궤적 생성
- 다단계 샘플링 필수: ODE solver로 반복 계산 필요
- Coarse discretization 오류: 적은 step에서 정확도 급락
MeanFlow 핵심 아이디어¶
평균 속도 (Average Velocity)¶
순간 속도 대신 시간 구간 [r, t]에서의 평균 속도를 정의:
여기서:
- u: 평균 속도 (average velocity)
- v: 순간 속도 (instantaneous velocity)
- [r, t]: 시간 구간
핵심 특성¶
1. 경계 조건
2. 자연스러운 일관성 (Consistency)
큰 step 하나 = 작은 step 두 개의 합 (적분의 가법성에서 유도)
3. One-step Generation
u(epsilon, 0, 1)만 계산하면 전체 궤적을 한 번에 근사
방법론¶
학습 목표¶
평균 속도와 순간 속도 사이의 항등식을 만족하도록 학습:
Loss Function¶
Self-consistency loss 기반:
- 평균 속도 network u_theta(z, r, t) 학습
- 순간 속도는 r -> t 극한에서 도출
- 외부 teacher model이나 distillation 불필요
아키텍처¶
- DiT (Diffusion Transformer) 기반
- 추가 시간 조건
r입력 처리 - Classifier-Free Guidance (CFG) 내장 가능
CFG 통합¶
평균 속도 field에 CFG를 직접 통합: - 샘플링 시 추가 비용 없음 - Multi-step 모델처럼 별도 CFG 계산 불필요
실험 결과¶
ImageNet 256x256 벤치마크¶
| Method | NFE | FID |
|---|---|---|
| iCT | 1 | 10.3 |
| Shortcut | 1 | 7.8 |
| IMM (2-NFE guidance) | 2 | 5.1 |
| MeanFlow | 1 | 3.43 |
주요 비교¶
- iCT (improved Consistency Training): FID 10.3 -> MeanFlow 대비 3배 열등
- Shortcut Models: FID 7.8 -> MeanFlow 대비 2배 이상 열등
- IMM (Inductive Moment Matching): 2-NFE 사용해도 MeanFlow 1-NFE보다 열등
Multi-step과의 격차 해소¶
| Model Type | FID |
|---|---|
| Multi-step Flow Matching (50-NFE) | ~2.5 |
| MeanFlow (1-NFE) | 3.43 |
One-step과 multi-step 사이의 성능 격차를 크게 줄임
Consistency Models과의 비교¶
| 측면 | Consistency Models | MeanFlow |
|---|---|---|
| 기반 | 네트워크 행동 제약 | Ground-truth field 학습 |
| 학습 | Curriculum learning 필수 | Curriculum 불필요 |
| 안정성 | 불안정할 수 있음 | 더 안정적 |
| 이론적 기반 | Heuristic constraint | 수학적 항등식 |
MeanFlow는 ground-truth target field가 존재하여 최적 해가 네트워크에 독립적
Python 구현 예시¶
평균 속도 개념¶
import torch
import torch.nn as nn
class MeanFlowModel(nn.Module):
"""
MeanFlow: Average velocity field modeling for one-step generation
u(z_t, r, t) = displacement / (t - r)
where displacement = integral_r^t v(z_tau, tau) d_tau
"""
def __init__(self, hidden_dim=512, num_layers=6):
super().__init__()
# 입력: z_t (latent), r (start time), t (end time)
self.time_embed = nn.Sequential(
nn.Linear(2, hidden_dim), # (r, t) 두 시간 조건
nn.SiLU(),
nn.Linear(hidden_dim, hidden_dim)
)
self.backbone = nn.ModuleList([
nn.Sequential(
nn.Linear(hidden_dim, hidden_dim),
nn.LayerNorm(hidden_dim),
nn.SiLU()
) for _ in range(num_layers)
])
self.output = nn.Linear(hidden_dim, hidden_dim)
def forward(self, z_t, r, t):
"""
Args:
z_t: Latent at time t, shape (B, D)
r: Start time, shape (B, 1)
t: End time, shape (B, 1)
Returns:
u: Average velocity, shape (B, D)
"""
# 시간 임베딩 (r, t 모두 조건으로 사용)
time_cond = torch.cat([r, t], dim=-1)
time_emb = self.time_embed(time_cond)
h = z_t + time_emb
for layer in self.backbone:
h = layer(h) + h # Residual
return self.output(h)
def one_step_generate(model, noise):
"""
One-step generation using MeanFlow
x = z_1 - (1 - 0) * u(z_1, 0, 1)
= z_1 - u(z_1, 0, 1)
"""
batch_size = noise.shape[0]
r = torch.zeros(batch_size, 1, device=noise.device)
t = torch.ones(batch_size, 1, device=noise.device)
# 평균 속도 예측
avg_velocity = model(noise, r, t)
# One-step generation
x_generated = noise - avg_velocity
return x_generated
Self-consistency Loss¶
def meanflow_loss(model, x_data, noise):
"""
MeanFlow self-consistency loss
핵심 아이디어:
- 랜덤 시간 (r, s, t) 샘플링 (r < s < t)
- 큰 step의 평균 속도 = 작은 step들의 가중 평균
Loss:
(t-r) * u(z_t, r, t) should equal
(s-r) * u(z_s, r, s) + (t-s) * u(z_t, s, t)
"""
batch_size = x_data.shape[0]
device = x_data.device
# 시간 샘플링: 0 < r < s < t < 1
times = torch.rand(batch_size, 3, device=device).sort(dim=1).values
r, s, t = times[:, 0:1], times[:, 1:2], times[:, 2:3]
# Flow path 구성 (linear interpolation)
z_t = (1 - t) * x_data + t * noise
z_s = (1 - s) * x_data + s * noise
# 평균 속도 예측
u_full = model(z_t, r, t) # u(z_t, r, t)
u_first = model(z_s, r, s) # u(z_s, r, s)
u_second = model(z_t, s, t) # u(z_t, s, t)
# Self-consistency: 가법성 검증
# (t-r) * u_full = (s-r) * u_first + (t-s) * u_second
lhs = (t - r) * u_full
rhs = (s - r) * u_first + (t - s) * u_second
loss = ((lhs - rhs) ** 2).mean()
return loss
def boundary_loss(model, x_data, noise):
"""
Boundary condition loss: lim(r->t) u(z_t, r, t) = v(z_t, t)
When r is very close to t, average velocity should equal
instantaneous velocity (tangent to the flow path)
"""
batch_size = x_data.shape[0]
device = x_data.device
# r을 t에 가깝게 설정
t = torch.rand(batch_size, 1, device=device) * 0.9 + 0.1
delta = torch.rand(batch_size, 1, device=device) * 0.01 # 작은 delta
r = t - delta
z_t = (1 - t) * x_data + t * noise
# 평균 속도
u = model(z_t, r, t)
# Ground-truth 순간 속도 (linear flow의 경우)
v_gt = noise - x_data
loss = ((u - v_gt) ** 2).mean()
return loss
DiT 기반 구현¶
class MeanFlowDiT(nn.Module):
"""
DiT-based MeanFlow model
Key modification: Two time conditions (r, t) instead of one
"""
def __init__(
self,
input_size=32,
patch_size=2,
in_channels=4,
hidden_size=1152,
depth=28,
num_heads=16,
mlp_ratio=4.0,
class_dropout_prob=0.1,
num_classes=1000
):
super().__init__()
self.input_size = input_size
self.patch_size = patch_size
self.hidden_size = hidden_size
# Patch embedding
self.x_embedder = PatchEmbed(
input_size, patch_size, in_channels, hidden_size
)
# Two time embeddings (r and t)
self.r_embedder = TimestepEmbedder(hidden_size)
self.t_embedder = TimestepEmbedder(hidden_size)
# Class embedding for CFG
self.y_embedder = LabelEmbedder(
num_classes, hidden_size, class_dropout_prob
)
# Transformer blocks
self.blocks = nn.ModuleList([
DiTBlock(hidden_size, num_heads, mlp_ratio)
for _ in range(depth)
])
# Output projection
self.final_layer = FinalLayer(
hidden_size, patch_size, in_channels
)
def forward(self, z_t, r, t, y):
"""
Args:
z_t: Noisy latent (B, C, H, W)
r: Start timestep (B,)
t: End timestep (B,)
y: Class labels (B,)
"""
# Embeddings
x = self.x_embedder(z_t)
r_emb = self.r_embedder(r)
t_emb = self.t_embedder(t)
y_emb = self.y_embedder(y)
# Combine conditions
c = r_emb + t_emb + y_emb
# Transformer
for block in self.blocks:
x = block(x, c)
# Unpatchify
u = self.final_layer(x, c)
return u
class TimestepEmbedder(nn.Module):
"""Sinusoidal timestep embedding"""
def __init__(self, hidden_size, freq_embed_size=256):
super().__init__()
self.mlp = nn.Sequential(
nn.Linear(freq_embed_size, hidden_size),
nn.SiLU(),
nn.Linear(hidden_size, hidden_size)
)
self.freq_embed_size = freq_embed_size
def forward(self, t):
freqs = torch.exp(
-torch.arange(0, self.freq_embed_size, 2, device=t.device)
* (torch.log(torch.tensor(10000.0)) / self.freq_embed_size)
)
args = t[:, None] * freqs[None]
embedding = torch.cat([torch.cos(args), torch.sin(args)], dim=-1)
return self.mlp(embedding)
학습 루프¶
def train_meanflow(
model,
dataloader,
optimizer,
epochs=100,
alpha=0.1 # Boundary loss weight
):
"""
MeanFlow training loop
"""
model.train()
for epoch in range(epochs):
epoch_loss = 0.0
for batch in dataloader:
x_data = batch['images']
labels = batch['labels']
# Prior sample
noise = torch.randn_like(x_data)
# Consistency loss (main)
loss_cons = meanflow_consistency_loss(
model, x_data, noise, labels
)
# Boundary loss (regularization)
loss_bound = meanflow_boundary_loss(
model, x_data, noise, labels
)
loss = loss_cons + alpha * loss_bound
optimizer.zero_grad()
loss.backward()
optimizer.step()
epoch_loss += loss.item()
print(f"Epoch {epoch}: Loss = {epoch_loss / len(dataloader):.4f}")
def meanflow_consistency_loss(model, x_data, noise, labels):
"""Full consistency loss with class conditioning"""
batch_size = x_data.shape[0]
device = x_data.device
# Sample three ordered times
times = torch.rand(batch_size, 3, device=device)
times = times.sort(dim=1).values
r = times[:, 0]
s = times[:, 1]
t = times[:, 2]
# Construct latents at different times
z_t = (1 - t.view(-1, 1, 1, 1)) * x_data + t.view(-1, 1, 1, 1) * noise
z_s = (1 - s.view(-1, 1, 1, 1)) * x_data + s.view(-1, 1, 1, 1) * noise
# Predict average velocities
with torch.no_grad():
# Stop gradient for targets (similar to EMA target)
u_first_target = model(z_s, r, s, labels).detach()
u_second_target = model(z_t, s, t, labels).detach()
u_full = model(z_t, r, t, labels)
# Consistency target
target = (
(s - r).view(-1, 1, 1, 1) * u_first_target +
(t - s).view(-1, 1, 1, 1) * u_second_target
) / (t - r).view(-1, 1, 1, 1)
loss = nn.functional.mse_loss(u_full, target)
return loss
핵심 인사이트¶
왜 MeanFlow가 작동하는가¶
- Ground-truth Field 존재: 평균 속도
u는 순간 속도v로부터 유도된 well-defined field - 네트워크 독립적 목표: 최적 해가 네트워크 구조에 독립적
- 자연스러운 일관성: 적분의 가법성에서 consistency가 자동으로 유도
Flow Matching vs MeanFlow¶
Flow Matching:
- 학습: v(z_t, t) 근사
- 샘플링: z_{t+dt} = z_t + dt * v(z_t, t) [반복]
- 문제: 곡선 궤적에서 다단계 필요
MeanFlow:
- 학습: u(z_t, r, t) 근사 (평균 속도)
- 샘플링: x = z_1 - u(z_1, 0, 1) [한 번]
- 장점: 전체 궤적을 한 번에 근사
실용적 의의¶
- 추론 속도: 1-NFE로 multi-step 수준 품질 달성
- 학습 안정성: Curriculum learning 없이 scratch 학습
- CFG 효율성: 별도 guidance 계산 없이 CFG 내장
한계 및 향후 연구¶
현재 한계¶
- ImageNet 256x256 외 다른 도메인 검증 필요
- 고해상도(512x512+) 확장 연구 진행 중
- Video, 3D 등 다른 modality 적용 미검증
향후 방향¶
- Scaling: 더 큰 모델/데이터셋에서의 성능
- Multi-modal: Text-to-image, video generation 적용
- 이론적 분석: 수렴 보장, 최적성 분석
참고 문헌¶
- Geng et al. "Mean Flows for One-step Generative Modeling" NeurIPS 2025
- Lipman et al. "Flow Matching for Generative Modeling" ICLR 2023
- Song et al. "Consistency Models" ICML 2023
- Song et al. "Improved Techniques for Consistency Training" ICLR 2024
- Frans et al. "Shortcut Models" 2024
- Zhou et al. "Inductive Moment Matching" 2025