콘텐츠로 이동
Data Prep
상세

Flow Matching

메타 정보

항목 내용
원 논문 Flow Matching for Generative Modeling (Lipman et al., 2022)
발표 ICLR 2023 (Spotlight)
arXiv 2210.02747
분야 Generative Models, Continuous Normalizing Flows
키워드 CNF, Optimal Transport, Diffusion, Vector Field Regression

개요

Flow Matching은 Continuous Normalizing Flows (CNFs)를 효율적으로 학습하기 위한 새로운 패러다임이다. 기존 CNF 학습의 핵심 문제였던 시뮬레이션 비용을 제거하고, 조건부 확률 경로의 벡터 필드를 직접 회귀하는 방식을 제안한다.

핵심 기여

  1. Simulation-free 학습: ODE 시뮬레이션 없이 CNF 학습 가능
  2. 일반화된 프레임워크: Diffusion models를 특수 케이스로 포함
  3. Optimal Transport 경로: 더 효율적인 확률 경로 설계
  4. 확장성: ImageNet 스케일에서 학습 가능

배경 지식

Continuous Normalizing Flows (CNF)

CNF는 ODE를 통해 간단한 분포(노이즈)를 복잡한 분포(데이터)로 변환한다:

\[ \frac{dx}{dt} = v_\theta(x, t), \quad t \in [0, 1] \]
  • \(x_0 \sim p_0\) (노이즈 분포, 보통 가우시안)
  • \(x_1 \sim p_1\) (데이터 분포)
  • \(v_\theta\): 학습 가능한 벡터 필드

기존 학습 방법의 한계

방법 문제점
Maximum Likelihood ODE 시뮬레이션 필요 (계산 비용 높음)
Score Matching Diffusion 경로에 한정
FFJORD 큰 스케일에서 불안정

Flow Matching 방법론

1. 조건부 확률 경로

데이터 샘플 \(x_1\)이 주어졌을 때, 조건부 확률 경로 \(p_t(x|x_1)\)를 정의:

\[ p_t(x|x_1) = \mathcal{N}(x; \mu_t(x_1), \sigma_t^2 I) \]

가장 간단한 형태 (선형 보간):

\[ \mu_t(x_1) = t \cdot x_1, \quad \sigma_t = 1 - t \]

2. 조건부 벡터 필드

조건부 경로를 생성하는 벡터 필드:

\[ u_t(x|x_1) = \frac{x_1 - x}{1 - t} \]

3. Flow Matching 목적 함수

\[ \mathcal{L}_{FM}(\theta) = \mathbb{E}_{t, x_1, x} \left[ \| v_\theta(x, t) - u_t(x|x_1) \|^2 \right] \]
  • \(t \sim \text{Uniform}(0, 1)\)
  • \(x_1 \sim p_\text{data}\)
  • \(x \sim p_t(x|x_1)\)

4. Optimal Transport 경로

Diffusion 대신 OT 기반 직선 경로 사용:

경로 유형 특징
Diffusion 곡선 경로, 많은 스텝 필요
OT (Optimal Transport) 직선 경로, 적은 스텝으로 충분

OT 경로 정의:

\[ x_t = (1-t) x_0 + t x_1 \]

Diffusion Models와의 비교

┌─────────────────────────────────────────────────────────────┐
│                    생성 모델 프레임워크                       │
├─────────────────────────────────────────────────────────────┤
│                                                             │
│  Diffusion Models          Flow Matching                    │
│  ┌─────────────────┐      ┌─────────────────┐              │
│  │ Forward: SDE    │      │ Forward: ODE    │              │
│  │ (노이즈 추가)    │      │ (직선 보간)     │              │
│  └────────┬────────┘      └────────┬────────┘              │
│           │                        │                        │
│           ▼                        ▼                        │
│  ┌─────────────────┐      ┌─────────────────┐              │
│  │ Score Matching  │      │ Vector Field    │              │
│  │ ∇log p_t(x)     │      │ v_t(x)          │              │
│  └────────┬────────┘      └────────┬────────┘              │
│           │                        │                        │
│           ▼                        ▼                        │
│  ┌─────────────────┐      ┌─────────────────┐              │
│  │ Reverse SDE     │      │ ODE Solver      │              │
│  │ (많은 스텝)      │      │ (적은 스텝)     │              │
│  └─────────────────┘      └─────────────────┘              │
│                                                             │
└─────────────────────────────────────────────────────────────┘

정량적 비교 (ImageNet 256x256)

모델 FID NFE (스텝 수) 학습 안정성
DDPM 3.17 1000 높음
Score SDE 2.20 2000 중간
Flow Matching (OT) 2.08 ~100 높음

알고리즘

학습 알고리즘

Algorithm: Flow Matching Training
─────────────────────────────────
Input: 데이터셋 D, 신경망 v_θ
Output: 학습된 파라미터 θ

