Mixture-of-Depths (MoD)
개요
Mixture-of-Depths는 Google DeepMind가 2024년에 발표한 동적 연산 할당 기법이다. 기존 트랜스포머가 모든 토큰에 동일한 연산량을 할당하는 것과 달리, MoD는 토큰별로 필요한 연산량을 동적으로 결정한다.
| 항목 |
내용 |
| 논문 |
Mixture-of-Depths: Dynamically allocating compute in transformer-based language models |
| 저자 |
David Raposo, Sam Ritter, Blake Richards, Timothy Lillicrap, Peter Conway Humphreys, Adam Santoro |
| arXiv |
2404.02258 |
| 소속 |
Google DeepMind |
| 발표 |
2024년 4월 |
핵심 아이디어
기존 트랜스포머의 문제
Standard Transformer:
Layer 1: [tok1] [tok2] [tok3] [tok4] [tok5] <- 모든 토큰 처리
Layer 2: [tok1] [tok2] [tok3] [tok4] [tok5] <- 모든 토큰 처리
Layer 3: [tok1] [tok2] [tok3] [tok4] [tok5] <- 모든 토큰 처리
...
모든 레이어에서 모든 토큰에 동일한 FLOPs 할당
문제점:
- 일부 토큰은 "쉬움" (예: 관사 "the", "a")
- 일부 토큰은 "어려움" (예: 전문 용어, 맥락 의존적 단어)
- 모든 토큰에 동일 연산 = 비효율
MoD 접근법
Mixture-of-Depths:
Layer 1: [tok1] - [tok3] - [tok5] <- top-k 토큰만 처리
Layer 2: - [tok2] - [tok4] [tok5] <- 다른 top-k 토큰
Layer 3: [tok1] [tok2] - - [tok5] <- 또 다른 top-k 토큰
...
각 레이어에서 가장 중요한 k개 토큰만 연산
핵심 원리:
1. Capacity Factor: 각 레이어에서 처리할 토큰 비율 (예: 12.5%)
2. Router: 어떤 토큰을 처리할지 결정
3. Residual Connection: 처리되지 않은 토큰은 skip
아키텍처
Router 메커니즘
입력 시퀀스: [x_1, x_2, ..., x_n]
1. Router 점수 계산:
r_i = W_r * x_i (스칼라 점수)
2. Top-k 선택:
k = capacity_factor * n
S = top_k_indices(r_1, ..., r_n)
3. 선택된 토큰만 처리:
for i in S:
x_i = Attention(x_i) + MLP(x_i)
for i not in S:
x_i = x_i (skip, residual만)
전체 구조
┌─────────────────────────────────────────────────────────┐
│ Input Sequence │
│ [tok1, tok2, tok3, ..., tokN] │
└─────────────────────────────────────────────────────────┘
│
▼
┌─────────────────────────────────────────────────────────┐
│ Router (Layer 1) │
│ scores = W_r @ X │
│ selected_idx = top_k(scores, k=capacity*N) │
└─────────────────────────────────────────────────────────┘
│
┌─────────────┴─────────────┐
│ │
▼ ▼
┌─────────────────────────┐ ┌─────────────────────────┐
│ Selected Tokens │ │ Skipped Tokens │
│ (k tokens) │ │ (N-k tokens) │
├─────────────────────────┤ ├─────────────────────────┤
│ Self-Attention │ │ │
│ + │ │ Identity (Residual) │
│ MLP │ │ │
└─────────────────────────┘ └─────────────────────────┘
│ │
└─────────────┬─────────────┘
│
▼
┌─────────────────────────────────────────────────────────┐
│ Combined Output │
│ (재정렬하여 원래 순서로) │
└─────────────────────────────────────────────────────────┘
│
▼
(다음 레이어 반복)
MoD vs MoE 비교
| 측면 |
Mixture-of-Experts (MoE) |
Mixture-of-Depths (MoD) |
| 라우팅 대상 |
Expert (MLP 변형) |
Layer (전체 블록) |
| 동적 요소 |
어떤 expert를 사용할지 |
얼마나 깊이 처리할지 |
| 파라미터 |
증가 (다중 expert) |
동일 유지 |
| FLOPs |
고정 (선택된 expert만) |
감소 (일부 토큰 skip) |
| 추론 속도 |
메모리 바운드 |
FLOP 절약으로 빨라짐 |
결합 가능성
MoD + MoE 결합:
1. MoD로 어떤 토큰을 처리할지 결정
2. 선택된 토큰에 대해 MoE로 어떤 expert를 사용할지 결정
결과: 토큰별, 레이어별, Expert별 동적 연산
학습 방법
Auxiliary Loss
Router가 의미있는 선택을 하도록 보조 손실 함수 사용:
L_total = L_language + alpha * L_router
L_router: load balancing loss (토큰 분포 균형)
Capacity Factor 설정
| Capacity Factor |
처리 토큰 비율 |
FLOPs 절약 |
성능 영향 |
| 1.0 |
100% |
0% |
베이스라인 |
| 0.5 |
50% |
~50% |
미미 |
| 0.25 |
25% |
~75% |
약간 |
| 0.125 |
12.5% |
~87.5% |
측정 가능 |
논문 결과: capacity_factor=0.125에서도 베이스라인과 동등한 성능 달성
성능 결과
학습 효율성
| 모델 |
FLOPs (학습) |
FLOPs (추론) |
성능 |
| Baseline (12.5B) |
1x |
1x |
1.0 |
| MoD (12.5B) |
1x |
0.5x |
1.0 |
| MoD (12.5B, isoFLOP) |
0.66x |
0.33x |
1.0 |
추론 속도
- 최대 50% 빠른 샘플링 (동일 품질)
- Static compute graph로 효율적 배치 처리
- KV-cache와 완전 호환
Python 구현
기본 Router
import torch
import torch.nn as nn
import torch.nn.functional as F
class MoDRouter(nn.Module):
"""Mixture-of-Depths Router"""
def __init__(self, dim: int, capacity_factor: float = 0.125):
super().__init__()
self.capacity_factor = capacity_factor
self.router = nn.Linear(dim, 1, bias=False)
def forward(self, x: torch.Tensor) -> tuple:
"""
Args:
x: (batch, seq_len, dim)
Returns:
selected_mask: (batch, seq_len) bool tensor
router_weights: (batch, seq_len) for aux loss
"""
batch, seq_len, dim = x.shape
# Router 점수 계산
router_logits = self.router(x).squeeze(-1) # (batch, seq_len)
router_weights = torch.sigmoid(router_logits)
# Top-k 선택
k = int(seq_len * self.capacity_factor)
k = max(1, k) # 최소 1개
# 각 배치에서 top-k 인덱스 선택
_, indices = torch.topk(router_logits, k, dim=-1)
# 마스크 생성
selected_mask = torch.zeros(batch, seq_len, dtype=torch.bool, device=x.device)
selected_mask.scatter_(1, indices, True)
return selected_mask, router_weights
class MoDBlock(nn.Module):
"""Mixture-of-Depths Transformer Block"""
def __init__(
self,
dim: int,
num_heads: int,
mlp_ratio: float = 4.0,
capacity_factor: float = 0.125,
dropout: float = 0.0,
):
super().__init__()
self.dim = dim
self.capacity_factor = capacity_factor
# Router
self.router = MoDRouter(dim, capacity_factor)
# Attention
self.norm1 = nn.LayerNorm(dim)
self.attn = nn.MultiheadAttention(
dim, num_heads, dropout=dropout, batch_first=True
)
# MLP
self.norm2 = nn.LayerNorm(dim)
hidden_dim = int(dim * mlp_ratio)
self.mlp = nn.Sequential(
nn.Linear(dim, hidden_dim),
nn.GELU(),
nn.Dropout(dropout),
nn.Linear(hidden_dim, dim),
nn.Dropout(dropout),
)
def forward(self, x: torch.Tensor, return_router_weights: bool = False):
"""
Args:
x: (batch, seq_len, dim)
Returns:
output: (batch, seq_len, dim)
router_weights: optional, for aux loss
"""
batch, seq_len, dim = x.shape
# 1. Router로 토큰 선택
selected_mask, router_weights = self.router(x)
# 2. 선택된 토큰 추출
# selected_mask: (batch, seq_len) -> indices
selected_indices = selected_mask.nonzero(as_tuple=False)
# 선택된 토큰만 처리 (효율적 구현)
if selected_mask.any():
# 선택된 토큰 gather
x_selected = x[selected_mask] # (num_selected, dim)
# Attention (선택된 토큰끼리만)
# 실제로는 causal mask와 함께 처리 필요
x_norm = self.norm1(x_selected)
# 간단한 구현: 전체 시퀀스에 대해 attention 후 마스킹
# 실제 구현에서는 효율적인 sparse attention 사용
x_full_norm = self.norm1(x)
attn_out, _ = self.attn(x_full_norm, x_full_norm, x_full_norm)
# MLP
mlp_out = self.mlp(self.norm2(x + attn_out))
# 선택된 토큰만 업데이트
output = x.clone()
output = output + attn_out * selected_mask.unsqueeze(-1).float()
output = output + mlp_out * selected_mask.unsqueeze(-1).float()
else:
output = x
if return_router_weights:
return output, router_weights
return output
class MoDTransformer(nn.Module):
"""Mixture-of-Depths Transformer"""
def __init__(
self,
vocab_size: int,
dim: int = 512,
depth: int = 12,
num_heads: int = 8,
mlp_ratio: float = 4.0,
capacity_factor: float = 0.125,
max_seq_len: int = 2048,
dropout: float = 0.0,
):
super().__init__()
self.dim = dim
# Embeddings
self.token_emb = nn.Embedding(vocab_size, dim)
self.pos_emb = nn.Embedding(max_seq_len, dim)
# MoD Blocks
self.blocks = nn.ModuleList([
MoDBlock(dim, num_heads, mlp_ratio, capacity_factor, dropout)
for _ in range(depth)
])
# Output
self.norm = nn.LayerNorm(dim)
self.head = nn.Linear(dim, vocab_size, bias=False)
def forward(self, x: torch.Tensor, return_router_weights: bool = False):
"""
Args:
x: (batch, seq_len) token indices
Returns:
logits: (batch, seq_len, vocab_size)
"""
batch, seq_len = x.shape
# Embeddings
pos = torch.arange(seq_len, device=x.device)
x = self.token_emb(x) + self.pos_emb(pos)
# MoD Blocks
all_router_weights = []
for block in self.blocks:
if return_router_weights:
x, rw = block(x, return_router_weights=True)
all_router_weights.append(rw)
else:
x = block(x)
# Output
x = self.norm(x)
logits = self.head(x)
if return_router_weights:
return logits, all_router_weights
return logits
Load Balancing Loss
def load_balancing_loss(router_weights_list: list) -> torch.Tensor:
"""
Router weights의 load balancing을 위한 auxiliary loss
Args:
router_weights_list: list of (batch, seq_len) tensors
"""
total_loss = 0.0
for router_weights in router_weights_list:
# 각 토큰이 선택될 확률의 분산을 최소화
# 이상적: 모든 토큰이 균등한 확률로 선택됨
mean_weight = router_weights.mean(dim=-1, keepdim=True)
variance = ((router_weights - mean_weight) ** 2).mean()
total_loss = total_loss + variance
return total_loss / len(router_weights_list)
def train_step(model, optimizer, x, y, aux_weight=0.01):
"""MoD 모델 학습 스텝"""
optimizer.zero_grad()
logits, router_weights = model(x, return_router_weights=True)
# Language modeling loss
lm_loss = F.cross_entropy(
logits.view(-1, logits.size(-1)),
y.view(-1)
)
# Auxiliary load balancing loss
aux_loss = load_balancing_loss(router_weights)
# Total loss
total_loss = lm_loss + aux_weight * aux_loss
total_loss.backward()
optimizer.step()
return {
'total_loss': total_loss.item(),
'lm_loss': lm_loss.item(),
'aux_loss': aux_loss.item(),
}
효율적인 추론
class EfficientMoDInference:
"""효율적인 MoD 추론을 위한 wrapper"""
def __init__(self, model: MoDTransformer):
self.model = model
self.model.eval()
@torch.no_grad()
def generate(
self,
prompt_ids: torch.Tensor,
max_new_tokens: int = 100,
temperature: float = 1.0,
top_k: int = 50,
) -> torch.Tensor:
"""
Autoregressive generation with MoD
Args:
prompt_ids: (1, prompt_len) token indices
max_new_tokens: 생성할 최대 토큰 수
"""
generated = prompt_ids.clone()
for _ in range(max_new_tokens):
# Forward pass
logits = self.model(generated)
# 마지막 토큰의 로짓만 사용
next_logits = logits[:, -1, :] / temperature
# Top-k sampling
if top_k > 0:
v, _ = torch.topk(next_logits, top_k)
next_logits[next_logits < v[:, [-1]]] = float('-inf')
probs = F.softmax(next_logits, dim=-1)
next_token = torch.multinomial(probs, num_samples=1)
generated = torch.cat([generated, next_token], dim=1)
return generated
def compute_flops_savings(self, seq_len: int) -> dict:
"""FLOPs 절약량 계산"""
capacity_factor = self.model.blocks[0].capacity_factor
num_layers = len(self.model.blocks)
baseline_flops = num_layers * seq_len # 상대적 단위
mod_flops = num_layers * seq_len * capacity_factor
return {
'baseline_flops': baseline_flops,
'mod_flops': mod_flops,
'savings_percent': (1 - capacity_factor) * 100,
}
응용 및 확장
적용 분야
| 분야 |
장점 |
| 실시간 추론 |
50% 이상 속도 향상 |
| 엣지 디바이스 |
FLOPs 절약으로 저전력 |
| 긴 문맥 처리 |
메모리 효율성 |
| 배치 추론 |
처리량 증가 |
변형 및 확장
- Layer-wise Capacity: 각 레이어마다 다른 capacity factor
- Dynamic Capacity: 입력에 따라 capacity 조절
- MoD + MoE: 두 기법 결합
- Auxiliary Router: 별도 경량 모델로 라우팅 결정
한계 및 고려사항
| 한계 |
설명 |
| 학습 복잡도 |
Router 학습에 추가 하이퍼파라미터 필요 |
| 작은 모델 |
큰 모델에서 효과가 더 큼 |
| Task 의존성 |
모든 태스크에서 동일한 효과 X |
| 구현 복잡도 |
효율적 구현에 커스텀 커널 필요 |
관련 연구
| 논문/기법 |
관계 |
| Mixture-of-Experts |
유사한 라우팅 메커니즘, 다른 적용 대상 |
| Early Exit |
토큰이 아닌 샘플 단위 동적 깊이 |
| Adaptive Computation |
일반적인 동적 연산 프레임워크 |
| Universal Transformers |
반복 횟수 동적 조절 |
참고 자료
논문
- Raposo et al. (2024). Mixture-of-Depths: Dynamically allocating compute in transformer-based language models. arXiv:2404.02258
구현
관련 문서