콘텐츠로 이동
Data Prep
상세

World Models

메타 정보

항목 내용
서베이 논문 Understanding World or Predicting Future? A Comprehensive Survey of World Models (Ding et al., 2024)
발표 ACM Computing Surveys 2025
arXiv 2411.14499
분야 Embodied AI, Video Generation, Reinforcement Learning
키워드 World Model, Video Prediction, Simulation, AGI, Generative Models

개요

World Model은 에이전트가 환경의 동역학을 내부적으로 시뮬레이션하여 미래 상태를 예측하고 의사결정을 지원하는 학습된 모델이다. GPT-4와 같은 멀티모달 LLM과 Sora와 같은 비디오 생성 모델의 발전으로 인공지능의 핵심 개념으로 부상했다.

핵심 기여

  1. 환경 이해: 물리 법칙, 인과관계, 객체 상호작용 모델링
  2. 미래 예측: 행동에 따른 결과 시뮬레이션
  3. 계획 수립: 목표 달성을 위한 행동 시퀀스 탐색
  4. 데이터 효율성: 실제 환경 없이 가상 경험으로 학습

배경 지식

World Model의 정의

World Model은 환경의 전이 함수(transition function)를 근사하는 학습된 모델이다:

\[ \hat{s}_{t+1} = f_\theta(s_t, a_t) \]
  • \(s_t\): 현재 상태 (이미지, 센서 데이터 등)
  • \(a_t\): 에이전트의 행동
  • \(\hat{s}_{t+1}\): 예측된 다음 상태
  • \(f_\theta\): 학습된 전이 모델

역사적 맥락

시기 발전 대표 연구
1986 Internal Models 개념 제안 Craik (1943), Jordan & Rumelhart
2018 World Models 논문 Ha & Schmidhuber
2020 Dreamer 시리즈 Hafner et al.
2023 대규모 비디오 모델 Genie (ICML 2024)
2024 범용 World Model Sora (OpenAI)

기존 접근법의 한계

접근법 한계
Model-Free RL 데이터 비효율성, 실제 환경 필요
Physics Simulation 수작업 모델링, 일반화 어려움
Video Prediction 장기 일관성 부족, 제어 불가

분류 체계

1. 기능 기반 분류

┌─────────────────────────────────────────────────────────────┐
│                    World Model 분류                         │
├─────────────────────────────────────────────────────────────┤
│                                                             │
│  Understanding-Oriented       Prediction-Oriented          │
│  ┌─────────────────────┐     ┌─────────────────────┐       │
│  │ 환경 구조 이해       │     │ 미래 상태 예측       │       │
│  │ - 인과 관계         │     │ - 시계열 예측        │       │
│  │ - 객체 관계         │     │ - 비디오 생성        │       │
│  │ - 물리 법칙         │     │ - 행동 결과 예측     │       │
│  └──────────┬──────────┘     └──────────┬──────────┘       │
│             │                           │                   │
│             └───────────┬───────────────┘                   │
│                         │                                   │
│                         ▼                                   │
│             ┌─────────────────────┐                        │
│             │ Decision Making     │                        │
│             │ - 계획 수립          │                        │
│             │ - 정책 최적화        │                        │
│             └─────────────────────┘                        │
│                                                             │
└─────────────────────────────────────────────────────────────┘

2. Embodied AI 관점 (3축 분류)

분류 설명
기능 Decision-Coupled RL/Planning과 통합
General-Purpose 범용 비디오 생성
시간 모델링 Sequential 순차적 프레임 예측
Global Difference 전체 변화량 예측
공간 표현 Global Latent 전역 벡터 (VAE)
Token Sequence 토큰 시퀀스 (Transformer)
Spatial Grid 공간 격자 (CNN)
Decomposed 분해 렌더링 (NeRF, 3DGS)

3. 세대별 발전

