콘텐츠로 이동
Data Prep
상세

Neural ODE (Neural Ordinary Differential Equations)

메타 정보

항목 내용
분류 Deep Learning / Continuous-Depth Models
원논문 "Neural Ordinary Differential Equations" (NeurIPS 2018, Best Paper)
저자 Ricky T. Q. Chen, Yulia Rubanova, Jesse Bettencourt, David Duvenaud (Univ. of Toronto)
핵심 개념 신경망의 hidden state 변환을 연속 시간 ODE로 모델링
관련 분야 Dynamical Systems, Normalizing Flows, Time Series, Generative Models

정의

Neural ODE는 ResNet 등 이산 레이어의 변환을 연속 시간 상미분방정식(ODE)의 관점에서 재해석한 프레임워크이다. 이산 레이어 h_{t+1} = h_t + f(h_t, theta)를 연속 극한으로 보면 다음 ODE가 된다:

dh(t)
----- = f(h(t), t, theta)
 dt

여기서 f는 신경망으로 매개변수화된 벡터장(vector field)이고, hidden state h(t)는 초기값 문제(IVP)를 풀어 임의의 시점에서 평가할 수 있다.

핵심 아이디어

ResNet에서 Neural ODE로

이산 (ResNet)                          연속 (Neural ODE)
h_{t+1} = h_t + f(h_t, theta_t)       dh/dt = f(h(t), t, theta)
     |                                      |
고정된 레이어 수 (T층)                  적응적 깊이 (ODE solver가 결정)
O(T) 메모리                            O(1) 메모리 (adjoint method)

핵심 관찰: ResNet의 잔차 연결(residual connection)은 Euler method로 ODE를 1스텝 적분한 것과 동일하다. 이를 일반화하면 고정된 레이어 수 대신 연속적인 깊이를 가진 모델을 구성할 수 있다.

Adjoint Sensitivity Method

Neural ODE의 역전파는 adjoint sensitivity method로 수행한다. 순방향 ODE를 역시간(reverse-time)으로 풀면서 gradient를 계산한다.

Forward:  h(0) --[ODE solve]--> h(T) --[Loss]--> L

Backward: a(T) = dL/dh(T)
          da/dt = -a(t)^T * (df/dh)       (adjoint equation)
          dL/dtheta = integral_T^0 a(t)^T * (df/dtheta) dt
방법 메모리 시간
직접 역전파 (backprop through solver) O(L * D) O(L)
Adjoint method O(D) O(L)
Checkpointing O(sqrt(L) * D) O(L * sqrt(L))

여기서 L은 solver steps, D는 hidden dimension이다. Adjoint method는 중간 상태를 저장하지 않으므로 메모리가 상수이다.

주요 변형

1. Augmented Neural ODE (NeurIPS 2019)

Dupont et al.의 연구. 원래 Neural ODE는 위상(topology)을 보존하는 homeomorphism만 학습 가능하다는 한계가 있다 (궤적이 교차할 수 없음). 이를 해결하기 위해 hidden state에 추가 차원을 확장한다:

원래:    dh/dt = f(h(t), t, theta)          -- R^d 위의 흐름
확장:    d[h, a]/dt = f([h(t), a(t)], t, theta)   -- R^(d+p) 위의 흐름

추가 차원 a(t)가 흐름의 교차를 허용하여 표현력이 크게 향상된다.

2. Latent ODE (NeurIPS 2019)

Rubanova et al.의 연구. 불규칙 시계열(irregularly-sampled time series) 모델링에 특화된 변형이다:

관측: x(t_1), x(t_3), x(t_7), ...  (불규칙 간격)
      |
      v
Encoder (ODE-RNN) --> z(t_0)  (잠재 초기 상태)
      |
      v
ODE Solver: dz/dt = f(z(t), t, theta)
      |
      v
Decoder --> x_hat(t)  (임의의 시점에서 재구성)

특징: - 관측 간격이 불규칙해도 자연스럽게 처리 - 임의의 시점에서 보간/외삽 가능 - 결측값 처리가 내재적

