콘텐츠로 이동
Data Prep
상세

MeanFlow: One-step Generative Modeling

메타정보

항목 내용
논문 Mean Flows for One-step Generative Modeling
저자 Zhengyang Geng (CMU), Mingyang Deng, Xingjian Bai, J. Zico Kolter, Kaiming He (MIT)
발표 NeurIPS 2025 Oral
arXiv 2505.13447
코드 github.com/haidog-yaqub/MeanFlow (unofficial)

개요

MeanFlow는 한 번의 forward pass로 고품질 이미지를 생성하는 프레임워크다. Flow Matching이 순간 속도(instantaneous velocity)를 모델링하는 것과 달리, 평균 속도(average velocity) 개념을 도입하여 one-step generation을 가능하게 한다.

핵심 성과: - ImageNet 256x256에서 FID 3.43 (1-NFE, single function evaluation) - 기존 one-step 모델 대비 50-70% 성능 향상 - Pre-training, distillation, curriculum learning 불필요


배경: Flow Matching의 한계

Flow Matching 복습

Flow Matching은 prior 분포를 data 분포로 변환하는 velocity field를 학습한다:

z_t = a_t * x + b_t * epsilon
v_t = dz_t/dt = a'_t * x + b'_t * epsilon

일반적인 스케줄: a_t = 1-t, b_t = t -> v_t = epsilon - x

문제점

  1. 곡선 궤적: Conditional flow가 직선이어도 marginal velocity field는 곡선 궤적 생성
  2. 다단계 샘플링 필수: ODE solver로 반복 계산 필요
  3. Coarse discretization 오류: 적은 step에서 정확도 급락

MeanFlow 핵심 아이디어

평균 속도 (Average Velocity)

순간 속도 대신 시간 구간 [r, t]에서의 평균 속도를 정의:

u(z_t, r, t) = (1 / (t-r)) * integral_r^t v(z_tau, tau) d_tau

여기서: - u: 평균 속도 (average velocity) - v: 순간 속도 (instantaneous velocity) - [r, t]: 시간 구간

핵심 특성

1. 경계 조건

lim(r->t) u = v  (평균 속도 -> 순간 속도)

2. 자연스러운 일관성 (Consistency)

(t-r) * u(z_t, r, t) = (s-r) * u(z_s, r, s) + (t-s) * u(z_t, s, t)

큰 step 하나 = 작은 step 두 개의 합 (적분의 가법성에서 유도)

3. One-step Generation

x_generated = z_1 - u_theta(z_1, 0, 1)

u(epsilon, 0, 1)만 계산하면 전체 궤적을 한 번에 근사


방법론

학습 목표

평균 속도와 순간 속도 사이의 항등식을 만족하도록 학습:

u(z_t, r, t) = (1 / (t-r)) * integral_r^t v(z_tau, tau) d_tau

Loss Function

Self-consistency loss 기반: - 평균 속도 network u_theta(z, r, t) 학습 - 순간 속도는 r -> t 극한에서 도출 - 외부 teacher model이나 distillation 불필요

아키텍처

  • DiT (Diffusion Transformer) 기반
  • 추가 시간 조건 r 입력 처리
  • Classifier-Free Guidance (CFG) 내장 가능

CFG 통합

평균 속도 field에 CFG를 직접 통합: - 샘플링 시 추가 비용 없음 - Multi-step 모델처럼 별도 CFG 계산 불필요


실험 결과

ImageNet 256x256 벤치마크

Method NFE FID
iCT 1 10.3
Shortcut 1 7.8
IMM (2-NFE guidance) 2 5.1
MeanFlow 1 3.43

주요 비교

  • iCT (improved Consistency Training): FID 10.3 -> MeanFlow 대비 3배 열등
  • Shortcut Models: FID 7.8 -> MeanFlow 대비 2배 이상 열등
  • IMM (Inductive Moment Matching): 2-NFE 사용해도 MeanFlow 1-NFE보다 열등