세대 특징 대표 모델
1세대 짧은 클립, 단일 장면 VideoGPT, TATS
2세대 긴 비디오, 텍스트 조건부 Sora, Kling
3세대 상호작용, 물리 일관성 Genie, iVideoGPT
4세대 (목표) 실시간, 멀티모달, 계획 Genie 2, Cosmos

핵심 아키텍처

1. 전통적 구조 (Ha & Schmidhuber, 2018)

┌───────────────────────────────────────────────────────────┐
│                World Model Architecture                    │
├───────────────────────────────────────────────────────────┤
│                                                           │
│  Observation (o_t)                                        │
│        │                                                  │
│        ▼                                                  │
│  ┌───────────────┐                                        │
│  │   Vision (V)  │  VAE Encoder                          │
│  │   o_t → z_t   │  고차원 관측 → 압축된 잠재 상태         │
│  └───────┬───────┘                                        │
│          │ z_t                                            │
│          ▼                                                │
│  ┌───────────────┐                                        │
│  │   Memory (M)  │  MDN-RNN                              │
│  │   (z_t, a_t,  │  잠재 공간에서 동역학 모델링            │
│  │    h_t) → h'  │                                        │
│  └───────┬───────┘                                        │
│          │ h_{t+1}                                        │
│          ▼                                                │
│  ┌───────────────┐                                        │
│  │ Controller (C)│  Compact Policy                       │
│  │   h_t → a_t   │  잠재 상태에서 행동 결정                │
│  └───────────────┘                                        │
│                                                           │
└───────────────────────────────────────────────────────────┘

2. Dreamer 구조 (Hafner et al.)

┌───────────────────────────────────────────────────────────┐
│                 Dreamer Architecture                       │
├───────────────────────────────────────────────────────────┤
│                                                           │
│  Real Experience          Imagined Experience             │
│  ┌─────────────┐         ┌─────────────┐                 │
│  │ Environment │ ───────→│ World Model │                 │
│  │ (Online)    │         │ (Offline)   │                 │
│  └──────┬──────┘         └──────┬──────┘                 │
│         │                       │                         │
│         ▼                       ▼                         │
│  ┌──────────────────────────────────────────┐            │
│  │              Latent Dynamics              │            │
│  │                                           │            │
│  │  Encoder: e_θ(o_t) → z_t                 │            │
│  │  Prior:   p_θ(z_t | h_t)                 │            │
│  │  Posterior: q_θ(z_t | h_t, o_t)          │            │
│  │  Transition: h_{t+1} = f_θ(h_t, z_t, a_t)│            │
│  │  Decoder: d_θ(h_t, z_t) → o_t            │            │
│  │  Reward: r_θ(h_t, z_t) → r_t             │            │
│  │                                           │            │
│  └──────────────────────────────────────────┘            │
│                         │                                 │
│                         ▼                                 │
│              ┌─────────────────────┐                     │
│              │   Actor-Critic      │                     │
│              │   Policy Learning   │                     │
│              └─────────────────────┘                     │
│                                                           │
└───────────────────────────────────────────────────────────┘

3. Video World Model (Sora-style)

┌───────────────────────────────────────────────────────────┐
│              Video World Model Pipeline                    │
├───────────────────────────────────────────────────────────┤
│                                                           │
│  Input: Text Prompt / Initial Frame / Action Sequence     │
│                         │                                 │
│                         ▼                                 │
│  ┌───────────────────────────────────────────┐           │
│  │         Latent Encoder (3D VAE)            │           │
│  │   Video → Spatiotemporal Latent Tokens     │           │
│  └───────────────────────────────────────────┘           │
│                         │                                 │
│                         ▼                                 │
│  ┌───────────────────────────────────────────┐           │
│  │     Diffusion Transformer (DiT)            │           │
│  │   - Spatial-Temporal Attention             │           │
│  │   - Text Cross-Attention                   │           │
│  │   - Action Conditioning                    │           │
│  └───────────────────────────────────────────┘           │
│                         │                                 │
│                         ▼                                 │
│  ┌───────────────────────────────────────────┐           │
│  │         Latent Decoder (3D VAE)            │           │
│  │   Spatiotemporal Tokens → Video Frames     │           │
│  └───────────────────────────────────────────┘           │
│                         │                                 │
│                         ▼                                 │
│               Output: Generated Video                     │
│                                                           │
└───────────────────────────────────────────────────────────┘

