콘텐츠로 이동

State Space Models (SSM) - Mamba

개요

State Space Models (SSM)은 연속 시간 시스템을 이산화하여 시퀀스 모델링에 적용하는 아키텍처다. Mamba는 SSM에 선택적 메커니즘(selective scan)을 도입하여 Transformer에 필적하는 성능을 달성한 모델로, 시퀀스 길이에 대해 선형 시간 복잡도를 가진다.

핵심 개념

State Space Model 기본

연속 시간 상태 공간 모델:

h'(t) = A h(t) + B x(t)     # State equation
y(t)  = C h(t) + D x(t)     # Output equation

여기서: - x(t): 입력 신호 - h(t): 숨겨진 상태 (hidden state) - y(t): 출력 신호 - A, B, C, D: 시스템 행렬

이산화 (Discretization)

연속 시스템을 이산 시퀀스로 변환:

Zero-Order Hold (ZOH):
A_bar = exp(delta * A)
B_bar = (delta * A)^(-1) * (exp(delta * A) - I) * delta * B

Simplified (Euler method):
A_bar = I + delta * A
B_bar = delta * B

이산화된 재귀 관계:

h_k = A_bar * h_{k-1} + B_bar * x_k
y_k = C * h_k

S4 (Structured State Space)

S4는 HiPPO 행렬을 사용하여 장거리 의존성을 효과적으로 모델링:

HiPPO Matrix (LegS):
A_nk = -sqrt(2n+1) * sqrt(2k+1)  if n > k
     = n + 1                      if n = k
     = 0                          if n < k

Selective Scan (Mamba의 핵심)

기존 SSM은 입력에 무관하게 동일한 파라미터를 사용하지만, Mamba는 입력에 따라 B, C, delta를 동적으로 결정:

# Input-dependent parameters
B = Linear_B(x)       # (batch, seq_len, state_dim)
C = Linear_C(x)       # (batch, seq_len, state_dim)
delta = softplus(Linear_delta(x))  # (batch, seq_len, d_model)

# Discretization with input-dependent delta
A_bar = exp(-delta * A)
B_bar = delta * B

# Selective scan
for k in range(seq_len):
    h[k] = A_bar[k] * h[k-1] + B_bar[k] * x[k]
    y[k] = C[k] @ h[k]

아키텍처 다이어그램

S4 Layer

        Input x
           |
           v
    +--------------+
    |   Conv1D     |  (local context)
    +------+-------+
           |
           v
    +--------------+
    |     SSM      |
    |  h' = Ah+Bx  |
    |  y  = Ch     |
    +------+-------+
           |
           v
    +--------------+
    |   Dropout    |
    +------+-------+
           |
           v
       Output y

Mamba Block

                    Input x
                       |
           +-----------+-----------+
           |                       |
           v                       |
    +-------------+                |
    |   Linear    |                |
    | (expand)    |                |
    +------+------+                |
           |                       |
           v                       |
    +-------------+                |
    |   Conv1D    |                |
    | (k=4)       |                |
    +------+------+                |
           |                       |
           v                       |
    +-------------+                |
    |    SiLU     |                |
    +------+------+                |
           |                       |
           v                       |
    +-------------+                |
    | Selective   |                |
    |   SSM       |                |
    | (S6 core)   |                |
    +------+------+                |
           |                       |
           v                       v
    +-------------+         +-------------+
    |  Element-   |<--------|   Linear    |
    |  wise       |         |   + SiLU    |
    |  Multiply   |         +-------------+
    +------+------+
           |
           v
    +-------------+
    |   Linear    |
    | (project)   |
    +------+------+
           |
           v
        Output

Selective SSM 상세

                         x (input)
                            |
              +-------------+-------------+
              |             |             |
              v             v             v
        +----------+  +----------+  +------------+
        | Linear_B |  | Linear_C |  | Linear_dt  |
        +----+-----+  +----+-----+  +-----+------+
             |             |              |
             v             v              v
             B             C         softplus(dt)
             |             |              |
             +-------------+-------+------+
                                   |
                                   v
                          +----------------+
                          | Discretize A,B |
                          | A_bar = e^(-dt*A)|
                          | B_bar = dt * B |
                          +-------+--------+
                                  |
                                  v
                          +----------------+
                          | Parallel Scan  |
                          | h_k = A_bar*h  |
                          |     + B_bar*x  |
                          +-------+--------+
                                  |
                                  v
                          +----------------+
                          | Output         |
                          | y_k = C @ h_k  |
                          +----------------+
                                  |
                                  v
                               Output

Mamba-2 (Structured State Space Duality)

    Mamba-2: SSM = Attention with structured mask

    SSM View:                    Attention View:
    h_k = A*h_{k-1} + B*x_k     y = softmax(Q K^T / sqrt(d)) V
    y_k = C*h_k                        |
         |                             |
         +------------+----------------+
                      |
                      v
              [Same Computation]
              [Different Algorithm]

    SSM:  O(n) sequential scan
    Attn: O(n^2) parallel matmul

    Mamba-2 chooses optimal based on sequence length