Multi-step과의 격차 해소

Model Type FID
Multi-step Flow Matching (50-NFE) ~2.5
MeanFlow (1-NFE) 3.43

One-step과 multi-step 사이의 성능 격차를 크게 줄임


Consistency Models과의 비교

측면 Consistency Models MeanFlow
기반 네트워크 행동 제약 Ground-truth field 학습
학습 Curriculum learning 필수 Curriculum 불필요
안정성 불안정할 수 있음 더 안정적
이론적 기반 Heuristic constraint 수학적 항등식

MeanFlow는 ground-truth target field가 존재하여 최적 해가 네트워크에 독립적


Python 구현 예시

평균 속도 개념

import torch
import torch.nn as nn

class MeanFlowModel(nn.Module):
    """
    MeanFlow: Average velocity field modeling for one-step generation

    u(z_t, r, t) = displacement / (t - r)
    where displacement = integral_r^t v(z_tau, tau) d_tau
    """
    def __init__(self, hidden_dim=512, num_layers=6):
        super().__init__()
        # 입력: z_t (latent), r (start time), t (end time)
        self.time_embed = nn.Sequential(
            nn.Linear(2, hidden_dim),  # (r, t) 두 시간 조건
            nn.SiLU(),
            nn.Linear(hidden_dim, hidden_dim)
        )

        self.backbone = nn.ModuleList([
            nn.Sequential(
                nn.Linear(hidden_dim, hidden_dim),
                nn.LayerNorm(hidden_dim),
                nn.SiLU()
            ) for _ in range(num_layers)
        ])

        self.output = nn.Linear(hidden_dim, hidden_dim)

    def forward(self, z_t, r, t):
        """
        Args:
            z_t: Latent at time t, shape (B, D)
            r: Start time, shape (B, 1)
            t: End time, shape (B, 1)
        Returns:
            u: Average velocity, shape (B, D)
        """
        # 시간 임베딩 (r, t 모두 조건으로 사용)
        time_cond = torch.cat([r, t], dim=-1)
        time_emb = self.time_embed(time_cond)

        h = z_t + time_emb
        for layer in self.backbone:
            h = layer(h) + h  # Residual

        return self.output(h)


def one_step_generate(model, noise):
    """
    One-step generation using MeanFlow

    x = z_1 - (1 - 0) * u(z_1, 0, 1)
      = z_1 - u(z_1, 0, 1)
    """
    batch_size = noise.shape[0]
    r = torch.zeros(batch_size, 1, device=noise.device)
    t = torch.ones(batch_size, 1, device=noise.device)

    # 평균 속도 예측
    avg_velocity = model(noise, r, t)

    # One-step generation
    x_generated = noise - avg_velocity

    return x_generated

Self-consistency Loss

def meanflow_loss(model, x_data, noise):
    """
    MeanFlow self-consistency loss

    핵심 아이디어:
    - 랜덤 시간 (r, s, t) 샘플링 (r < s < t)
    - 큰 step의 평균 속도 = 작은 step들의 가중 평균

    Loss:
    (t-r) * u(z_t, r, t) should equal
    (s-r) * u(z_s, r, s) + (t-s) * u(z_t, s, t)
    """
    batch_size = x_data.shape[0]
    device = x_data.device

    # 시간 샘플링: 0 < r < s < t < 1
    times = torch.rand(batch_size, 3, device=device).sort(dim=1).values
    r, s, t = times[:, 0:1], times[:, 1:2], times[:, 2:3]

    # Flow path 구성 (linear interpolation)
    z_t = (1 - t) * x_data + t * noise
    z_s = (1 - s) * x_data + s * noise

    # 평균 속도 예측
    u_full = model(z_t, r, t)      # u(z_t, r, t)
    u_first = model(z_s, r, s)     # u(z_s, r, s)
    u_second = model(z_t, s, t)    # u(z_t, s, t)

    # Self-consistency: 가법성 검증
    # (t-r) * u_full = (s-r) * u_first + (t-s) * u_second
    lhs = (t - r) * u_full
    rhs = (s - r) * u_first + (t - s) * u_second

    loss = ((lhs - rhs) ** 2).mean()

    return loss


