콘텐츠로 이동
Data Prep
상세

State Space Models (SSM)

Transformer의 대안으로 부상한 시퀀스 모델링 아키텍처. 선형 시간 복잡도로 긴 시퀀스를 효율적으로 처리한다.

Meta Information

항목 내용
분류 Deep Learning / Sequence Modeling
핵심 논문 S4 (NeurIPS 2021), Mamba (arXiv 2312.00752)
주요 저자 Albert Gu, Tri Dao (Carnegie Mellon, Princeton)
관련 기법 RNN, Transformer, Linear Attention
응용 분야 NLP, Audio, Vision, Time-series, Genomics

1. 개요

1.1 배경

모델 시간 복잡도 메모리 장거리 의존성
RNN O(L) O(1) Vanishing gradient 문제
Transformer O(L^2) O(L^2) 완전한 attention
SSM O(L) 또는 O(L log L) O(L) 구조화된 상태 공간

L = 시퀀스 길이

1.2 핵심 아이디어

State Space Models은 연속 시간 동역학 시스템을 이산화하여 시퀀스를 모델링한다:

연속 시스템:
x'(t) = Ax(t) + Bu(t)
y(t)  = Cx(t) + Du(t)

이산화 (Zero-Order Hold):
x_k = Ā x_{k-1} + B̄ u_k
y_k = C x_k + D u_k

여기서:
- x: 잠재 상태 (hidden state)
- u: 입력 시퀀스
- y: 출력 시퀀스
- A, B, C, D: 학습 가능한 파라미터

2. 주요 발전 과정

2.1 S4 (Structured State Space Sequence Model)

NeurIPS 2021 Outstanding Paper

핵심 기여: - HiPPO (High-order Polynomial Projection Operators) 행렬로 A 초기화 - 장거리 의존성 문제 해결 - FFT 기반 컨볼루션으로 병렬 학습

HiPPO 행렬 (Legendre 기반):
A_nk = -√(2n+1) √(2k+1)  if n > k
     = -(n+1)            if n = k
     = 0                 otherwise

2.2 S5 (Simplified S4)

  • 대각 상태 공간으로 단순화
  • MIMO (Multi-Input Multi-Output) 확장
  • 학습 안정성 개선

2.3 Mamba (2023)

핵심 혁신: Selective State Spaces

기존 SSM의 한계: 입력에 관계없이 동일한 A, B, C 사용 (content-agnostic)

Mamba의 해결책:

B_t = Linear_B(x_t)  # 입력 의존적
C_t = Linear_C(x_t)  # 입력 의존적
Δ_t = softplus(Linear_Δ(x_t))  # 선택적 게이팅

Hardware-Aware 알고리즘: - 커널 퓨전으로 메모리 I/O 최소화 - Recurrent 모드에서도 병렬 처리 - FlashAttention 스타일 최적화

2.4 Mamba-2 (ICML 2024)

State Space Duality (SSD)

Mamba-2는 SSM과 Attention의 이론적 연결을 밝힘: - 특정 조건에서 Structured SSM = Structured Attention - 더 효율적인 행렬 분해 알고리즘 - 8배 빠른 학습 속도

2.5 Jamba (AI21, 2024)

Mamba + Transformer 하이브리드: - Mamba 레이어와 Attention 레이어 혼합 - MoE (Mixture of Experts) 통합 - 256K 컨텍스트 지원


3. 아키텍처 상세

3.1 Mamba Block

┌─────────────────────────────────────────┐
│              Input (B, L, D)            │
├─────────────────────────────────────────┤
│           Linear Projection             │
│              D → 2*E                    │
├──────────────────┬──────────────────────┤
│    Branch 1      │      Branch 2        │
│   Conv1D + SiLU  │       SiLU           │
├──────────────────┤                      │
│   Selective SSM  │                      │
├──────────────────┴──────────────────────┤
│           Element-wise Multiply         │
├─────────────────────────────────────────┤
│           Linear Projection             │
│              E → D                      │
├─────────────────────────────────────────┤
│           Residual Connection           │
└─────────────────────────────────────────┘