3. FFJORD (ICLR 2019)

Grathwohl et al.의 연구. Continuous Normalizing Flow (CNF)를 자유 형식 야코비안으로 확장한 모델이다. 밀도 변환을 다음과 같이 계산한다:

d log p(z(t))
------------- = -tr(df/dz)
     dt

Hutchinson's trace estimator를 사용하여 야코비안의 trace를 O(D)에 추정한다. 기존 Normalizing Flow 대비:

비교 항목 기존 Normalizing Flow FFJORD
변환 구조 제한적 (삼각, 결합 등) 자유 형식
야코비안 계산 O(D^3) 또는 구조적 O(D) 추정
가역성 구조적으로 보장 ODE 자체가 가역
연속 시간 아님

4. Neural SDE (NeurIPS 2020)

Li et al.의 연구. ODE에 확률적 노이즈를 추가한 확률 미분방정식(SDE) 형태:

dh(t) = f(h(t), t, theta) dt + g(h(t), t, theta) dW(t)

불확실성 정량화가 자연스러우며, 생성 모델(score-based diffusion)과도 연결된다.

ODE Solver 선택

Neural ODE의 성능은 ODE solver에 크게 의존한다:

Solver 종류 적응적 스텝 적합한 상황
Euler 고정 스텝 X 빠른 프로토타이핑, 단순한 역학
RK4 (Runge-Kutta 4) 고정 스텝 X 중간 정확도
Dopri5 (Dormand-Prince) 적응적 스텝 O 범용, torchdiffeq 기본값
Adams 다단계 O 매끄러운 역학, 장기 적분
Implicit methods 암묵적 O Stiff system

적응적 스텝 크기의 의미

입력 복잡도 높음 --> 스텝 크기 줄임 --> 더 많은 연산 (더 깊은 네트워크)
입력 복잡도 낮음 --> 스텝 크기 늘림 --> 더 적은 연산 (더 얕은 네트워크)

즉, 입력에 따라 연산량(깊이)이 자동 조절된다. 이는 고정 아키텍처에 없는 고유한 장점이다.

이론적 기반

Universal Approximation

Neural ODE가 충분히 넓은 hidden dimension을 가지면 임의의 연속 변환을 근사할 수 있다. 단, 원래 차원 d에서는 homeomorphism 클래스로 제한되며, Augmented Neural ODE로 확장하면 이 제약이 해소된다.

Dynamical Systems 관점

Neural ODE는 자율 시스템(autonomous system)의 학습으로 볼 수 있다:

개념 역학계 Neural ODE
상태 시스템의 상태 hidden state h(t)
벡터장 물리 법칙 신경망 f(h, t, theta)
궤적 시스템 진화 순방향 계산
안정점 평형 상태 수렴한 표현

학습 시 실용적 고려사항

정규화 기법

기법 목적 참조
Kinetic energy regularization NFE(solver 호출 수) 감소 Finlay et al., 2020
Jacobian norm regularization 역학 복잡도 제한 Kelly et al., 2020
STEER (Stiff ODE regularizer) Stiffness 완화 Ghosh et al., 2020
Spectral regularization 리아프노프 안정성 Massaroli et al., 2021

훈련 안정성 문제

  1. NFE 폭발: 학습이 진행되면서 ODE solver가 필요한 함수 평가 횟수(NFE)가 급증할 수 있다. 정규화로 완화한다.
  2. Stiff dynamics: 벡터장의 스케일이 크게 달라지면 solver가 매우 작은 스텝을 사용해야 하므로 느려진다.
  3. 수치적 오차 누적: Adjoint method의 역시간 적분에서 오차가 누적되어 gradient가 부정확해질 수 있다.
  4. 해결 전략: Seminorm adjoint (Kidger, 2021), interpolated adjoint, 또는 checkpointing 사용.

응용 분야

1. 불규칙 시계열 모델링

가장 영향력 있는 응용. 관측 시점이 불균일한 데이터에 자연스럽게 적합하다:

  • 전자 의무 기록 (EHR): 환자 방문 간격이 불규칙
  • 기후/환경 센서: 결측과 비정기 관측
  • 금융 틱 데이터: 거래 간격이 가변적

