Hybrid Architecture¶
개요¶
하이브리드 아키텍처는 Transformer의 어텐션 메커니즘과 SSM(State Space Model)의 효율성을 결합한 구조다. Transformer의 강력한 in-context learning 능력과 SSM의 선형 시간 복잡도를 모두 활용하여 긴 시퀀스를 효율적으로 처리한다.
핵심 개념¶
왜 하이브리드인가?¶
Transformer와 SSM은 상호 보완적인 특성을 가진다:
+--------------------+------------------+--------------------+
| 특성 | Transformer | SSM (Mamba) |
+--------------------+------------------+--------------------+
| 시간 복잡도 | O(n^2) | O(n) |
| 메모리 (추론) | O(n) KV cache | O(1) fixed state |
| In-context Learning| 강력 | 제한적 |
| 장거리 의존성 | 직접 참조 | 상태 압축 |
| 정보 검색 | 정확 | 근사 |
+--------------------+------------------+--------------------+
하이브리드는 두 아키텍처의 장점을 결합: - 대부분의 레이어: SSM으로 효율적 처리 - 일부 레이어: Attention으로 정확한 정보 검색
블록 배치 전략¶
Interleaved (교차 배치):
Ratio-based (비율 기반):
Task-adaptive (태스크 적응형):
아키텍처 다이어그램¶
Jamba 기본 블록 구조¶
Input
|
v
+-------------------+--------------------+
| Jamba Block |
| |
| +--------+ or +------------+ |
| | Mamba | | Attention | |
| | Layer | | Layer | |
| +---+----+ +-----+------+ |
| | | |
| +----------+----------+ |
| | |
| +----------v----------+ |
| | MoE / MLP | |
| | (Feed Forward) | |
| +----------+----------+ |
| | |
+------------------+---------------------+
|
v
Output
Jamba 전체 구조 (52B 모델 예시)¶
+--------------------------------------------------+
| Jamba Model |
+--------------------------------------------------+
| |
| Layer 1-7: [Mamba] [Mamba] ... [Mamba] (7x) |
| Layer 8: [Attention + MoE] |
| |
| Layer 9-15: [Mamba] [Mamba] ... [Mamba] (7x) |
| Layer 16: [Attention + MoE] |
| |
| ... |
| |
| Total: 72 layers |
| - 64 Mamba layers |
| - 8 Attention + MoE layers |
| |
+--------------------------------------------------+
Layer ratio: 1:7 (Attention:Mamba)
Jamba 1.5 상세 구조¶
[Input Tokens]
|
v
+----------------+
| Embedding |
+-------+--------+
|
+==========================================+
| REPEATED BLOCK (x9) |
| |
| +------------------------------------+ |
| | MAMBA BLOCK (x7) | |
| | +------------------------------+ | |
| | | LayerNorm | | |
| | +-------------+----------------+ | |
| | | | |
| | +-------------v----------------+ | |
| | | Mamba Layer | | |
| | | (Selective SSM) | | |
| | +-------------+----------------+ | |
| | | | |
| | +-------------v----------------+ | |
| | | MLP (Dense) | | |
| | +------------------------------+ | |
| +------------------------------------+ |
| |
| +------------------------------------+ |
| | ATTENTION + MoE BLOCK (x1) | |
| | +------------------------------+ | |
| | | LayerNorm | | |
| | +-------------+----------------+ | |
| | | | |
| | +-------------v----------------+ | |
| | | Grouped-Query Attention | | |
| | | (GQA, 8 KV heads) | | |
| | +-------------+----------------+ | |
| | | | |
| | +-------------v----------------+ | |
| | | MoE Layer | | |
| | | (16 experts, top-2) | | |
| | +------------------------------+ | |
| +------------------------------------+ |
| |
+==========================================+
|
v
+----------------+
| LayerNorm |
+-------+--------+
|
v
+----------------+
| LM Head |
+----------------+
Zamba 구조 (Shared Attention)¶
Zamba: 공유 어텐션 블록 사용
+----------------------------------------------------+
| |
| Block 1: [Mamba] -> [Shared Attention] -> [MLP] |
| Block 2: [Mamba] -> [Shared Attention] -> [MLP] |
| Block 3: [Mamba] -> [Shared Attention] -> [MLP] |
| ... |
| |
| * Shared Attention: 동일한 가중치를 모든 블록에서 공유 |
| * 메모리 효율성 향상 |
| |
+----------------------------------------------------+
StripedHyena 구조¶
StripedHyena: 교대 배치 (Alternating)
+--------------------------------------------+
| Layer 1: [Hyena (Long Conv)] |
| Layer 2: [Attention] |
| Layer 3: [Hyena (Long Conv)] |
| Layer 4: [Attention] |
| ... |
| |
| * 1:1 비율로 교대 배치 |
| * Hyena: 암시적 긴 합성곱 |
+--------------------------------------------+
대표 모델¶
| 모델 | 총 파라미터 | 활성 파라미터 | 구조 | 컨텍스트 |
|---|---|---|---|---|
| Jamba 1.5 Large | 398B | 94B | Mamba + Attention + MoE | 256K |
| Jamba 1.5 Mini | 52B | 12B | Mamba + Attention + MoE | 256K |
| Jamba (원본) | 52B | 12B | Mamba + Attention + MoE | 256K |
| Zamba 7B | 7.2B | 7.2B | Mamba + Shared Attention | 4K |
| StripedHyena | 7B | 7B | Hyena + Attention | 128K |
장단점¶
장점¶
- 효율적 긴 컨텍스트: SSM의 선형 복잡도로 긴 시퀀스 처리
- 강력한 검색 능력: Attention 레이어로 정확한 정보 참조
- 메모리 효율성: 대부분 SSM이므로 KV 캐시 크기 감소
- 유연한 설계: 태스크에 따라 Attention/Mamba 비율 조절 가능
단점¶
- 설계 복잡성: 최적 비율/배치 찾기 어려움
- 학습 난이도: 두 아키텍처의 상호작용 이해 필요
- 초기 단계: 아직 연구가 활발히 진행 중
- 툴체인 미성숙: Transformer 전용 최적화 활용 제한
코드 예시¶
Jamba 스타일 하이브리드 블록¶
import torch
import torch.nn as nn
from typing import Optional
# Assume MambaBlock and MultiHeadAttention are defined elsewhere
# from ssm import MambaBlock
# from transformer import MultiHeadAttention
class HybridBlock(nn.Module):
"""Single hybrid block that can be either Mamba or Attention"""
def __init__(
self,
d_model: int,
use_attention: bool = False,
n_heads: int = 8,
d_ff: int = None,
state_dim: int = 16,
dropout: float = 0.1
):
super().__init__()
self.use_attention = use_attention
d_ff = d_ff or d_model * 4
self.norm1 = nn.LayerNorm(d_model)
if use_attention:
self.layer = MultiHeadAttention(d_model, n_heads, dropout)
else:
self.layer = MambaBlock(d_model, state_dim=state_dim)
self.norm2 = nn.LayerNorm(d_model)
self.ffn = nn.Sequential(
nn.Linear(d_model, d_ff),
nn.GELU(),
nn.Dropout(dropout),
nn.Linear(d_ff, d_model),
nn.Dropout(dropout)
)
def forward(self, x, mask: Optional[torch.Tensor] = None):
# Pre-norm architecture
if self.use_attention:
x = x + self.layer(self.norm1(x), mask)
else:
x = x + self.layer(self.norm1(x))
x = x + self.ffn(self.norm2(x))
return x
class JambaModel(nn.Module):
"""
Jamba-style hybrid model.
Uses a 1:7 ratio of Attention:Mamba layers.
Every 8th layer uses Attention + MoE.
"""
def __init__(
self,
vocab_size: int,
d_model: int = 4096,
n_layers: int = 72,
n_heads: int = 32,
d_ff: int = 14336,
attention_frequency: int = 8, # 1 attention per 8 layers
use_moe: bool = True,
num_experts: int = 16,
top_k_experts: int = 2,
state_dim: int = 16,
dropout: float = 0.1
):
super().__init__()
self.embedding = nn.Embedding(vocab_size, d_model)
self.layers = nn.ModuleList()
for i in range(n_layers):
use_attention = ((i + 1) % attention_frequency == 0)
if use_attention and use_moe:
# Attention + MoE layer
layer = HybridAttentionMoEBlock(
d_model, n_heads, d_ff,
num_experts, top_k_experts, dropout
)
elif use_attention:
# Attention + Dense FFN
layer = HybridBlock(
d_model, use_attention=True,
n_heads=n_heads, d_ff=d_ff, dropout=dropout
)
else:
# Mamba + Dense FFN
layer = HybridBlock(
d_model, use_attention=False,
state_dim=state_dim, d_ff=d_ff, dropout=dropout
)
self.layers.append(layer)
self.norm = nn.LayerNorm(d_model)
self.lm_head = nn.Linear(d_model, vocab_size, bias=False)
def forward(self, input_ids, attention_mask=None):
x = self.embedding(input_ids)
for layer in self.layers:
x = layer(x, attention_mask)
x = self.norm(x)
logits = self.lm_head(x)
return logits
class HybridAttentionMoEBlock(nn.Module):
"""Attention block with MoE FFN (for every 8th layer in Jamba)"""
def __init__(
self,
d_model: int,
n_heads: int,
d_ff: int,
num_experts: int,
top_k: int,
dropout: float = 0.1
):
super().__init__()
self.norm1 = nn.LayerNorm(d_model)
self.attention = MultiHeadAttention(d_model, n_heads, dropout)
self.norm2 = nn.LayerNorm(d_model)
self.moe = MoELayer(d_model, d_ff, num_experts, top_k, dropout)
def forward(self, x, mask=None):
x = x + self.attention(self.norm1(x), mask)
moe_out, router_logits = self.moe(self.norm2(x))
x = x + moe_out
return x
Zamba 스타일 (공유 어텐션)¶
class ZambaModel(nn.Module):
"""
Zamba-style model with shared attention.
All blocks share the same attention layer weights,
reducing memory footprint significantly.
"""
def __init__(
self,
vocab_size: int,
d_model: int = 3072,
n_layers: int = 76,
n_heads: int = 24,
state_dim: int = 64,
dropout: float = 0.1
):
super().__init__()
self.embedding = nn.Embedding(vocab_size, d_model)
# Shared attention (same weights for all layers)
self.shared_attention = MultiHeadAttention(d_model, n_heads, dropout)
self.attn_norm = nn.LayerNorm(d_model)
# Individual Mamba blocks
self.mamba_blocks = nn.ModuleList([
MambaBlock(d_model, state_dim=state_dim)
for _ in range(n_layers)
])
self.mamba_norms = nn.ModuleList([
nn.LayerNorm(d_model) for _ in range(n_layers)
])
# Individual MLPs
self.mlps = nn.ModuleList([
nn.Sequential(
nn.Linear(d_model, d_model * 4),
nn.GELU(),
nn.Linear(d_model * 4, d_model),
nn.Dropout(dropout)
)
for _ in range(n_layers)
])
self.mlp_norms = nn.ModuleList([
nn.LayerNorm(d_model) for _ in range(n_layers)
])
self.final_norm = nn.LayerNorm(d_model)
self.lm_head = nn.Linear(d_model, vocab_size, bias=False)
def forward(self, input_ids, attention_mask=None):
x = self.embedding(input_ids)
for i in range(len(self.mamba_blocks)):
# Mamba processing
x = x + self.mamba_blocks[i](self.mamba_norms[i](x))
# Shared attention
x = x + self.shared_attention(self.attn_norm(x), attention_mask)
# MLP
x = x + self.mlps[i](self.mlp_norms[i](x))
x = self.final_norm(x)
return self.lm_head(x)
설계 고려사항¶
Attention/Mamba 비율 선택¶
| 비율 | 특징 | 적합한 경우 |
|---|---|---|
| 1:7 (Jamba) | 효율성 우선 | 긴 컨텍스트, 추론 비용 민감 |
| 1:3 | 균형 | 일반적인 사용 |
| 1:1 | 검색 능력 우선 | In-context learning 중요 |
MoE 통합 전략¶
- 모든 레이어: 파라미터 효율적이나 학습 어려움
- Attention 레이어만: Jamba 접근법, 안정적
- Mamba 레이어만: 실험적, 연구 중
참고 논문¶
- Lieber, O., et al. (2024). "Jamba: A Hybrid Transformer-Mamba Language Model."
-
arXiv: https://arxiv.org/abs/2403.19887
-
AI21 Labs. (2024). "Jamba 1.5: Hybrid Transformer-Mamba Models at Scale."
-
https://www.ai21.com/blog/announcing-jamba-1-5/
-
Glorioso, P., et al. (2024). "Zamba: A Compact 7B SSM Hybrid Model."
-
arXiv: https://arxiv.org/abs/2405.18712
-
Poli, M., et al. (2023). "Hyena Hierarchy: Towards Larger Convolutional Language Models."
-
arXiv: https://arxiv.org/abs/2302.10866
-
Together AI. (2023). "StripedHyena: Moving Beyond Transformers with Hybrid Signal Processing Models."
-
https://www.together.ai/blog/stripedhyena-7b
-
Waleffe, R., et al. (2024). "An Empirical Study of Mamba-based Language Models."
- arXiv: https://arxiv.org/abs/2406.07887