3.2 Selective Scan Algorithm

# Pseudo-code
def selective_scan(u, delta, A, B, C):
    """
    u: (B, L, D) 입력
    delta: (B, L, D) 시간 스텝
    A: (D, N) 상태 행렬
    B: (B, L, N) 입력 행렬
    C: (B, L, N) 출력 행렬
    """
    # 이산화
    deltaA = exp(einsum('bld,dn->bldn', delta, A))
    deltaB_u = einsum('bld,bln,bld->bldn', delta, B, u)

    # 순차 스캔 (실제로는 병렬화)
    x = zeros(B, D, N)
    ys = []
    for i in range(L):
        x = deltaA[:, i] * x + deltaB_u[:, i]
        y = einsum('bdn,bn->bd', x, C[:, i])
        ys.append(y)

    return stack(ys, dim=1)

4. 성능 비교

4.1 언어 모델링 (Perplexity, 낮을수록 좋음)

모델 파라미터 The Pile (PPL)
Transformer 2.7B 8.21
Mamba 2.8B 7.82
Transformer 6.9B 7.50
Mamba 2.8B 7.82 (2.5x 작은 크기로 근접)

4.2 처리량 (Throughput)

시퀀스 길이에 따른 토큰/초 (A100 GPU):

길이 Transformer Mamba 속도 향상
2K 53K 90K 1.7x
8K 32K 85K 2.7x
32K 12K 78K 6.5x
128K OOM 72K -

4.3 Long Range Arena 벤치마크

Task Transformer S4 Mamba
ListOps 36.4 59.6 61.2
Text 64.3 86.8 87.5
Retrieval 57.5 90.9 91.3
Image 42.4 88.7 89.1
Pathfinder 71.4 94.2 94.8
Path-X FAIL 96.4 97.1
Average 54.4 86.1 86.8

5. Python 구현

5.1 기본 SSM 레이어

import torch
import torch.nn as nn
import torch.nn.functional as F
from einops import rearrange, einsum


class SimpleSSM(nn.Module):
    """단순화된 State Space Model 구현"""

    def __init__(self, d_model: int, d_state: int = 16):
        super().__init__()
        self.d_model = d_model
        self.d_state = d_state

        # SSM 파라미터
        self.A = nn.Parameter(torch.randn(d_model, d_state))
        self.B = nn.Parameter(torch.randn(d_model, d_state))
        self.C = nn.Parameter(torch.randn(d_model, d_state))
        self.D = nn.Parameter(torch.ones(d_model))

        # 시간 스텝
        self.dt = nn.Parameter(torch.ones(d_model) * 0.1)

        self._init_parameters()

    def _init_parameters(self):
        """HiPPO 스타일 초기화"""
        with torch.no_grad():
            # A를 음수로 초기화 (안정성)
            self.A.copy_(-torch.abs(self.A))

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        """
        Args:
            x: (batch, length, d_model)
        Returns:
            y: (batch, length, d_model)
        """
        batch, length, _ = x.shape

        # 이산화 (Zero-Order Hold)
        dt = F.softplus(self.dt)
        A_discrete = torch.exp(self.A * dt.unsqueeze(-1))
        B_discrete = self.B * dt.unsqueeze(-1)

        # Recurrent 계산
        state = torch.zeros(batch, self.d_model, self.d_state, device=x.device)
        outputs = []

        for t in range(length):
            # x_t: (batch, d_model)
            x_t = x[:, t, :]

            # 상태 업데이트: h_t = A * h_{t-1} + B * x_t
            state = state * A_discrete + x_t.unsqueeze(-1) * B_discrete

            # 출력: y_t = C * h_t + D * x_t
            y_t = (state * self.C).sum(dim=-1) + self.D * x_t
            outputs.append(y_t)

        return torch.stack(outputs, dim=1)