2. 생성 모델 (Continuous Normalizing Flows)

FFJORD를 기반으로 한 밀도 추정 및 생성:

  • 이미지 생성
  • 분자 구조 생성
  • 이상 탐지 (log-likelihood 기반)

3. 물리 시뮬레이션

물리 법칙을 선험 지식으로 활용하는 Physics-Informed Neural ODE:

  • 분자 역학 시뮬레이션
  • 유체 역학
  • 기후 모델링

4. 제어 및 강화학습

연속 시간 제어 정책 학습:

  • 로봇 공학
  • 자율 시스템
  • 최적 제어 문제

Flow Matching과의 관계

Neural ODE의 Continuous Normalizing Flow는 이후 Flow Matching 프레임워크의 이론적 기반이 된다:

Neural ODE (2018)
  |
  +--> CNF / FFJORD (2019) -- 최대우도 학습, trace 추정 필요
         |
         +--> Flow Matching (2023) -- 조건부 경로로 학습 단순화
                |
                +--> MeanFlow (2024) -- 1-step 생성

Flow Matching은 CNF의 simulation-free 학습을 가능하게 하여 Neural ODE 기반 생성 모델의 확장성을 크게 개선했다.

Python 구현 예제

기본 Neural ODE (torchdiffeq)

import torch
import torch.nn as nn
from torchdiffeq import odeint, odeint_adjoint

class ODEFunc(nn.Module):
    """ODE의 벡터장 f(h, t, theta)를 정의하는 신경망"""

    def __init__(self, hidden_dim: int):
        super().__init__()
        self.net = nn.Sequential(
            nn.Linear(hidden_dim, 128),
            nn.Tanh(),
            nn.Linear(128, 128),
            nn.Tanh(),
            nn.Linear(128, hidden_dim),
        )

    def forward(self, t, h):
        """
        Args:
            t: 현재 시간 (스칼라)
            h: hidden state [batch, hidden_dim]
        Returns:
            dh/dt [batch, hidden_dim]
        """
        return self.net(h)


class NeuralODEClassifier(nn.Module):
    """Neural ODE 기반 분류기"""

    def __init__(self, input_dim: int, hidden_dim: int, output_dim: int,
                 solver: str = 'dopri5', use_adjoint: bool = True):
        super().__init__()
        self.input_layer = nn.Linear(input_dim, hidden_dim)
        self.odefunc = ODEFunc(hidden_dim)
        self.output_layer = nn.Linear(hidden_dim, output_dim)
        self.solver = solver
        self.integration_time = torch.tensor([0.0, 1.0])
        self.odeint = odeint_adjoint if use_adjoint else odeint

    def forward(self, x):
        # 입력을 hidden space로 변환
        h0 = self.input_layer(x)  # [batch, hidden_dim]

        # ODE 적분: h(0) -> h(1)
        t = self.integration_time.to(x.device)
        h_traj = self.odeint(
            self.odefunc, h0, t,
            method=self.solver,
            rtol=1e-4,
            atol=1e-4,
        )
        # h_traj: [2, batch, hidden_dim] (t=0, t=1)

        h_final = h_traj[-1]  # t=1에서의 상태
        return self.output_layer(h_final)


# 사용 예시
model = NeuralODEClassifier(
    input_dim=28*28,
    hidden_dim=64,
    output_dim=10,
    solver='dopri5',
    use_adjoint=True
)

x = torch.randn(32, 28*28)
logits = model(x)  # [32, 10]

Latent ODE (불규칙 시계열)

import torch
import torch.nn as nn
from torchdiffeq import odeint

class LatentODEFunc(nn.Module):
    """잠재 공간의 역학을 정의하는 ODE 함수"""

    def __init__(self, latent_dim: int):
        super().__init__()
        self.net = nn.Sequential(
            nn.Linear(latent_dim, 64),
            nn.ELU(),
            nn.Linear(64, 64),
            nn.ELU(),
            nn.Linear(64, latent_dim),
        )

    def forward(self, t, z):
        return self.net(z)