1. repeat
2.   x₁ ~ D                          # 데이터 샘플링
3.   x₀ ~ N(0, I)                    # 노이즈 샘플링  
4.   t ~ Uniform(0, 1)               # 시간 샘플링
5.   xₜ = (1-t)x₀ + t·x₁             # 보간 (OT 경로)
6.   uₜ = x₁ - x₀                    # 타겟 벡터 필드
7.   L = ||v_θ(xₜ, t) - uₜ||²        # 손실 계산
8.   θ ← θ - η∇_θL                   # 파라미터 업데이트
9. until converged

샘플링 알고리즘

Algorithm: Flow Matching Sampling
─────────────────────────────────
Input: 학습된 v_θ, 스텝 수 N
Output: 생성된 샘플 x₁

1. x₀ ~ N(0, I)                      # 초기 노이즈
2. dt = 1/N
3. for t = 0 to 1-dt by dt do
4.   x_{t+dt} = xₜ + v_θ(xₜ, t)·dt   # Euler 적분
5. end for
6. return x₁

Python 구현 예시

기본 Flow Matching

import torch
import torch.nn as nn
from torch.utils.data import DataLoader

class VectorFieldNet(nn.Module):
    """벡터 필드를 예측하는 신경망"""
    def __init__(self, dim: int, hidden_dim: int = 256):
        super().__init__()
        self.net = nn.Sequential(
            nn.Linear(dim + 1, hidden_dim),  # +1 for time
            nn.SiLU(),
            nn.Linear(hidden_dim, hidden_dim),
            nn.SiLU(),
            nn.Linear(hidden_dim, hidden_dim),
            nn.SiLU(),
            nn.Linear(hidden_dim, dim)
        )

    def forward(self, x: torch.Tensor, t: torch.Tensor) -> torch.Tensor:
        # t: (batch,) -> (batch, 1)
        t = t.unsqueeze(-1) if t.dim() == 1 else t
        return self.net(torch.cat([x, t], dim=-1))


class FlowMatching:
    """Flow Matching 학습 및 샘플링"""

    def __init__(self, dim: int, device: str = 'cuda'):
        self.dim = dim
        self.device = device
        self.model = VectorFieldNet(dim).to(device)
        self.optimizer = torch.optim.Adam(self.model.parameters(), lr=1e-3)

    def compute_loss(self, x1: torch.Tensor) -> torch.Tensor:
        """Flow Matching 손실 계산"""
        batch_size = x1.shape[0]

        # 노이즈 및 시간 샘플링
        x0 = torch.randn_like(x1)
        t = torch.rand(batch_size, 1, device=self.device)

        # OT 경로를 따라 보간
        xt = (1 - t) * x0 + t * x1

        # 타겟 벡터 필드 (OT 경로의 미분)
        ut = x1 - x0

        # 예측 벡터 필드
        vt = self.model(xt, t)

        # MSE 손실
        loss = ((vt - ut) ** 2).mean()
        return loss

    def train_step(self, x1: torch.Tensor) -> float:
        """단일 학습 스텝"""
        self.optimizer.zero_grad()
        loss = self.compute_loss(x1)
        loss.backward()
        self.optimizer.step()
        return loss.item()

    @torch.no_grad()
    def sample(self, n_samples: int, n_steps: int = 100) -> torch.Tensor:
        """Euler 방법으로 샘플 생성"""
        self.model.eval()

        # 초기 노이즈
        x = torch.randn(n_samples, self.dim, device=self.device)
        dt = 1.0 / n_steps

        # ODE 적분
        for i in range(n_steps):
            t = torch.full((n_samples, 1), i * dt, device=self.device)
            v = self.model(x, t)
            x = x + v * dt

        self.model.train()
        return x


# 사용 예시
if __name__ == "__main__":
    # 2D Gaussian Mixture 데이터 생성
    def sample_gmm(n: int) -> torch.Tensor:
        """4개의 가우시안 혼합 샘플링"""
        centers = torch.tensor([[-2, -2], [-2, 2], [2, -2], [2, 2]], dtype=torch.float32)
        idx = torch.randint(0, 4, (n,))
        samples = centers[idx] + 0.3 * torch.randn(n, 2)
        return samples

    # 학습
    device = 'cuda' if torch.cuda.is_available() else 'cpu'
    fm = FlowMatching(dim=2, device=device)

    for epoch in range(1000):
        data = sample_gmm(256).to(device)
        loss = fm.train_step(data)
        if epoch % 100 == 0:
            print(f"Epoch {epoch}, Loss: {loss:.4f}")

    # 샘플링
    samples = fm.sample(500, n_steps=50)
    print(f"Generated samples shape: {samples.shape}")

조건부 Flow Matching (Class-Conditional)

class ConditionalVectorFieldNet(nn.Module):
    """클래스 조건부 벡터 필드"""
    def __init__(self, dim: int, n_classes: int, hidden_dim: int = 256):
        super().__init__()
        self.class_embed = nn.Embedding(n_classes, hidden_dim)
        self.net = nn.Sequential(
            nn.Linear(dim + 1 + hidden_dim, hidden_dim),
            nn.SiLU(),
            nn.Linear(hidden_dim, hidden_dim),
            nn.SiLU(),
            nn.Linear(hidden_dim, dim)
        )

    def forward(self, x: torch.Tensor, t: torch.Tensor, 
                y: torch.Tensor) -> torch.Tensor:
        t = t.unsqueeze(-1) if t.dim() == 1 else t
        c = self.class_embed(y)
        return self.net(torch.cat([x, t, c], dim=-1))