class SelectiveSSM(nn.Module):
    """Mamba 스타일 Selective SSM"""

    def __init__(self, d_model: int, d_state: int = 16, d_conv: int = 4):
        super().__init__()
        self.d_model = d_model
        self.d_state = d_state

        # 입력 의존적 파라미터 생성
        self.proj_b = nn.Linear(d_model, d_state, bias=False)
        self.proj_c = nn.Linear(d_model, d_state, bias=False)
        self.proj_dt = nn.Linear(d_model, d_model, bias=True)

        # 고정 파라미터
        self.A_log = nn.Parameter(torch.log(torch.arange(1, d_state + 1).float()))
        self.D = nn.Parameter(torch.ones(d_model))

        # 1D Convolution
        self.conv1d = nn.Conv1d(
            d_model, d_model, 
            kernel_size=d_conv, 
            padding=d_conv - 1,
            groups=d_model
        )

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        batch, length, d = x.shape

        # Conv1D
        x_conv = self.conv1d(x.transpose(1, 2))[:, :, :length].transpose(1, 2)
        x_conv = F.silu(x_conv)

        # Selective 파라미터 생성
        B = self.proj_b(x_conv)  # (batch, length, d_state)
        C = self.proj_c(x_conv)  # (batch, length, d_state)
        dt = F.softplus(self.proj_dt(x_conv))  # (batch, length, d_model)

        # A 이산화
        A = -torch.exp(self.A_log)  # (d_state,)

        # Selective scan
        y = self._selective_scan(x_conv, dt, A, B, C)

        # Skip connection
        return y + self.D * x

    def _selective_scan(self, u, delta, A, B, C):
        """Selective scan 알고리즘 (단순 버전)"""
        batch, length, d_model = u.shape
        d_state = A.shape[0]

        # 이산화
        deltaA = torch.exp(delta.unsqueeze(-1) * A)  # (B, L, D, N)
        deltaB_u = delta.unsqueeze(-1) * B.unsqueeze(2) * u.unsqueeze(-1)

        # 순차 스캔
        state = torch.zeros(batch, d_model, d_state, device=u.device)
        outputs = []

        for t in range(length):
            state = deltaA[:, t] * state + deltaB_u[:, t]
            y = (state * C[:, t].unsqueeze(1)).sum(dim=-1)
            outputs.append(y)

        return torch.stack(outputs, dim=1)

5.2 Mamba Block 구현

class MambaBlock(nn.Module):
    """완전한 Mamba 블록"""

    def __init__(
        self, 
        d_model: int, 
        d_state: int = 16, 
        expand: int = 2,
        d_conv: int = 4
    ):
        super().__init__()
        self.d_inner = d_model * expand

        # 입력 프로젝션
        self.in_proj = nn.Linear(d_model, self.d_inner * 2, bias=False)

        # Selective SSM
        self.ssm = SelectiveSSM(self.d_inner, d_state, d_conv)

        # 출력 프로젝션
        self.out_proj = nn.Linear(self.d_inner, d_model, bias=False)

        # Layer Norm
        self.norm = nn.LayerNorm(d_model)

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        residual = x
        x = self.norm(x)

        # Split into two branches
        xz = self.in_proj(x)
        x, z = xz.chunk(2, dim=-1)

        # SSM branch
        x = self.ssm(x)

        # Gating
        x = x * F.silu(z)

        # Output projection
        x = self.out_proj(x)

        return x + residual


class MambaModel(nn.Module):
    """Mamba 언어 모델"""

    def __init__(
        self,
        vocab_size: int,
        d_model: int = 768,
        n_layers: int = 12,
        d_state: int = 16,
        expand: int = 2
    ):
        super().__init__()
        self.embedding = nn.Embedding(vocab_size, d_model)
        self.layers = nn.ModuleList([
            MambaBlock(d_model, d_state, expand)
            for _ in range(n_layers)
        ])
        self.norm = nn.LayerNorm(d_model)
        self.lm_head = nn.Linear(d_model, vocab_size, bias=False)

        # Weight tying
        self.lm_head.weight = self.embedding.weight

    def forward(self, input_ids: torch.Tensor) -> torch.Tensor:
        x = self.embedding(input_ids)

        for layer in self.layers:
            x = layer(x)

        x = self.norm(x)
        return self.lm_head(x)


