State Space Models (SSM) - Mamba¶
개요¶
State Space Models (SSM)은 연속 시간 시스템을 이산화하여 시퀀스 모델링에 적용하는 아키텍처다. Mamba는 SSM에 선택적 메커니즘(selective scan)을 도입하여 Transformer에 필적하는 성능을 달성한 모델로, 시퀀스 길이에 대해 선형 시간 복잡도를 가진다.
핵심 개념¶
State Space Model 기본¶
연속 시간 상태 공간 모델:
여기서: - x(t): 입력 신호 - h(t): 숨겨진 상태 (hidden state) - y(t): 출력 신호 - A, B, C, D: 시스템 행렬
이산화 (Discretization)¶
연속 시스템을 이산 시퀀스로 변환:
Zero-Order Hold (ZOH):
A_bar = exp(delta * A)
B_bar = (delta * A)^(-1) * (exp(delta * A) - I) * delta * B
Simplified (Euler method):
A_bar = I + delta * A
B_bar = delta * B
이산화된 재귀 관계:
S4 (Structured State Space)¶
S4는 HiPPO 행렬을 사용하여 장거리 의존성을 효과적으로 모델링:
Selective Scan (Mamba의 핵심)¶
기존 SSM은 입력에 무관하게 동일한 파라미터를 사용하지만, Mamba는 입력에 따라 B, C, delta를 동적으로 결정:
# Input-dependent parameters
B = Linear_B(x) # (batch, seq_len, state_dim)
C = Linear_C(x) # (batch, seq_len, state_dim)
delta = softplus(Linear_delta(x)) # (batch, seq_len, d_model)
# Discretization with input-dependent delta
A_bar = exp(-delta * A)
B_bar = delta * B
# Selective scan
for k in range(seq_len):
h[k] = A_bar[k] * h[k-1] + B_bar[k] * x[k]
y[k] = C[k] @ h[k]
아키텍처 다이어그램¶
S4 Layer¶
Input x
|
v
+--------------+
| Conv1D | (local context)
+------+-------+
|
v
+--------------+
| SSM |
| h' = Ah+Bx |
| y = Ch |
+------+-------+
|
v
+--------------+
| Dropout |
+------+-------+
|
v
Output y
Mamba Block¶
Input x
|
+-----------+-----------+
| |
v |
+-------------+ |
| Linear | |
| (expand) | |
+------+------+ |
| |
v |
+-------------+ |
| Conv1D | |
| (k=4) | |
+------+------+ |
| |
v |
+-------------+ |
| SiLU | |
+------+------+ |
| |
v |
+-------------+ |
| Selective | |
| SSM | |
| (S6 core) | |
+------+------+ |
| |
v v
+-------------+ +-------------+
| Element- |<--------| Linear |
| wise | | + SiLU |
| Multiply | +-------------+
+------+------+
|
v
+-------------+
| Linear |
| (project) |
+------+------+
|
v
Output
Selective SSM 상세¶
x (input)
|
+-------------+-------------+
| | |
v v v
+----------+ +----------+ +------------+
| Linear_B | | Linear_C | | Linear_dt |
+----+-----+ +----+-----+ +-----+------+
| | |
v v v
B C softplus(dt)
| | |
+-------------+-------+------+
|
v
+----------------+
| Discretize A,B |
| A_bar = e^(-dt*A)|
| B_bar = dt * B |
+-------+--------+
|
v
+----------------+
| Parallel Scan |
| h_k = A_bar*h |
| + B_bar*x |
+-------+--------+
|
v
+----------------+
| Output |
| y_k = C @ h_k |
+----------------+
|
v
Output
Mamba-2 (Structured State Space Duality)¶
Mamba-2: SSM = Attention with structured mask
SSM View: Attention View:
h_k = A*h_{k-1} + B*x_k y = softmax(Q K^T / sqrt(d)) V
y_k = C*h_k |
| |
+------------+----------------+
|
v
[Same Computation]
[Different Algorithm]
SSM: O(n) sequential scan
Attn: O(n^2) parallel matmul
Mamba-2 chooses optimal based on sequence length
대표 모델¶
| 모델 | 파라미터 | 아키텍처 | 특징 |
|---|---|---|---|
| Mamba | 130M~2.8B | Pure SSM | 최초의 선택적 SSM |
| Mamba-2 | 130M~2.8B | SSD (State Space Duality) | SSM-Attention 이중성 활용 |
| Falcon Mamba | 7B | Pure SSM | 최초의 대규모 순수 SSM |
| Zamba | 7B | Hybrid (Mamba + Attention) | 공유 어텐션 블록 |
| RWKV-6 | 1.6B~14B | RNN + Attention 특성 | Linear Attention 변형 |
시간/공간 복잡도¶
| 연산 | Transformer | Mamba |
|---|---|---|
| 학습 시간 | O(n^2 * d) | O(n * d * s) |
| 학습 메모리 | O(n^2 + n*d) | O(n * s) |
| 추론 시간 (토큰당) | O(n * d) | O(d * s) |
| 추론 메모리 | O(n * d) (KV cache) | O(d * s) (fixed) |
여기서: - n: 시퀀스 길이 - d: 모델 차원 - s: 상태 차원 (보통 16~64)
장단점¶
장점¶
- 선형 시간 복잡도: 시퀀스 길이에 대해 O(n)
- 고정 메모리 추론: KV 캐시 없이 상수 메모리 사용
- 무한 컨텍스트: 이론적으로 무한한 시퀀스 처리 가능
- 효율적 학습: 병렬 스캔으로 학습 속도 향상
단점¶
- In-context Learning 제한: 어텐션 대비 약한 검색 능력
- 성숙도 부족: Transformer 대비 생태계 미성숙
- 하드웨어 최적화: GPU 최적화가 Transformer만큼 발달하지 않음
- 정보 압축 손실: 고정 상태에 정보 압축 시 손실 가능
코드 예시¶
기본 SSM 구현¶
import torch
import torch.nn as nn
import torch.nn.functional as F
class SSMLayer(nn.Module):
"""Basic State Space Model layer"""
def __init__(self, d_model: int, state_dim: int = 16):
super().__init__()
self.d_model = d_model
self.state_dim = state_dim
# SSM parameters (learnable)
self.A = nn.Parameter(torch.randn(d_model, state_dim))
self.B = nn.Parameter(torch.randn(d_model, state_dim))
self.C = nn.Parameter(torch.randn(d_model, state_dim))
self.D = nn.Parameter(torch.ones(d_model))
# Discretization step
self.log_delta = nn.Parameter(torch.zeros(d_model))
def discretize(self):
"""Zero-order hold discretization"""
delta = F.softplus(self.log_delta) # Ensure positive
# Simplified discretization (Euler method)
A_bar = torch.exp(-delta.unsqueeze(-1) * self.A)
B_bar = delta.unsqueeze(-1) * self.B
return A_bar, B_bar
def forward(self, x):
# x: (batch, seq_len, d_model)
batch_size, seq_len, _ = x.shape
A_bar, B_bar = self.discretize()
# Initialize state
h = torch.zeros(batch_size, self.d_model, self.state_dim, device=x.device)
outputs = []
for k in range(seq_len):
x_k = x[:, k, :] # (batch, d_model)
# State update: h = A_bar * h + B_bar * x
h = A_bar * h + B_bar * x_k.unsqueeze(-1)
# Output: y = C @ h + D * x
y_k = (self.C * h).sum(dim=-1) + self.D * x_k
outputs.append(y_k)
return torch.stack(outputs, dim=1) # (batch, seq_len, d_model)
class SelectiveSSM(nn.Module):
"""Mamba-style Selective State Space Model"""
def __init__(self, d_model: int, state_dim: int = 16, dt_rank: int = None):
super().__init__()
self.d_model = d_model
self.state_dim = state_dim
self.dt_rank = dt_rank or (d_model // 16)
# Static A parameter (log-spaced initialization)
A = torch.arange(1, state_dim + 1).float()
self.A_log = nn.Parameter(torch.log(A).unsqueeze(0).expand(d_model, -1))
# Input-dependent projections
self.proj_B = nn.Linear(d_model, state_dim, bias=False)
self.proj_C = nn.Linear(d_model, state_dim, bias=False)
# Delta (discretization step) projection
self.proj_dt = nn.Linear(d_model, self.dt_rank, bias=False)
self.dt_proj = nn.Linear(self.dt_rank, d_model)
# Skip connection
self.D = nn.Parameter(torch.ones(d_model))
def forward(self, x):
# x: (batch, seq_len, d_model)
batch_size, seq_len, _ = x.shape
# Input-dependent parameters
B = self.proj_B(x) # (batch, seq_len, state_dim)
C = self.proj_C(x) # (batch, seq_len, state_dim)
# Compute delta (discretization step)
dt = self.dt_proj(self.proj_dt(x)) # (batch, seq_len, d_model)
dt = F.softplus(dt)
# Get A (negative for stability)
A = -torch.exp(self.A_log) # (d_model, state_dim)
# Discretize
A_bar = torch.exp(dt.unsqueeze(-1) * A) # (batch, seq_len, d_model, state_dim)
B_bar = dt.unsqueeze(-1) * B.unsqueeze(2) # (batch, seq_len, d_model, state_dim)
# Parallel scan (simplified sequential for clarity)
h = torch.zeros(batch_size, self.d_model, self.state_dim, device=x.device)
outputs = []
for k in range(seq_len):
# State update
h = A_bar[:, k] * h + B_bar[:, k] * x[:, k].unsqueeze(-1)
# Output
y_k = (C[:, k].unsqueeze(1) * h).sum(dim=-1) + self.D * x[:, k]
outputs.append(y_k)
return torch.stack(outputs, dim=1)
class MambaBlock(nn.Module):
"""Complete Mamba block"""
def __init__(self, d_model: int, expand: int = 2, state_dim: int = 16):
super().__init__()
d_inner = d_model * expand
self.in_proj = nn.Linear(d_model, d_inner * 2, bias=False)
self.conv = nn.Conv1d(
d_inner, d_inner,
kernel_size=4,
padding=3,
groups=d_inner
)
self.ssm = SelectiveSSM(d_inner, state_dim)
self.out_proj = nn.Linear(d_inner, d_model, bias=False)
self.norm = nn.LayerNorm(d_model)
def forward(self, x):
# x: (batch, seq_len, d_model)
residual = x
x = self.norm(x)
# Split into two paths
xz = self.in_proj(x)
x, z = xz.chunk(2, dim=-1)
# Conv path
x = x.transpose(1, 2) # (batch, d_inner, seq_len)
x = self.conv(x)[:, :, :x.shape[-1]] # Causal conv
x = x.transpose(1, 2)
x = F.silu(x)
# SSM
x = self.ssm(x)
# Gate and project
x = x * F.silu(z)
x = self.out_proj(x)
return x + residual
효율적 병렬 스캔 (개념)¶
def parallel_scan(A_bar, B_bar_x):
"""
Parallel associative scan for SSM.
The recurrence h[k] = A_bar[k] * h[k-1] + B_bar[k] * x[k]
can be computed in O(log n) parallel steps using associative scan.
Args:
A_bar: (batch, seq_len, d_model, state_dim) - discretized A
B_bar_x: (batch, seq_len, d_model, state_dim) - B_bar * x
This is a conceptual implementation.
Real implementation uses custom CUDA kernels for efficiency.
"""
# The key insight is that (A, Bx) pairs form a monoid under:
# (A1, Bx1) * (A2, Bx2) = (A1 * A2, A2 * Bx1 + Bx2)
# This allows parallel prefix sum computation
# See: Blelloch (1990) "Prefix Sums and Their Applications"
# Actual efficient implementation requires CUDA kernels
# See: mamba-ssm library for production implementation
pass
Mamba vs Transformer 비교¶
| 측면 | Mamba | Transformer |
|---|---|---|
| 시퀀스 복잡도 | O(n) | O(n^2) |
| 메모리 (추론) | O(1) | O(n) |
| 장거리 의존성 | 상태 압축 | 직접 참조 |
| 병렬화 | 스캔 기반 | 완전 병렬 |
| In-context Learning | 제한적 | 강력 |
| 생태계 | 초기 단계 | 성숙 |
참고 논문¶
- Gu, A., et al. (2021). "Efficiently Modeling Long Sequences with Structured State Spaces." (S4)
-
arXiv: https://arxiv.org/abs/2111.00396
-
Gu, A., & Dao, T. (2023). "Mamba: Linear-Time Sequence Modeling with Selective State Spaces."
-
arXiv: https://arxiv.org/abs/2312.00752
-
Dao, T., & Gu, A. (2024). "Transformers are SSMs: Generalized Models and Efficient Algorithms Through Structured State Space Duality." (Mamba-2)
-
arXiv: https://arxiv.org/abs/2405.21060
-
Gu, A., et al. (2020). "HiPPO: Recurrent Memory with Optimal Polynomial Projections."
-
arXiv: https://arxiv.org/abs/2008.07669
-
Peng, B., et al. (2023). "RWKV: Reinventing RNNs for the Transformer Era."
-
arXiv: https://arxiv.org/abs/2305.13048
-
Smith, J., et al. (2022). "Simplified State Space Layers for Sequence Modeling." (S5)
- arXiv: https://arxiv.org/abs/2208.04933