class ODERNNEncoder(nn.Module):
    """
    불규칙 시계열을 역방향으로 읽어 잠재 초기 상태를 추론하는 인코더.
    각 관측 사이를 ODE로 보간한다.
    """

    def __init__(self, input_dim: int, hidden_dim: int, latent_dim: int):
        super().__init__()
        self.odefunc = LatentODEFunc(hidden_dim)
        self.rnn_cell = nn.GRUCell(input_dim, hidden_dim)
        self.fc_mu = nn.Linear(hidden_dim, latent_dim)
        self.fc_logvar = nn.Linear(hidden_dim, latent_dim)
        self.hidden_dim = hidden_dim

    def forward(self, obs_times, obs_values, obs_mask):
        """
        Args:
            obs_times: [seq_len] 관측 시점
            obs_values: [batch, seq_len, input_dim] 관측값
            obs_mask: [batch, seq_len, input_dim] 관측 마스크 (1=관측, 0=결측)
        Returns:
            mu, logvar: [batch, latent_dim] 잠재 초기 상태의 분포 파라미터
        """
        batch_size = obs_values.shape[0]
        h = torch.zeros(batch_size, self.hidden_dim, device=obs_values.device)

        # 역방향으로 시계열 처리
        for i in reversed(range(len(obs_times))):
            if i < len(obs_times) - 1:
                # 다음 관측까지 ODE로 보간 (역방향)
                dt = obs_times[i+1] - obs_times[i]
                t_span = torch.tensor([0.0, dt], device=h.device)
                h = odeint(self.odefunc, h, t_span, method='euler')[-1]

            # 관측이 있는 위치에서 GRU 업데이트
            x_i = obs_values[:, i, :] * obs_mask[:, i, :]
            h = self.rnn_cell(x_i, h)

        return self.fc_mu(h), self.fc_logvar(h)


class LatentODE(nn.Module):
    """
    Latent ODE: 불규칙 시계열을 위한 VAE + Neural ODE 모델

    Reference: Rubanova et al., "Latent ODEs for Irregularly-Sampled
    Time Series" (NeurIPS 2019)
    """

    def __init__(self, input_dim: int, hidden_dim: int = 64,
                 latent_dim: int = 16):
        super().__init__()
        self.encoder = ODERNNEncoder(input_dim, hidden_dim, latent_dim)
        self.dynamics = LatentODEFunc(latent_dim)
        self.decoder = nn.Sequential(
            nn.Linear(latent_dim, 64),
            nn.ReLU(),
            nn.Linear(64, input_dim),
        )
        self.latent_dim = latent_dim

    def reparameterize(self, mu, logvar):
        std = torch.exp(0.5 * logvar)
        eps = torch.randn_like(std)
        return mu + eps * std

    def forward(self, obs_times, obs_values, obs_mask, eval_times):
        """
        Args:
            obs_times: [seq_len] 관측 시점 (정렬됨)
            obs_values: [batch, seq_len, input_dim]
            obs_mask: [batch, seq_len, input_dim]
            eval_times: [eval_len] 예측할 시점
        Returns:
            pred: [batch, eval_len, input_dim] 예측값
            mu, logvar: 잠재 분포 파라미터
        """
        # 인코더: 관측 -> 잠재 초기 상태
        mu, logvar = self.encoder(obs_times, obs_values, obs_mask)
        z0 = self.reparameterize(mu, logvar)  # [batch, latent_dim]

        # ODE 적분: z(t_0) -> z(eval_times)
        z_traj = odeint(
            self.dynamics, z0, eval_times,
            method='dopri5', rtol=1e-3, atol=1e-3
        )
        # z_traj: [eval_len, batch, latent_dim]

        # 디코더: 잠재 상태 -> 관측 공간
        z_traj = z_traj.permute(1, 0, 2)  # [batch, eval_len, latent_dim]
        pred = self.decoder(z_traj)  # [batch, eval_len, input_dim]

        return pred, mu, logvar

    def loss(self, pred, target, mask, mu, logvar, beta: float = 1.0):
        """ELBO 손실 = 재구성 손실 + beta * KL divergence"""
        # 재구성 손실 (관측된 위치만)
        recon = ((pred - target) ** 2 * mask).sum() / mask.sum()

        # KL divergence
        kl = -0.5 * torch.mean(1 + logvar - mu.pow(2) - logvar.exp())

        return recon + beta * kl