# 사용 예시
if __name__ == "__main__":
    model = MambaModel(
        vocab_size=32000,
        d_model=768,
        n_layers=12
    )

    # 파라미터 수
    params = sum(p.numel() for p in model.parameters())
    print(f"Parameters: {params / 1e6:.1f}M")

    # Forward pass
    x = torch.randint(0, 32000, (2, 1024))
    logits = model(x)
    print(f"Input: {x.shape}, Output: {logits.shape}")

5.3 시계열 예측 적용

class MambaForTimeSeries(nn.Module):
    """시계열 예측을 위한 Mamba 모델"""

    def __init__(
        self,
        input_dim: int,
        hidden_dim: int = 128,
        n_layers: int = 4,
        d_state: int = 16,
        forecast_horizon: int = 24
    ):
        super().__init__()
        self.input_proj = nn.Linear(input_dim, hidden_dim)
        self.layers = nn.ModuleList([
            MambaBlock(hidden_dim, d_state)
            for _ in range(n_layers)
        ])
        self.norm = nn.LayerNorm(hidden_dim)
        self.output_proj = nn.Linear(hidden_dim, forecast_horizon)

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        """
        Args:
            x: (batch, seq_len, input_dim)
        Returns:
            forecast: (batch, forecast_horizon)
        """
        x = self.input_proj(x)

        for layer in self.layers:
            x = layer(x)

        x = self.norm(x)

        # 마지막 타임스텝에서 예측
        return self.output_proj(x[:, -1, :])


# 시계열 학습 예시
def train_timeseries():
    import numpy as np

    # 합성 데이터 생성
    np.random.seed(42)
    T = 1000
    t = np.linspace(0, 100, T)
    data = np.sin(t) + 0.5 * np.sin(3 * t) + 0.1 * np.random.randn(T)

    # 시퀀스 생성
    seq_len, horizon = 96, 24
    X, y = [], []
    for i in range(len(data) - seq_len - horizon):
        X.append(data[i:i+seq_len])
        y.append(data[i+seq_len:i+seq_len+horizon])

    X = torch.tensor(np.array(X), dtype=torch.float32).unsqueeze(-1)
    y = torch.tensor(np.array(y), dtype=torch.float32)

    # 모델 학습
    model = MambaForTimeSeries(input_dim=1, forecast_horizon=horizon)
    optimizer = torch.optim.Adam(model.parameters(), lr=1e-3)

    for epoch in range(100):
        pred = model(X)
        loss = F.mse_loss(pred, y)

        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        if (epoch + 1) % 20 == 0:
            print(f"Epoch {epoch+1}, Loss: {loss.item():.4f}")

    return model


if __name__ == "__main__":
    train_timeseries()

6. 실무 적용 가이드

6.1 Mamba vs Transformer 선택 기준

상황 권장
시퀀스 < 2K Transformer
시퀀스 > 8K Mamba
긴 문서/책 처리 Mamba
추론 지연 중요 Mamba
기존 생태계 활용 Transformer
시계열/오디오/유전체 Mamba

6.2 하이퍼파라미터 권장값

# 기본 설정
d_model: 768         # 모델 차원
d_state: 16          # 상태 차원 (16-64)
expand: 2            # 확장 비율
d_conv: 4            # 컨볼루션 커널 크기

# 스케일링
small:   d_model=768,  n_layers=12   # ~125M params
medium:  d_model=1024, n_layers=24   # ~350M params
large:   d_model=2048, n_layers=48   # ~1.3B params

# 학습
learning_rate: 1e-4 to 5e-4
warmup_steps: 1000
weight_decay: 0.1
batch_size: 시퀀스 길이에 따라 조정

6.3 주의사항

  1. 초기화 중요성: A 행렬 초기화가 안정성에 큰 영향
  2. 수치 안정성: 긴 시퀀스에서 overflow 주의 (fp32/bf16 권장)
  3. 하드웨어 최적화: 공식 CUDA 커널 사용 시 성능 크게 향상
  4. 하이브리드 고려: 일부 태스크에서 Attention과 혼합이 효과적

7. 관련 자료

논문

구현

튜토리얼


최종 업데이트: 2026-02-08