대표 모델

모델 파라미터 아키텍처 특징
Mamba 130M~2.8B Pure SSM 최초의 선택적 SSM
Mamba-2 130M~2.8B SSD (State Space Duality) SSM-Attention 이중성 활용
Falcon Mamba 7B Pure SSM 최초의 대규모 순수 SSM
Zamba 7B Hybrid (Mamba + Attention) 공유 어텐션 블록
RWKV-6 1.6B~14B RNN + Attention 특성 Linear Attention 변형

시간/공간 복잡도

연산 Transformer Mamba
학습 시간 O(n^2 * d) O(n * d * s)
학습 메모리 O(n^2 + n*d) O(n * s)
추론 시간 (토큰당) O(n * d) O(d * s)
추론 메모리 O(n * d) (KV cache) O(d * s) (fixed)

여기서: - n: 시퀀스 길이 - d: 모델 차원 - s: 상태 차원 (보통 16~64)

장단점

장점

  1. 선형 시간 복잡도: 시퀀스 길이에 대해 O(n)
  2. 고정 메모리 추론: KV 캐시 없이 상수 메모리 사용
  3. 무한 컨텍스트: 이론적으로 무한한 시퀀스 처리 가능
  4. 효율적 학습: 병렬 스캔으로 학습 속도 향상

단점

  1. In-context Learning 제한: 어텐션 대비 약한 검색 능력
  2. 성숙도 부족: Transformer 대비 생태계 미성숙
  3. 하드웨어 최적화: GPU 최적화가 Transformer만큼 발달하지 않음
  4. 정보 압축 손실: 고정 상태에 정보 압축 시 손실 가능

코드 예시

기본 SSM 구현

import torch
import torch.nn as nn
import torch.nn.functional as F

class SSMLayer(nn.Module):
    """Basic State Space Model layer"""
    def __init__(self, d_model: int, state_dim: int = 16):
        super().__init__()
        self.d_model = d_model
        self.state_dim = state_dim

        # SSM parameters (learnable)
        self.A = nn.Parameter(torch.randn(d_model, state_dim))
        self.B = nn.Parameter(torch.randn(d_model, state_dim))
        self.C = nn.Parameter(torch.randn(d_model, state_dim))
        self.D = nn.Parameter(torch.ones(d_model))

        # Discretization step
        self.log_delta = nn.Parameter(torch.zeros(d_model))

    def discretize(self):
        """Zero-order hold discretization"""
        delta = F.softplus(self.log_delta)  # Ensure positive

        # Simplified discretization (Euler method)
        A_bar = torch.exp(-delta.unsqueeze(-1) * self.A)
        B_bar = delta.unsqueeze(-1) * self.B

        return A_bar, B_bar

    def forward(self, x):
        # x: (batch, seq_len, d_model)
        batch_size, seq_len, _ = x.shape

        A_bar, B_bar = self.discretize()

        # Initialize state
        h = torch.zeros(batch_size, self.d_model, self.state_dim, device=x.device)

        outputs = []
        for k in range(seq_len):
            x_k = x[:, k, :]  # (batch, d_model)

            # State update: h = A_bar * h + B_bar * x
            h = A_bar * h + B_bar * x_k.unsqueeze(-1)

            # Output: y = C @ h + D * x
            y_k = (self.C * h).sum(dim=-1) + self.D * x_k

            outputs.append(y_k)

        return torch.stack(outputs, dim=1)  # (batch, seq_len, d_model)