주요 모델

비디오 생성 World Models

모델 조직 발표 특징
Sora OpenAI 2024 DiT 기반, 1분 고품질 비디오
Sora 2 OpenAI 2025 20초 비디오, 향상된 물리
Genie DeepMind ICML 2024 상호작용 가능, 게임 생성
Genie 2 DeepMind 2024 3D 환경 시뮬레이션
Kling Kuaishou 2024 중국, 긴 비디오 생성
Cosmos NVIDIA 2025 자율주행/로봇 특화
MovieGen Meta 2024 영화급 품질, 오디오 동기화
LWM UC Berkeley ICLR 2025 Million-length 비디오

Embodied AI World Models

모델 적용 분야 발표 특징
Dreamer v3 범용 RL 2023 이미지 기반 RL SOTA
TD-MPC2 로봇 제어 2023 모델 예측 제어
UniSim 자율주행 2023 NVIDIA, 센서 시뮬레이션
GAIA-1 자율주행 Wayve 2023 9B 파라미터
iVideoGPT 상호작용 NeurIPS 2024 실시간 상호작용

학습 목표

1. Reconstruction Loss

미래 상태/프레임 재구성:

\[ \mathcal{L}_{recon} = \mathbb{E} \left[ \| o_t - \hat{o}_t \|^2 \right] \]

2. Latent Dynamics Loss

잠재 공간에서 전이 예측:

\[ \mathcal{L}_{dyn} = \mathbb{E} \left[ D_{KL}(q(z_{t+1}|o_{t+1}, h_{t+1}) \| p(z_{t+1}|h_{t+1})) \right] \]

3. Reward Prediction Loss

보상 예측 (RL 응용):

\[ \mathcal{L}_{reward} = \mathbb{E} \left[ (r_t - \hat{r}_t)^2 \right] \]

4. Contrastive Loss

시공간 표현 학습:

\[ \mathcal{L}_{contrast} = -\log \frac{\exp(sim(z_t, z_{t+k})/\tau)}{\sum_{j} \exp(sim(z_t, z_j)/\tau)} \]

Python 구현 예시

기본 World Model (VAE + RNN)

import torch
import torch.nn as nn
import torch.nn.functional as F
from typing import Tuple, Optional