# 사용 예시: 불규칙 시계열
batch_size = 16
input_dim = 3
seq_len = 20

model = LatentODE(input_dim=input_dim, hidden_dim=64, latent_dim=16)

# 불규칙 관측 시점
obs_times = torch.sort(torch.rand(seq_len))[0]
obs_values = torch.randn(batch_size, seq_len, input_dim)
obs_mask = (torch.rand(batch_size, seq_len, input_dim) > 0.3).float()

# 예측할 시점 (연속적)
eval_times = torch.linspace(0, 1, 50)

pred, mu, logvar = model(obs_times, obs_values, obs_mask, eval_times)
loss = model.loss(pred[:, :seq_len, :], obs_values, obs_mask, mu, logvar)

NFE 모니터링 유틸리티

class NFECounter:
    """Neural ODE의 함수 평가 횟수(NFE)를 추적하는 래퍼"""

    def __init__(self, odefunc: nn.Module):
        self.odefunc = odefunc
        self.nfe = 0

    def __call__(self, t, h):
        self.nfe += 1
        return self.odefunc(t, h)

    def reset(self):
        self.nfe = 0

# 사용
counter = NFECounter(model.odefunc)
# ... ODE 적분 실행 ...
print(f"Forward NFE: {counter.nfe}")
# NFE가 과도하면 정규화 필요

기존 기법과의 비교

비교 항목 ResNet (이산) Neural ODE (연속) Transformer
깊이 고정 적응적 (입력 의존) 고정
메모리 (역전파) O(L) O(1) adjoint O(L^2) attention
불규칙 시계열 보간 후 처리 직접 처리 위치 인코딩
물리적 해석 제한적 역학계로 해석 가능 제한적
학습 속도 빠름 ODE solver 오버헤드 빠름 (병렬화)
파라미터 수 레이어당 별도 전 시간대 공유 레이어당 별도

한계 및 주의사항

  1. 계산 비용: ODE solver 호출로 인해 이산 모델보다 학습/추론이 느림
  2. Homeomorphism 제약: 원래 차원에서 위상 변환 불가 (Augmented로 해결)
  3. NFE 증가: 복잡한 역학 학습 시 solver step 수가 급증 가능
  4. Stiff dynamics: 빠른/느린 역학이 혼재하면 solver가 어려워함
  5. 수치 오차: Adjoint method의 역시간 적분에서 gradient 정확도 저하 가능

핵심 논문 목록

연도 제목 저자 학회/저널
2018 Neural Ordinary Differential Equations Chen et al. NeurIPS 2018 (Best Paper)
2019 FFJORD: Free-form Continuous Dynamics for Scalable Reversible Generative Models Grathwohl et al. ICLR 2019
2019 Latent ODEs for Irregularly-Sampled Time Series Rubanova et al. NeurIPS 2019
2019 Augmented Neural ODEs Dupont et al. NeurIPS 2019
2020 How to Train Your Neural ODE Finlay et al. ICML 2020
2020 Neural SDEs as Infinite-Dimensional GANs Kidger et al. ICML 2021
2020 Scalable Gradients for Stochastic Differential Equations Li et al. AISTATS 2020
2021 "Hey, That's Not an ODE": Faster ODE Adjoints via Seminorms Kidger et al. ICML 2021
2022 Neural Controlled Differential Equations for Irregular Time Series Kidger et al. NeurIPS 2020
2023 Enhanced Distribution Modelling via Augmented Architectures (AFFJORD) Lim et al. NeurIPS 2023
2024 Training Stiff Neural ODEs with Implicit Single-Step Methods Kim et al. arXiv 2024

관련 주제