def boundary_loss(model, x_data, noise):
    """
    Boundary condition loss: lim(r->t) u(z_t, r, t) = v(z_t, t)

    When r is very close to t, average velocity should equal
    instantaneous velocity (tangent to the flow path)
    """
    batch_size = x_data.shape[0]
    device = x_data.device

    # r을 t에 가깝게 설정
    t = torch.rand(batch_size, 1, device=device) * 0.9 + 0.1
    delta = torch.rand(batch_size, 1, device=device) * 0.01  # 작은 delta
    r = t - delta

    z_t = (1 - t) * x_data + t * noise

    # 평균 속도
    u = model(z_t, r, t)

    # Ground-truth 순간 속도 (linear flow의 경우)
    v_gt = noise - x_data

    loss = ((u - v_gt) ** 2).mean()

    return loss

DiT 기반 구현

class MeanFlowDiT(nn.Module):
    """
    DiT-based MeanFlow model

    Key modification: Two time conditions (r, t) instead of one
    """
    def __init__(
        self,
        input_size=32,
        patch_size=2,
        in_channels=4,
        hidden_size=1152,
        depth=28,
        num_heads=16,
        mlp_ratio=4.0,
        class_dropout_prob=0.1,
        num_classes=1000
    ):
        super().__init__()
        self.input_size = input_size
        self.patch_size = patch_size
        self.hidden_size = hidden_size

        # Patch embedding
        self.x_embedder = PatchEmbed(
            input_size, patch_size, in_channels, hidden_size
        )

        # Two time embeddings (r and t)
        self.r_embedder = TimestepEmbedder(hidden_size)
        self.t_embedder = TimestepEmbedder(hidden_size)

        # Class embedding for CFG
        self.y_embedder = LabelEmbedder(
            num_classes, hidden_size, class_dropout_prob
        )

        # Transformer blocks
        self.blocks = nn.ModuleList([
            DiTBlock(hidden_size, num_heads, mlp_ratio)
            for _ in range(depth)
        ])

        # Output projection
        self.final_layer = FinalLayer(
            hidden_size, patch_size, in_channels
        )

    def forward(self, z_t, r, t, y):
        """
        Args:
            z_t: Noisy latent (B, C, H, W)
            r: Start timestep (B,)
            t: End timestep (B,)
            y: Class labels (B,)
        """
        # Embeddings
        x = self.x_embedder(z_t)
        r_emb = self.r_embedder(r)
        t_emb = self.t_embedder(t)
        y_emb = self.y_embedder(y)

        # Combine conditions
        c = r_emb + t_emb + y_emb

        # Transformer
        for block in self.blocks:
            x = block(x, c)

        # Unpatchify
        u = self.final_layer(x, c)

        return u


class TimestepEmbedder(nn.Module):
    """Sinusoidal timestep embedding"""
    def __init__(self, hidden_size, freq_embed_size=256):
        super().__init__()
        self.mlp = nn.Sequential(
            nn.Linear(freq_embed_size, hidden_size),
            nn.SiLU(),
            nn.Linear(hidden_size, hidden_size)
        )
        self.freq_embed_size = freq_embed_size

    def forward(self, t):
        freqs = torch.exp(
            -torch.arange(0, self.freq_embed_size, 2, device=t.device)
            * (torch.log(torch.tensor(10000.0)) / self.freq_embed_size)
        )
        args = t[:, None] * freqs[None]
        embedding = torch.cat([torch.cos(args), torch.sin(args)], dim=-1)
        return self.mlp(embedding)

학습 루프

