State Space Models (SSM)¶
Transformer의 대안으로 부상한 시퀀스 모델링 아키텍처. 선형 시간 복잡도로 긴 시퀀스를 효율적으로 처리한다.
Meta Information¶
| 항목 | 내용 |
|---|---|
| 분류 | Deep Learning / Sequence Modeling |
| 핵심 논문 | S4 (NeurIPS 2021), Mamba (arXiv 2312.00752) |
| 주요 저자 | Albert Gu, Tri Dao (Carnegie Mellon, Princeton) |
| 관련 기법 | RNN, Transformer, Linear Attention |
| 응용 분야 | NLP, Audio, Vision, Time-series, Genomics |
1. 개요¶
1.1 배경¶
| 모델 | 시간 복잡도 | 메모리 | 장거리 의존성 |
|---|---|---|---|
| RNN | O(L) | O(1) | Vanishing gradient 문제 |
| Transformer | O(L^2) | O(L^2) | 완전한 attention |
| SSM | O(L) 또는 O(L log L) | O(L) | 구조화된 상태 공간 |
L = 시퀀스 길이
1.2 핵심 아이디어¶
State Space Models은 연속 시간 동역학 시스템을 이산화하여 시퀀스를 모델링한다:
연속 시스템:
x'(t) = Ax(t) + Bu(t)
y(t) = Cx(t) + Du(t)
이산화 (Zero-Order Hold):
x_k = Ā x_{k-1} + B̄ u_k
y_k = C x_k + D u_k
여기서:
- x: 잠재 상태 (hidden state)
- u: 입력 시퀀스
- y: 출력 시퀀스
- A, B, C, D: 학습 가능한 파라미터
2. 주요 발전 과정¶
2.1 S4 (Structured State Space Sequence Model)¶
NeurIPS 2021 Outstanding Paper
핵심 기여: - HiPPO (High-order Polynomial Projection Operators) 행렬로 A 초기화 - 장거리 의존성 문제 해결 - FFT 기반 컨볼루션으로 병렬 학습
2.2 S5 (Simplified S4)¶
- 대각 상태 공간으로 단순화
- MIMO (Multi-Input Multi-Output) 확장
- 학습 안정성 개선
2.3 Mamba (2023)¶
핵심 혁신: Selective State Spaces
기존 SSM의 한계: 입력에 관계없이 동일한 A, B, C 사용 (content-agnostic)
Mamba의 해결책:
Hardware-Aware 알고리즘: - 커널 퓨전으로 메모리 I/O 최소화 - Recurrent 모드에서도 병렬 처리 - FlashAttention 스타일 최적화
2.4 Mamba-2 (ICML 2024)¶
State Space Duality (SSD)
Mamba-2는 SSM과 Attention의 이론적 연결을 밝힘: - 특정 조건에서 Structured SSM = Structured Attention - 더 효율적인 행렬 분해 알고리즘 - 8배 빠른 학습 속도
2.5 Jamba (AI21, 2024)¶
Mamba + Transformer 하이브리드: - Mamba 레이어와 Attention 레이어 혼합 - MoE (Mixture of Experts) 통합 - 256K 컨텍스트 지원
3. 아키텍처 상세¶
3.1 Mamba Block¶
┌─────────────────────────────────────────┐
│ Input (B, L, D) │
├─────────────────────────────────────────┤
│ Linear Projection │
│ D → 2*E │
├──────────────────┬──────────────────────┤
│ Branch 1 │ Branch 2 │
│ Conv1D + SiLU │ SiLU │
├──────────────────┤ │
│ Selective SSM │ │
├──────────────────┴──────────────────────┤
│ Element-wise Multiply │
├─────────────────────────────────────────┤
│ Linear Projection │
│ E → D │
├─────────────────────────────────────────┤
│ Residual Connection │
└─────────────────────────────────────────┘
3.2 Selective Scan Algorithm¶
# Pseudo-code
def selective_scan(u, delta, A, B, C):
"""
u: (B, L, D) 입력
delta: (B, L, D) 시간 스텝
A: (D, N) 상태 행렬
B: (B, L, N) 입력 행렬
C: (B, L, N) 출력 행렬
"""
# 이산화
deltaA = exp(einsum('bld,dn->bldn', delta, A))
deltaB_u = einsum('bld,bln,bld->bldn', delta, B, u)
# 순차 스캔 (실제로는 병렬화)
x = zeros(B, D, N)
ys = []
for i in range(L):
x = deltaA[:, i] * x + deltaB_u[:, i]
y = einsum('bdn,bn->bd', x, C[:, i])
ys.append(y)
return stack(ys, dim=1)
4. 성능 비교¶
4.1 언어 모델링 (Perplexity, 낮을수록 좋음)¶
| 모델 | 파라미터 | The Pile (PPL) |
|---|---|---|
| Transformer | 2.7B | 8.21 |
| Mamba | 2.8B | 7.82 |
| Transformer | 6.9B | 7.50 |
| Mamba | 2.8B | 7.82 (2.5x 작은 크기로 근접) |
4.2 처리량 (Throughput)¶
시퀀스 길이에 따른 토큰/초 (A100 GPU):
| 길이 | Transformer | Mamba | 속도 향상 |
|---|---|---|---|
| 2K | 53K | 90K | 1.7x |
| 8K | 32K | 85K | 2.7x |
| 32K | 12K | 78K | 6.5x |
| 128K | OOM | 72K | - |
4.3 Long Range Arena 벤치마크¶
| Task | Transformer | S4 | Mamba |
|---|---|---|---|
| ListOps | 36.4 | 59.6 | 61.2 |
| Text | 64.3 | 86.8 | 87.5 |
| Retrieval | 57.5 | 90.9 | 91.3 |
| Image | 42.4 | 88.7 | 89.1 |
| Pathfinder | 71.4 | 94.2 | 94.8 |
| Path-X | FAIL | 96.4 | 97.1 |
| Average | 54.4 | 86.1 | 86.8 |
5. Python 구현¶
5.1 기본 SSM 레이어¶
import torch
import torch.nn as nn
import torch.nn.functional as F
from einops import rearrange, einsum
class SimpleSSM(nn.Module):
"""단순화된 State Space Model 구현"""
def __init__(self, d_model: int, d_state: int = 16):
super().__init__()
self.d_model = d_model
self.d_state = d_state
# SSM 파라미터
self.A = nn.Parameter(torch.randn(d_model, d_state))
self.B = nn.Parameter(torch.randn(d_model, d_state))
self.C = nn.Parameter(torch.randn(d_model, d_state))
self.D = nn.Parameter(torch.ones(d_model))
# 시간 스텝
self.dt = nn.Parameter(torch.ones(d_model) * 0.1)
self._init_parameters()
def _init_parameters(self):
"""HiPPO 스타일 초기화"""
with torch.no_grad():
# A를 음수로 초기화 (안정성)
self.A.copy_(-torch.abs(self.A))
def forward(self, x: torch.Tensor) -> torch.Tensor:
"""
Args:
x: (batch, length, d_model)
Returns:
y: (batch, length, d_model)
"""
batch, length, _ = x.shape
# 이산화 (Zero-Order Hold)
dt = F.softplus(self.dt)
A_discrete = torch.exp(self.A * dt.unsqueeze(-1))
B_discrete = self.B * dt.unsqueeze(-1)
# Recurrent 계산
state = torch.zeros(batch, self.d_model, self.d_state, device=x.device)
outputs = []
for t in range(length):
# x_t: (batch, d_model)
x_t = x[:, t, :]
# 상태 업데이트: h_t = A * h_{t-1} + B * x_t
state = state * A_discrete + x_t.unsqueeze(-1) * B_discrete
# 출력: y_t = C * h_t + D * x_t
y_t = (state * self.C).sum(dim=-1) + self.D * x_t
outputs.append(y_t)
return torch.stack(outputs, dim=1)
class SelectiveSSM(nn.Module):
"""Mamba 스타일 Selective SSM"""
def __init__(self, d_model: int, d_state: int = 16, d_conv: int = 4):
super().__init__()
self.d_model = d_model
self.d_state = d_state
# 입력 의존적 파라미터 생성
self.proj_b = nn.Linear(d_model, d_state, bias=False)
self.proj_c = nn.Linear(d_model, d_state, bias=False)
self.proj_dt = nn.Linear(d_model, d_model, bias=True)
# 고정 파라미터
self.A_log = nn.Parameter(torch.log(torch.arange(1, d_state + 1).float()))
self.D = nn.Parameter(torch.ones(d_model))
# 1D Convolution
self.conv1d = nn.Conv1d(
d_model, d_model,
kernel_size=d_conv,
padding=d_conv - 1,
groups=d_model
)
def forward(self, x: torch.Tensor) -> torch.Tensor:
batch, length, d = x.shape
# Conv1D
x_conv = self.conv1d(x.transpose(1, 2))[:, :, :length].transpose(1, 2)
x_conv = F.silu(x_conv)
# Selective 파라미터 생성
B = self.proj_b(x_conv) # (batch, length, d_state)
C = self.proj_c(x_conv) # (batch, length, d_state)
dt = F.softplus(self.proj_dt(x_conv)) # (batch, length, d_model)
# A 이산화
A = -torch.exp(self.A_log) # (d_state,)
# Selective scan
y = self._selective_scan(x_conv, dt, A, B, C)
# Skip connection
return y + self.D * x
def _selective_scan(self, u, delta, A, B, C):
"""Selective scan 알고리즘 (단순 버전)"""
batch, length, d_model = u.shape
d_state = A.shape[0]
# 이산화
deltaA = torch.exp(delta.unsqueeze(-1) * A) # (B, L, D, N)
deltaB_u = delta.unsqueeze(-1) * B.unsqueeze(2) * u.unsqueeze(-1)
# 순차 스캔
state = torch.zeros(batch, d_model, d_state, device=u.device)
outputs = []
for t in range(length):
state = deltaA[:, t] * state + deltaB_u[:, t]
y = (state * C[:, t].unsqueeze(1)).sum(dim=-1)
outputs.append(y)
return torch.stack(outputs, dim=1)
5.2 Mamba Block 구현¶
class MambaBlock(nn.Module):
"""완전한 Mamba 블록"""
def __init__(
self,
d_model: int,
d_state: int = 16,
expand: int = 2,
d_conv: int = 4
):
super().__init__()
self.d_inner = d_model * expand
# 입력 프로젝션
self.in_proj = nn.Linear(d_model, self.d_inner * 2, bias=False)
# Selective SSM
self.ssm = SelectiveSSM(self.d_inner, d_state, d_conv)
# 출력 프로젝션
self.out_proj = nn.Linear(self.d_inner, d_model, bias=False)
# Layer Norm
self.norm = nn.LayerNorm(d_model)
def forward(self, x: torch.Tensor) -> torch.Tensor:
residual = x
x = self.norm(x)
# Split into two branches
xz = self.in_proj(x)
x, z = xz.chunk(2, dim=-1)
# SSM branch
x = self.ssm(x)
# Gating
x = x * F.silu(z)
# Output projection
x = self.out_proj(x)
return x + residual
class MambaModel(nn.Module):
"""Mamba 언어 모델"""
def __init__(
self,
vocab_size: int,
d_model: int = 768,
n_layers: int = 12,
d_state: int = 16,
expand: int = 2
):
super().__init__()
self.embedding = nn.Embedding(vocab_size, d_model)
self.layers = nn.ModuleList([
MambaBlock(d_model, d_state, expand)
for _ in range(n_layers)
])
self.norm = nn.LayerNorm(d_model)
self.lm_head = nn.Linear(d_model, vocab_size, bias=False)
# Weight tying
self.lm_head.weight = self.embedding.weight
def forward(self, input_ids: torch.Tensor) -> torch.Tensor:
x = self.embedding(input_ids)
for layer in self.layers:
x = layer(x)
x = self.norm(x)
return self.lm_head(x)
# 사용 예시
if __name__ == "__main__":
model = MambaModel(
vocab_size=32000,
d_model=768,
n_layers=12
)
# 파라미터 수
params = sum(p.numel() for p in model.parameters())
print(f"Parameters: {params / 1e6:.1f}M")
# Forward pass
x = torch.randint(0, 32000, (2, 1024))
logits = model(x)
print(f"Input: {x.shape}, Output: {logits.shape}")
5.3 시계열 예측 적용¶
class MambaForTimeSeries(nn.Module):
"""시계열 예측을 위한 Mamba 모델"""
def __init__(
self,
input_dim: int,
hidden_dim: int = 128,
n_layers: int = 4,
d_state: int = 16,
forecast_horizon: int = 24
):
super().__init__()
self.input_proj = nn.Linear(input_dim, hidden_dim)
self.layers = nn.ModuleList([
MambaBlock(hidden_dim, d_state)
for _ in range(n_layers)
])
self.norm = nn.LayerNorm(hidden_dim)
self.output_proj = nn.Linear(hidden_dim, forecast_horizon)
def forward(self, x: torch.Tensor) -> torch.Tensor:
"""
Args:
x: (batch, seq_len, input_dim)
Returns:
forecast: (batch, forecast_horizon)
"""
x = self.input_proj(x)
for layer in self.layers:
x = layer(x)
x = self.norm(x)
# 마지막 타임스텝에서 예측
return self.output_proj(x[:, -1, :])
# 시계열 학습 예시
def train_timeseries():
import numpy as np
# 합성 데이터 생성
np.random.seed(42)
T = 1000
t = np.linspace(0, 100, T)
data = np.sin(t) + 0.5 * np.sin(3 * t) + 0.1 * np.random.randn(T)
# 시퀀스 생성
seq_len, horizon = 96, 24
X, y = [], []
for i in range(len(data) - seq_len - horizon):
X.append(data[i:i+seq_len])
y.append(data[i+seq_len:i+seq_len+horizon])
X = torch.tensor(np.array(X), dtype=torch.float32).unsqueeze(-1)
y = torch.tensor(np.array(y), dtype=torch.float32)
# 모델 학습
model = MambaForTimeSeries(input_dim=1, forecast_horizon=horizon)
optimizer = torch.optim.Adam(model.parameters(), lr=1e-3)
for epoch in range(100):
pred = model(X)
loss = F.mse_loss(pred, y)
optimizer.zero_grad()
loss.backward()
optimizer.step()
if (epoch + 1) % 20 == 0:
print(f"Epoch {epoch+1}, Loss: {loss.item():.4f}")
return model
if __name__ == "__main__":
train_timeseries()
6. 실무 적용 가이드¶
6.1 Mamba vs Transformer 선택 기준¶
| 상황 | 권장 |
|---|---|
| 시퀀스 < 2K | Transformer |
| 시퀀스 > 8K | Mamba |
| 긴 문서/책 처리 | Mamba |
| 추론 지연 중요 | Mamba |
| 기존 생태계 활용 | Transformer |
| 시계열/오디오/유전체 | Mamba |
6.2 하이퍼파라미터 권장값¶
# 기본 설정
d_model: 768 # 모델 차원
d_state: 16 # 상태 차원 (16-64)
expand: 2 # 확장 비율
d_conv: 4 # 컨볼루션 커널 크기
# 스케일링
small: d_model=768, n_layers=12 # ~125M params
medium: d_model=1024, n_layers=24 # ~350M params
large: d_model=2048, n_layers=48 # ~1.3B params
# 학습
learning_rate: 1e-4 to 5e-4
warmup_steps: 1000
weight_decay: 0.1
batch_size: 시퀀스 길이에 따라 조정
6.3 주의사항¶
- 초기화 중요성: A 행렬 초기화가 안정성에 큰 영향
- 수치 안정성: 긴 시퀀스에서 overflow 주의 (fp32/bf16 권장)
- 하드웨어 최적화: 공식 CUDA 커널 사용 시 성능 크게 향상
- 하이브리드 고려: 일부 태스크에서 Attention과 혼합이 효과적
7. 관련 자료¶
논문¶
- S4: Efficiently Modeling Long Sequences with Structured State Spaces (NeurIPS 2021)
- Mamba: Linear-Time Sequence Modeling with Selective State Spaces (2023)
- Transformers are SSMs: Generalized Models and Efficient Algorithms Through Structured State Space Duality (ICML 2024)
- From S4 to Mamba: A Comprehensive Survey on SSMs (2025)
구현¶
- Official Mamba: 공식 구현
- mamba.py: 최소 구현
- S4 Repository: S4 공식 구현
튜토리얼¶
- The Annotated S4: 상세 설명 포함 구현
- Mamba Explained: 직관적 설명
최종 업데이트: 2026-02-08