class ConditionalFlowMatching:
    """조건부 Flow Matching"""

    def __init__(self, dim: int, n_classes: int, device: str = 'cuda'):
        self.model = ConditionalVectorFieldNet(dim, n_classes).to(device)
        self.optimizer = torch.optim.Adam(self.model.parameters(), lr=1e-3)
        self.device = device

    def compute_loss(self, x1: torch.Tensor, y: torch.Tensor) -> torch.Tensor:
        batch_size = x1.shape[0]
        x0 = torch.randn_like(x1)
        t = torch.rand(batch_size, 1, device=self.device)

        xt = (1 - t) * x0 + t * x1
        ut = x1 - x0
        vt = self.model(xt, t, y)

        return ((vt - ut) ** 2).mean()

    @torch.no_grad()
    def sample(self, y: torch.Tensor, n_steps: int = 100) -> torch.Tensor:
        """조건부 샘플링"""
        self.model.eval()
        n_samples = y.shape[0]
        dim = list(self.model.parameters())[0].shape[1] - 1 - \
              self.model.class_embed.embedding_dim

        x = torch.randn(n_samples, dim, device=self.device)
        dt = 1.0 / n_steps

        for i in range(n_steps):
            t = torch.full((n_samples, 1), i * dt, device=self.device)
            v = self.model(x, t, y)
            x = x + v * dt

        self.model.train()
        return x

torchdyn을 활용한 고급 ODE 솔버

# pip install torchdyn

from torchdyn.core import NeuralODE

class FlowMatchingWithTorchdyn:
    """torchdyn을 활용한 Flow Matching"""

    def __init__(self, model: nn.Module, device: str = 'cuda'):
        self.device = device

        # 벡터 필드를 torchdyn 형식으로 래핑
        class VectorFieldWrapper(nn.Module):
            def __init__(self, base_model):
                super().__init__()
                self.base_model = base_model

            def forward(self, t: torch.Tensor, x: torch.Tensor) -> torch.Tensor:
                batch_size = x.shape[0]
                t_batch = t.expand(batch_size, 1)
                return self.base_model(x, t_batch)

        self.wrapper = VectorFieldWrapper(model)
        self.ode = NeuralODE(
            self.wrapper,
            solver='dopri5',        # Adaptive step solver
            sensitivity='adjoint',  # Memory efficient
            atol=1e-5,
            rtol=1e-5
        ).to(device)

    @torch.no_grad()
    def sample(self, n_samples: int, dim: int) -> torch.Tensor:
        """적응형 스텝 크기로 샘플링"""
        x0 = torch.randn(n_samples, dim, device=self.device)
        t_span = torch.linspace(0, 1, 2, device=self.device)

        trajectory = self.ode(x0, t_span)
        return trajectory[-1]  # t=1에서의 샘플

확장 및 응용

1. Rectified Flow

Flow Matching의 후속 연구로, 경로를 더욱 직선화하여 샘플링 효율 극대화:

def rectify_flow(model, data_loader, n_iterations: int = 10):
    """Rectified Flow: 경로 직선화를 통한 1-step 생성"""
    for _ in range(n_iterations):
        # 기존 모델로 (x0, x1) 쌍 생성
        pairs = []
        for x1 in data_loader:
            x0 = torch.randn_like(x1)
            # ODE를 통해 x0 -> x1 매핑
            x1_pred = model.sample_from_noise(x0)
            pairs.append((x0, x1_pred))

        # 새로운 직선 경로로 재학습
        model.train_on_pairs(pairs)

2. 적용 분야

분야 응용
이미지 생성 DALL-E 3, Stable Diffusion 3
분자 생성 단백질 구조 예측, 신약 개발
오디오 음성 합성, 음악 생성
비디오 프레임 보간, 비디오 생성
3D Point Cloud, NeRF

장단점

장점

  • 학습 효율성: Simulation-free로 빠른 학습
  • 샘플링 효율성: OT 경로로 적은 스텝에 고품질 샘플
  • 유연성: 다양한 확률 경로 설계 가능
  • 안정성: Diffusion보다 학습 안정성 우수
  • 이론적 기반: Optimal Transport 이론과 연결

단점

  • 조건부 생성: 복잡한 조건부 생성에서 추가 설계 필요
  • Discrete 데이터: 연속 데이터에 최적화, 이산 데이터는 추가 처리 필요
  • 스케일링: 매우 큰 모델에서는 추가 기법 필요

주요 후속 연구

논문 기여 년도
Rectified Flow 1-step 생성을 위한 경로 직선화 2022
Stochastic Interpolants 확률적 보간 이론적 통합 2023
Stable Diffusion 3 Flow Matching 기반 대규모 이미지 모델 2024
Physics-Constrained FM 물리 제약 조건 적용 2025

참고 자료


마지막 업데이트: 2026-02-06