def train_meanflow(
    model,
    dataloader,
    optimizer,
    epochs=100,
    alpha=0.1  # Boundary loss weight
):
    """
    MeanFlow training loop
    """
    model.train()

    for epoch in range(epochs):
        epoch_loss = 0.0

        for batch in dataloader:
            x_data = batch['images']
            labels = batch['labels']

            # Prior sample
            noise = torch.randn_like(x_data)

            # Consistency loss (main)
            loss_cons = meanflow_consistency_loss(
                model, x_data, noise, labels
            )

            # Boundary loss (regularization)
            loss_bound = meanflow_boundary_loss(
                model, x_data, noise, labels
            )

            loss = loss_cons + alpha * loss_bound

            optimizer.zero_grad()
            loss.backward()
            optimizer.step()

            epoch_loss += loss.item()

        print(f"Epoch {epoch}: Loss = {epoch_loss / len(dataloader):.4f}")


def meanflow_consistency_loss(model, x_data, noise, labels):
    """Full consistency loss with class conditioning"""
    batch_size = x_data.shape[0]
    device = x_data.device

    # Sample three ordered times
    times = torch.rand(batch_size, 3, device=device)
    times = times.sort(dim=1).values
    r = times[:, 0]
    s = times[:, 1]
    t = times[:, 2]

    # Construct latents at different times
    z_t = (1 - t.view(-1, 1, 1, 1)) * x_data + t.view(-1, 1, 1, 1) * noise
    z_s = (1 - s.view(-1, 1, 1, 1)) * x_data + s.view(-1, 1, 1, 1) * noise

    # Predict average velocities
    with torch.no_grad():
        # Stop gradient for targets (similar to EMA target)
        u_first_target = model(z_s, r, s, labels).detach()
        u_second_target = model(z_t, s, t, labels).detach()

    u_full = model(z_t, r, t, labels)

    # Consistency target
    target = (
        (s - r).view(-1, 1, 1, 1) * u_first_target +
        (t - s).view(-1, 1, 1, 1) * u_second_target
    ) / (t - r).view(-1, 1, 1, 1)

    loss = nn.functional.mse_loss(u_full, target)

    return loss

핵심 인사이트

왜 MeanFlow가 작동하는가

  1. Ground-truth Field 존재: 평균 속도 u는 순간 속도 v로부터 유도된 well-defined field
  2. 네트워크 독립적 목표: 최적 해가 네트워크 구조에 독립적
  3. 자연스러운 일관성: 적분의 가법성에서 consistency가 자동으로 유도

Flow Matching vs MeanFlow

Flow Matching:
  - 학습: v(z_t, t) 근사
  - 샘플링: z_{t+dt} = z_t + dt * v(z_t, t) [반복]
  - 문제: 곡선 궤적에서 다단계 필요

MeanFlow:
  - 학습: u(z_t, r, t) 근사 (평균 속도)
  - 샘플링: x = z_1 - u(z_1, 0, 1) [한 번]
  - 장점: 전체 궤적을 한 번에 근사

실용적 의의

  • 추론 속도: 1-NFE로 multi-step 수준 품질 달성
  • 학습 안정성: Curriculum learning 없이 scratch 학습
  • CFG 효율성: 별도 guidance 계산 없이 CFG 내장

한계 및 향후 연구

현재 한계

  • ImageNet 256x256 외 다른 도메인 검증 필요
  • 고해상도(512x512+) 확장 연구 진행 중
  • Video, 3D 등 다른 modality 적용 미검증

향후 방향

  1. Scaling: 더 큰 모델/데이터셋에서의 성능
  2. Multi-modal: Text-to-image, video generation 적용
  3. 이론적 분석: 수렴 보장, 최적성 분석

참고 문헌

  1. Geng et al. "Mean Flows for One-step Generative Modeling" NeurIPS 2025
  2. Lipman et al. "Flow Matching for Generative Modeling" ICLR 2023
  3. Song et al. "Consistency Models" ICML 2023
  4. Song et al. "Improved Techniques for Consistency Training" ICLR 2024
  5. Frans et al. "Shortcut Models" 2024
  6. Zhou et al. "Inductive Moment Matching" 2025