class SelectiveSSM(nn.Module):
    """Mamba-style Selective State Space Model"""
    def __init__(self, d_model: int, state_dim: int = 16, dt_rank: int = None):
        super().__init__()
        self.d_model = d_model
        self.state_dim = state_dim
        self.dt_rank = dt_rank or (d_model // 16)

        # Static A parameter (log-spaced initialization)
        A = torch.arange(1, state_dim + 1).float()
        self.A_log = nn.Parameter(torch.log(A).unsqueeze(0).expand(d_model, -1))

        # Input-dependent projections
        self.proj_B = nn.Linear(d_model, state_dim, bias=False)
        self.proj_C = nn.Linear(d_model, state_dim, bias=False)

        # Delta (discretization step) projection
        self.proj_dt = nn.Linear(d_model, self.dt_rank, bias=False)
        self.dt_proj = nn.Linear(self.dt_rank, d_model)

        # Skip connection
        self.D = nn.Parameter(torch.ones(d_model))

    def forward(self, x):
        # x: (batch, seq_len, d_model)
        batch_size, seq_len, _ = x.shape

        # Input-dependent parameters
        B = self.proj_B(x)  # (batch, seq_len, state_dim)
        C = self.proj_C(x)  # (batch, seq_len, state_dim)

        # Compute delta (discretization step)
        dt = self.dt_proj(self.proj_dt(x))  # (batch, seq_len, d_model)
        dt = F.softplus(dt)

        # Get A (negative for stability)
        A = -torch.exp(self.A_log)  # (d_model, state_dim)

        # Discretize
        A_bar = torch.exp(dt.unsqueeze(-1) * A)  # (batch, seq_len, d_model, state_dim)
        B_bar = dt.unsqueeze(-1) * B.unsqueeze(2)  # (batch, seq_len, d_model, state_dim)

        # Parallel scan (simplified sequential for clarity)
        h = torch.zeros(batch_size, self.d_model, self.state_dim, device=x.device)
        outputs = []

        for k in range(seq_len):
            # State update
            h = A_bar[:, k] * h + B_bar[:, k] * x[:, k].unsqueeze(-1)

            # Output
            y_k = (C[:, k].unsqueeze(1) * h).sum(dim=-1) + self.D * x[:, k]
            outputs.append(y_k)

        return torch.stack(outputs, dim=1)


class MambaBlock(nn.Module):
    """Complete Mamba block"""
    def __init__(self, d_model: int, expand: int = 2, state_dim: int = 16):
        super().__init__()
        d_inner = d_model * expand

        self.in_proj = nn.Linear(d_model, d_inner * 2, bias=False)

        self.conv = nn.Conv1d(
            d_inner, d_inner,
            kernel_size=4,
            padding=3,
            groups=d_inner
        )

        self.ssm = SelectiveSSM(d_inner, state_dim)

        self.out_proj = nn.Linear(d_inner, d_model, bias=False)

        self.norm = nn.LayerNorm(d_model)

    def forward(self, x):
        # x: (batch, seq_len, d_model)
        residual = x
        x = self.norm(x)

        # Split into two paths
        xz = self.in_proj(x)
        x, z = xz.chunk(2, dim=-1)

        # Conv path
        x = x.transpose(1, 2)  # (batch, d_inner, seq_len)
        x = self.conv(x)[:, :, :x.shape[-1]]  # Causal conv
        x = x.transpose(1, 2)
        x = F.silu(x)

        # SSM
        x = self.ssm(x)

        # Gate and project
        x = x * F.silu(z)
        x = self.out_proj(x)

        return x + residual

효율적 병렬 스캔 (개념)

def parallel_scan(A_bar, B_bar_x):
    """
    Parallel associative scan for SSM.

    The recurrence h[k] = A_bar[k] * h[k-1] + B_bar[k] * x[k]
    can be computed in O(log n) parallel steps using associative scan.

    Args:
        A_bar: (batch, seq_len, d_model, state_dim) - discretized A
        B_bar_x: (batch, seq_len, d_model, state_dim) - B_bar * x

    This is a conceptual implementation.
    Real implementation uses custom CUDA kernels for efficiency.
    """
    # The key insight is that (A, Bx) pairs form a monoid under:
    # (A1, Bx1) * (A2, Bx2) = (A1 * A2, A2 * Bx1 + Bx2)

    # This allows parallel prefix sum computation
    # See: Blelloch (1990) "Prefix Sums and Their Applications"

    # Actual efficient implementation requires CUDA kernels
    # See: mamba-ssm library for production implementation
    pass

Mamba vs Transformer 비교

측면 Mamba Transformer
시퀀스 복잡도 O(n) O(n^2)
메모리 (추론) O(1) O(n)
장거리 의존성 상태 압축 직접 참조
병렬화 스캔 기반 완전 병렬
In-context Learning 제한적 강력
생태계 초기 단계 성숙

참고 논문

  1. Gu, A., et al. (2021). "Efficiently Modeling Long Sequences with Structured State Spaces." (S4)
  2. arXiv: https://arxiv.org/abs/2111.00396

  3. Gu, A., & Dao, T. (2023). "Mamba: Linear-Time Sequence Modeling with Selective State Spaces."

  4. arXiv: https://arxiv.org/abs/2312.00752

  5. Dao, T., & Gu, A. (2024). "Transformers are SSMs: Generalized Models and Efficient Algorithms Through Structured State Space Duality." (Mamba-2)

  6. arXiv: https://arxiv.org/abs/2405.21060

  7. Gu, A., et al. (2020). "HiPPO: Recurrent Memory with Optimal Polynomial Projections."

  8. arXiv: https://arxiv.org/abs/2008.07669

  9. Peng, B., et al. (2023). "RWKV: Reinventing RNNs for the Transformer Era."

  10. arXiv: https://arxiv.org/abs/2305.13048

  11. Smith, J., et al. (2022). "Simplified State Space Layers for Sequence Modeling." (S5)

  12. arXiv: https://arxiv.org/abs/2208.04933