class VAEEncoder(nn.Module):
    """관측을 잠재 공간으로 인코딩"""
    def __init__(self, obs_dim: int, latent_dim: int, hidden_dim: int = 256):
        super().__init__()
        self.encoder = nn.Sequential(
            nn.Linear(obs_dim, hidden_dim),
            nn.ReLU(),
            nn.Linear(hidden_dim, hidden_dim),
            nn.ReLU()
        )
        self.mu = nn.Linear(hidden_dim, latent_dim)
        self.logvar = nn.Linear(hidden_dim, latent_dim)

    def forward(self, x: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
        h = self.encoder(x)
        return self.mu(h), self.logvar(h)

    def sample(self, mu: torch.Tensor, logvar: torch.Tensor) -> torch.Tensor:
        std = torch.exp(0.5 * logvar)
        eps = torch.randn_like(std)
        return mu + eps * std


class VAEDecoder(nn.Module):
    """잠재 상태를 관측으로 디코딩"""
    def __init__(self, latent_dim: int, obs_dim: int, hidden_dim: int = 256):
        super().__init__()
        self.decoder = nn.Sequential(
            nn.Linear(latent_dim, hidden_dim),
            nn.ReLU(),
            nn.Linear(hidden_dim, hidden_dim),
            nn.ReLU(),
            nn.Linear(hidden_dim, obs_dim)
        )

    def forward(self, z: torch.Tensor) -> torch.Tensor:
        return self.decoder(z)


class DynamicsModel(nn.Module):
    """잠재 공간에서의 동역학 모델 (RSSM 스타일)"""
    def __init__(
        self, 
        latent_dim: int, 
        action_dim: int, 
        hidden_dim: int = 256,
        rnn_hidden: int = 256
    ):
        super().__init__()
        self.rnn_hidden = rnn_hidden

        # Deterministic path (RNN)
        self.rnn = nn.GRUCell(latent_dim + action_dim, rnn_hidden)

        # Prior: p(z_t | h_t)
        self.prior = nn.Sequential(
            nn.Linear(rnn_hidden, hidden_dim),
            nn.ReLU(),
            nn.Linear(hidden_dim, latent_dim * 2)  # mu, logvar
        )

        # Posterior: q(z_t | h_t, o_t) - used during training
        self.posterior = nn.Sequential(
            nn.Linear(rnn_hidden + latent_dim, hidden_dim),
            nn.ReLU(),
            nn.Linear(hidden_dim, latent_dim * 2)
        )

    def forward(
        self, 
        z: torch.Tensor, 
        action: torch.Tensor, 
        h: torch.Tensor,
        obs_embed: Optional[torch.Tensor] = None
    ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
        """
        Args:
            z: 이전 잠재 상태
            action: 수행한 행동
            h: RNN 히든 상태
            obs_embed: 관측 인코딩 (학습 시에만 사용)

        Returns:
            next_h: 다음 RNN 히든 상태
            prior_mu, prior_logvar: Prior 분포
            post_mu, post_logvar: Posterior 분포 (obs_embed 있을 때)
        """
        # RNN 업데이트
        rnn_input = torch.cat([z, action], dim=-1)
        next_h = self.rnn(rnn_input, h)

        # Prior
        prior_stats = self.prior(next_h)
        prior_mu, prior_logvar = prior_stats.chunk(2, dim=-1)

        # Posterior (학습 시)
        if obs_embed is not None:
            post_input = torch.cat([next_h, obs_embed], dim=-1)
            post_stats = self.posterior(post_input)
            post_mu, post_logvar = post_stats.chunk(2, dim=-1)
        else:
            post_mu, post_logvar = prior_mu, prior_logvar

        return next_h, prior_mu, prior_logvar, post_mu, post_logvar

    def init_hidden(self, batch_size: int, device: str) -> torch.Tensor:
        return torch.zeros(batch_size, self.rnn_hidden, device=device)


class RewardPredictor(nn.Module):
    """잠재 상태에서 보상 예측"""
    def __init__(self, latent_dim: int, hidden_dim: int = 256):
        super().__init__()
        self.net = nn.Sequential(
            nn.Linear(latent_dim + hidden_dim, hidden_dim),
            nn.ReLU(),
            nn.Linear(hidden_dim, 1)
        )

    def forward(self, z: torch.Tensor, h: torch.Tensor) -> torch.Tensor:
        return self.net(torch.cat([z, h], dim=-1))


class WorldModel(nn.Module):
    """통합 World Model"""
    def __init__(
        self, 
        obs_dim: int, 
        action_dim: int, 
        latent_dim: int = 32,
        hidden_dim: int = 256,
        rnn_hidden: int = 256
    ):
        super().__init__()
        self.latent_dim = latent_dim

        self.encoder = VAEEncoder(obs_dim, latent_dim, hidden_dim)
        self.decoder = VAEDecoder(latent_dim + rnn_hidden, obs_dim, hidden_dim)
        self.dynamics = DynamicsModel(latent_dim, action_dim, hidden_dim, rnn_hidden)
        self.reward_pred = RewardPredictor(latent_dim, rnn_hidden)

    def encode(self, obs: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
        return self.encoder(obs)

    def decode(self, z: torch.Tensor, h: torch.Tensor) -> torch.Tensor:
        return self.decoder(torch.cat([z, h], dim=-1))

    def imagine(
        self, 
        initial_z: torch.Tensor, 
        initial_h: torch.Tensor,
        actions: torch.Tensor
    ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
        """행동 시퀀스에 따른 미래 상상

        Args:
            initial_z: 초기 잠재 상태 (batch, latent_dim)
            initial_h: 초기 RNN 상태 (batch, rnn_hidden)
            actions: 행동 시퀀스 (batch, horizon, action_dim)

        Returns:
            z_seq: 상상된 잠재 상태 시퀀스
            h_seq: RNN 상태 시퀀스
            reward_seq: 예측 보상 시퀀스
        """
        batch_size, horizon, _ = actions.shape
        device = actions.device

        z_seq = [initial_z]
        h_seq = [initial_h]
        reward_seq = []

        z, h = initial_z, initial_h

        for t in range(horizon):
            # 동역학 모델로 다음 상태 예측
            h, prior_mu, prior_logvar, _, _ = self.dynamics(
                z, actions[:, t], h, obs_embed=None
            )
            z = self.encoder.sample(prior_mu, prior_logvar)

            # 보상 예측
            reward = self.reward_pred(z, h)

            z_seq.append(z)
            h_seq.append(h)
            reward_seq.append(reward)

        return (
            torch.stack(z_seq, dim=1),
            torch.stack(h_seq, dim=1),
            torch.stack(reward_seq, dim=1).squeeze(-1)
        )


class WorldModelTrainer:
    """World Model 학습기"""
    def __init__(self, model: WorldModel, lr: float = 1e-3, kl_weight: float = 1.0):
        self.model = model
        self.optimizer = torch.optim.Adam(model.parameters(), lr=lr)
        self.kl_weight = kl_weight

    def compute_loss(
        self,
        observations: torch.Tensor,  # (batch, seq_len, obs_dim)
        actions: torch.Tensor,       # (batch, seq_len, action_dim)
        rewards: torch.Tensor        # (batch, seq_len)
    ) -> Tuple[torch.Tensor, dict]:
        batch_size, seq_len, _ = observations.shape
        device = observations.device

        # 초기 히든 상태
        h = self.model.dynamics.init_hidden(batch_size, device)

        recon_loss = 0
        kl_loss = 0
        reward_loss = 0

        for t in range(seq_len):
            obs = observations[:, t]

            # 관측 인코딩
            mu, logvar = self.model.encode(obs)
            z = self.model.encoder.sample(mu, logvar)

            if t > 0:
                # 동역학 모델 업데이트
                h, prior_mu, prior_logvar, post_mu, post_logvar = self.model.dynamics(
                    prev_z, actions[:, t-1], h, obs_embed=mu
                )

                # KL divergence (posterior vs prior)
                kl = -0.5 * torch.sum(
                    1 + post_logvar - prior_logvar - 
                    (post_mu - prior_mu).pow(2) / prior_logvar.exp() -
                    post_logvar.exp() / prior_logvar.exp(),
                    dim=-1
                )
                kl_loss = kl_loss + kl.mean()

                # Posterior에서 샘플링 (학습 시)
                z = self.model.encoder.sample(post_mu, post_logvar)

            # 디코딩 (재구성)
            obs_recon = self.model.decode(z, h)
            recon_loss = recon_loss + F.mse_loss(obs_recon, obs)

            # 보상 예측
            reward_pred = self.model.reward_pred(z, h)
            reward_loss = reward_loss + F.mse_loss(reward_pred.squeeze(-1), rewards[:, t])

            prev_z = z

        # 총 손실
        total_loss = recon_loss + self.kl_weight * kl_loss + reward_loss

        metrics = {
            'recon_loss': recon_loss.item(),
            'kl_loss': kl_loss.item() if isinstance(kl_loss, torch.Tensor) else kl_loss,
            'reward_loss': reward_loss.item(),
            'total_loss': total_loss.item()
        }

        return total_loss, metrics

    def train_step(
        self,
        observations: torch.Tensor,
        actions: torch.Tensor,
        rewards: torch.Tensor
    ) -> dict:
        self.optimizer.zero_grad()
        loss, metrics = self.compute_loss(observations, actions, rewards)
        loss.backward()
        torch.nn.utils.clip_grad_norm_(self.model.parameters(), 100.0)
        self.optimizer.step()
        return metrics


# 사용 예시
if __name__ == "__main__":
    device = 'cuda' if torch.cuda.is_available() else 'cpu'

    # 모델 초기화
    obs_dim = 64
    action_dim = 4
    model = WorldModel(obs_dim, action_dim).to(device)
    trainer = WorldModelTrainer(model)

    # 더미 데이터
    batch_size = 32
    seq_len = 50

    observations = torch.randn(batch_size, seq_len, obs_dim, device=device)
    actions = torch.randn(batch_size, seq_len, action_dim, device=device)
    rewards = torch.randn(batch_size, seq_len, device=device)

    # 학습
    for epoch in range(100):
        metrics = trainer.train_step(observations, actions, rewards)
        if epoch % 10 == 0:
            print(f"Epoch {epoch}: {metrics}")

    # 미래 상상
    initial_obs = observations[:, 0]
    mu, logvar = model.encode(initial_obs)
    initial_z = model.encoder.sample(mu, logvar)
    initial_h = model.dynamics.init_hidden(batch_size, device)

    future_actions = torch.randn(batch_size, 20, action_dim, device=device)
    z_seq, h_seq, reward_seq = model.imagine(initial_z, initial_h, future_actions)

    print(f"Imagined {z_seq.shape[1]} future states")
    print(f"Predicted rewards: {reward_seq.mean():.4f}")

간단한 Video World Model (Transformer 기반)

import torch
import torch.nn as nn
from einops import rearrange

class PatchEmbedding(nn.Module):
    """이미지를 패치 토큰으로 변환"""
    def __init__(self, img_size: int, patch_size: int, in_channels: int, embed_dim: int):
        super().__init__()
        self.patch_size = patch_size
        self.n_patches = (img_size // patch_size) ** 2
        self.proj = nn.Conv2d(in_channels, embed_dim, patch_size, stride=patch_size)

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        # x: (B, C, H, W) -> (B, N, D)
        x = self.proj(x)  # (B, D, H/P, W/P)
        x = rearrange(x, 'b d h w -> b (h w) d')
        return x


class TemporalAttention(nn.Module):
    """시간 축 어텐션"""
    def __init__(self, embed_dim: int, n_heads: int = 8):
        super().__init__()
        self.attn = nn.MultiheadAttention(embed_dim, n_heads, batch_first=True)
        self.norm = nn.LayerNorm(embed_dim)

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        # x: (B, T, N, D) - batch, time, patches, dim
        B, T, N, D = x.shape

        # 패치별로 시간 축 어텐션
        x = rearrange(x, 'b t n d -> (b n) t d')
        x_norm = self.norm(x)
        attn_out, _ = self.attn(x_norm, x_norm, x_norm)
        x = x + attn_out
        x = rearrange(x, '(b n) t d -> b t n d', b=B, n=N)

        return x


class SpatialAttention(nn.Module):
    """공간 축 어텐션"""
    def __init__(self, embed_dim: int, n_heads: int = 8):
        super().__init__()
        self.attn = nn.MultiheadAttention(embed_dim, n_heads, batch_first=True)
        self.norm = nn.LayerNorm(embed_dim)

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        # x: (B, T, N, D)
        B, T, N, D = x.shape

        # 프레임별로 공간 축 어텐션
        x = rearrange(x, 'b t n d -> (b t) n d')
        x_norm = self.norm(x)
        attn_out, _ = self.attn(x_norm, x_norm, x_norm)
        x = x + attn_out
        x = rearrange(x, '(b t) n d -> b t n d', b=B, t=T)

        return x


class VideoWorldModelBlock(nn.Module):
    """시공간 Transformer 블록"""
    def __init__(self, embed_dim: int, n_heads: int = 8):
        super().__init__()
        self.temporal_attn = TemporalAttention(embed_dim, n_heads)
        self.spatial_attn = SpatialAttention(embed_dim, n_heads)
        self.ffn = nn.Sequential(
            nn.LayerNorm(embed_dim),
            nn.Linear(embed_dim, embed_dim * 4),
            nn.GELU(),
            nn.Linear(embed_dim * 4, embed_dim)
        )

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        x = self.temporal_attn(x)
        x = self.spatial_attn(x)
        x = x + self.ffn(x)
        return x


class SimpleVideoWorldModel(nn.Module):
    """간단한 Video World Model"""
    def __init__(
        self,
        img_size: int = 64,
        patch_size: int = 8,
        in_channels: int = 3,
        embed_dim: int = 256,
        n_layers: int = 4,
        n_heads: int = 8,
        action_dim: int = 4
    ):
        super().__init__()
        self.patch_embed = PatchEmbedding(img_size, patch_size, in_channels, embed_dim)
        self.n_patches = self.patch_embed.n_patches

        # 위치 임베딩
        self.pos_embed = nn.Parameter(torch.randn(1, 1, self.n_patches, embed_dim) * 0.02)
        self.time_embed = nn.Parameter(torch.randn(1, 100, 1, embed_dim) * 0.02)  # max 100 frames

        # 행동 조건화
        self.action_embed = nn.Linear(action_dim, embed_dim)

        # Transformer 블록
        self.blocks = nn.ModuleList([
            VideoWorldModelBlock(embed_dim, n_heads) for _ in range(n_layers)
        ])

        # 출력 프로젝션 (다음 프레임 예측)
        self.head = nn.Sequential(
            nn.LayerNorm(embed_dim),
            nn.Linear(embed_dim, patch_size * patch_size * in_channels)
        )

        self.patch_size = patch_size
        self.img_size = img_size
        self.in_channels = in_channels

    def forward(
        self, 
        frames: torch.Tensor,    # (B, T, C, H, W)
        actions: torch.Tensor    # (B, T, action_dim)
    ) -> torch.Tensor:
        B, T, C, H, W = frames.shape

        # 패치 임베딩
        x = rearrange(frames, 'b t c h w -> (b t) c h w')
        x = self.patch_embed(x)  # (B*T, N, D)
        x = rearrange(x, '(b t) n d -> b t n d', b=B, t=T)

        # 위치 + 시간 임베딩
        x = x + self.pos_embed + self.time_embed[:, :T]

        # 행동 조건화 (각 시점에 추가)
        action_emb = self.action_embed(actions)  # (B, T, D)
        x = x + action_emb.unsqueeze(2)  # 모든 패치에 브로드캐스트

        # Transformer
        for block in self.blocks:
            x = block(x)

        # 다음 프레임 예측 (마지막 프레임의 각 패치에서)
        x_last = x[:, -1]  # (B, N, D)
        pred_patches = self.head(x_last)  # (B, N, patch_size^2 * C)

        # 패치를 이미지로 재구성
        pred_patches = rearrange(
            pred_patches, 
            'b (h w) (p1 p2 c) -> b c (h p1) (w p2)',
            h=H//self.patch_size, w=W//self.patch_size,
            p1=self.patch_size, p2=self.patch_size, c=C
        )

        return pred_patches

    @torch.no_grad()
    def generate(
        self,
        initial_frames: torch.Tensor,  # (B, T_init, C, H, W)
        actions: torch.Tensor,          # (B, T_total, action_dim)
        n_frames: int = 16
    ) -> torch.Tensor:
        """자기회귀적으로 미래 프레임 생성"""
        B = initial_frames.shape[0]
        frames = initial_frames.clone()

        for t in range(n_frames):
            # 현재까지의 프레임으로 다음 프레임 예측
            current_t = frames.shape[1]
            next_frame = self(frames, actions[:, :current_t])
            next_frame = next_frame.unsqueeze(1)
            frames = torch.cat([frames, next_frame], dim=1)

        return frames


# 사용 예시
if __name__ == "__main__":
    device = 'cuda' if torch.cuda.is_available() else 'cpu'

    model = SimpleVideoWorldModel(
        img_size=64,
        patch_size=8,
        embed_dim=256,
        n_layers=4,
        action_dim=4
    ).to(device)

    # 더미 입력
    batch_size = 4
    n_frames = 8
    action_dim = 4

    frames = torch.randn(batch_size, n_frames, 3, 64, 64, device=device)
    actions = torch.randn(batch_size, n_frames, action_dim, device=device)

    # 다음 프레임 예측
    next_frame = model(frames, actions)
    print(f"Predicted frame shape: {next_frame.shape}")

    # 자기회귀 생성
    initial = frames[:, :4]
    future_actions = torch.randn(batch_size, 20, action_dim, device=device)
    generated = model.generate(initial, future_actions, n_frames=16)
    print(f"Generated video shape: {generated.shape}")

평가 지표

비디오 품질

지표 측정 대상 값 범위
FVD (Frechet Video Distance) 시간적 일관성 + 품질 낮을수록 좋음
FID (Frechet Inception Distance) 개별 프레임 품질 낮을수록 좋음
PSNR 픽셀 수준 정확도 높을수록 좋음
SSIM 구조적 유사도 0-1, 높을수록 좋음
LPIPS 지각적 유사도 낮을수록 좋음

물리적 일관성

지표 측정 대상
Physics Violation Rate 물리 법칙 위반 빈도
Object Permanence 객체 영속성 유지
Collision Realism 충돌 물리 정확도
Gravity Consistency 중력 일관성

제어 성능 (Embodied AI)

지표 측정 대상
Task Success Rate 태스크 성공률
Sample Efficiency 학습에 필요한 환경 상호작용 수
Planning Horizon 효과적 계획 수립 가능 기간

응용 분야

1. 자율주행

  • 센서 시뮬레이션 (LiDAR, 카메라)
  • 희귀 시나리오 생성 (코너 케이스)
  • 폐쇄 루프 테스트

2. 로봇공학

  • 조작 정책 사전 학습
  • 시뮬레이션-현실 전이 (Sim2Real)
  • 안전한 탐색

3. 게임/시뮬레이션

  • 절차적 콘텐츠 생성
  • 플레이어 행동 기반 적응
  • NPC 지능

4. 비디오 생성

  • 텍스트-투-비디오
  • 이미지 애니메이션
  • 영화/광고 제작

5. 과학 연구

  • 기후 모델링
  • 단백질 동역학
  • 사회 시뮬레이션

장단점

장점

  • 데이터 효율성: 가상 경험으로 학습, 실제 환경 의존도 감소
  • 안전한 탐색: 위험한 시나리오를 시뮬레이션으로 테스트
  • 계획 수립: 미래 결과 예측 기반 의사결정
  • 전이 학습: 다양한 태스크에 적용 가능한 범용 표현

단점

  • 모델 오류 누적: 장기 예측 시 오류 축적 (compounding error)
  • 계산 비용: 대규모 비디오 모델은 막대한 자원 필요
  • 물리 일관성: 복잡한 물리 상호작용 모델링 어려움
  • 실시간 제약: 로봇 제어에 필요한 속도 달성 어려움

주요 도전 과제

1. 장기 시간 일관성

  • 프레임 간 일관성 유지
  • 객체 영속성
  • 누적 오류 완화

2. 물리적 정확성

  • 강체/유체 역학
  • 접촉 및 충돌
  • 재료 특성

3. 실시간 추론

  • 모델 경량화
  • 효율적 아키텍처
  • 하드웨어 최적화

4. 평가 표준화

  • 픽셀 정확도 vs 물리적 정확성
  • 도메인별 벤치마크
  • 인간 평가와의 정렬

참고 자료

서베이 논문

핵심 논문

리소스


마지막 업데이트: 2026-02-19