Mixture of Experts (MoE)¶
Sparse architecture that routes each input to a subset of specialized "experts", enabling efficient scaling of model capacity while maintaining constant computational cost per example.
Meta¶
| Item | Value |
|---|---|
| Category | Deep Learning Architecture |
| First Proposed | Jacobs et al. (1991) |
| Modern Revival | Shazeer et al. (2017) |
| Key Conferences | NeurIPS, ICML, ICLR |
| Notable Models | Switch Transformer, Mixtral, GPT-4 (rumored), DeepSeek-MoE |
Key References
- Jacobs et al. "Adaptive Mixtures of Local Experts" Neural Computation (1991)
- Shazeer et al. "Outrageously Large Neural Networks: The Sparsely-Gated Mixture-of-Experts Layer" ICLR (2017)
- Fedus et al. "Switch Transformers: Scaling to Trillion Parameter Models" JMLR (2022)
- Jiang et al. "Mixtral of Experts" arXiv:2401.04088 (2024)
- Fedus et al. "A Review of Sparse Expert Models in Deep Learning" arXiv:2209.01667 (2022)
Core Concept¶
Mixture of Experts (MoE)는 조건부 계산(conditional computation)을 통해 모델 파라미터 수를 크게 늘리면서도 추론 비용을 일정하게 유지하는 아키텍처다.
Key Components¶
| Component | Description |
|---|---|
| Experts | 독립적인 서브 네트워크 (일반적으로 FFN) |
| Router/Gate | 입력을 적절한 expert에 라우팅하는 네트워크 |
| Top-K Selection | 각 토큰에 대해 K개의 expert만 활성화 |
| Load Balancing | Expert 간 균형 있는 토큰 분배를 위한 메커니즘 |
Why MoE?¶
Dense Model: All parameters active for every input
Compute = O(Parameters)
MoE Model: Only subset of parameters active per input
Compute = O(Active Parameters) << O(Total Parameters)
예시: Mixtral 8x7B
- 총 파라미터: 47B (8개 expert FFN + shared layers)
- 활성 파라미터: ~13B (top-2 expert selection)
- 추론 속도: 12B dense 모델과 유사
- 성능: 70B dense 모델과 경쟁
Architecture¶
Standard MoE Layer¶
Transformer의 FFN layer를 MoE layer로 대체:
Input x
|
v
[Router Network] --> expert weights g_1, g_2, ..., g_n
|
v
Top-K Selection --> activate experts E_i, E_j, ...
|
v
y = sum(g_i * E_i(x)) for selected experts
|
v
Output y
Router/Gating Mechanisms¶
1. Softmax Gating (Basic)
2. Noisy Top-K Gating (Shazeer 2017)
H(x)_i = (x * W_g)_i + StandardNormal() * Softplus((x * W_noise)_i)
G(x) = Softmax(KeepTopK(H(x), k))
- 노이즈 추가로 exploration 촉진
- Top-K로 sparsity 확보
3. Expert Choice Routing (Zhou et al. 2022)
기존: 토큰이 expert를 선택 Expert Choice: expert가 토큰을 선택
장점: 완벽한 load balancing, 더 높은 throughput
Key Variants¶
1. Switch Transformer (Fedus et al. 2022)¶
Top-1 routing으로 단순화. 1.6T 파라미터까지 스케일링.
| Feature | Description |
|---|---|
| Routing | Top-1 (single expert per token) |
| Scale | Up to 2048 experts, 1.6T parameters |
| Speedup | 4x faster pretraining vs T5-XXL |
Auxiliary Loss for Load Balancing:
L_aux = alpha * n * sum(f_i * P_i)
f_i = (tokens routed to expert i) / (total tokens)
P_i = (sum of router probabilities for expert i) / (total tokens)
2. GShard (Lepikhin et al. 2020)¶
분산 학습을 위한 MoE 아키텍처.
- Top-2 gating
- Random routing for second expert
- Expert capacity 개념 도입
3. Mixtral 8x7B (Mistral AI, 2024)¶
오픈소스 MoE 모델의 게임체인저.
| Spec | Value |
|---|---|
| Experts | 8 per layer |
| Active | 2 per token |
| Total Params | 47B |
| Active Params | 13B |
| Context | 32K tokens |
벤치마크 결과: Llama 2 70B, GPT-3.5 능가 (수학, 코드, 다국어)
4. DeepSeek-MoE (2024)¶
Fine-grained expert segmentation + shared expert 도입.
- 일부 expert는 모든 토큰에 공유
- 나머지는 routing으로 선택
- 더 효율적인 파라미터 활용
Training Considerations¶
Load Balancing Problem¶
문제: Router가 소수의 expert에만 토큰을 보내는 경향
해결책:
| Method | Description |
|---|---|
| Auxiliary Loss | Expert utilization 균형화하는 추가 loss |
| Capacity Factor | Expert당 처리 가능 토큰 수 제한 |
| Noise Injection | Router에 노이즈 추가로 exploration |
| Expert Dropout | 일부 expert 랜덤 드롭 |
Capacity Factor¶
Expert Capacity = (tokens_per_batch / num_experts) * capacity_factor
capacity_factor > 1.0: 오버플로우 허용
capacity_factor = 1.0: 완벽 균형 가정
오버플로우된 토큰은 residual connection으로 다음 layer로 전달 (또는 drop)
Router Z-Loss (Switch Transformer)¶
Router logits의 크기를 제한하여 안정성 향상:
Inference Optimization¶
Challenges¶
| Challenge | Description |
|---|---|
| Memory | 모든 expert를 메모리에 로드해야 함 |
| Batching | Dynamic routing으로 배치 효율 저하 |
| Communication | 분산 환경에서 expert 간 통신 오버헤드 |
Solutions¶
1. Expert Parallelism
각 GPU에 서로 다른 expert 배치. All-to-all communication으로 토큰 교환.
2. Expert Offloading
비활성 expert는 CPU/SSD에 저장, 필요시 로드.
# Conceptual: Expert offloading
for token in batch:
active_experts = router(token)
load_experts_to_gpu(active_experts)
output = forward(token, active_experts)
offload_experts(active_experts)
3. Expert Pruning/Merging
- 사용률 낮은 expert 제거
- 유사한 expert 병합
- Knowledge distillation으로 dense 모델로 압축
Python Implementation¶
Basic MoE Layer (PyTorch)¶
import torch
import torch.nn as nn
import torch.nn.functional as F
class Expert(nn.Module):
"""Single expert network (FFN)."""
def __init__(self, input_dim, hidden_dim, output_dim):
super().__init__()
self.fc1 = nn.Linear(input_dim, hidden_dim)
self.fc2 = nn.Linear(hidden_dim, output_dim)
def forward(self, x):
return self.fc2(F.gelu(self.fc1(x)))
class Router(nn.Module):
"""Top-K router with noise."""
def __init__(self, input_dim, num_experts, top_k=2, noise_std=1.0):
super().__init__()
self.num_experts = num_experts
self.top_k = top_k
self.noise_std = noise_std
self.gate = nn.Linear(input_dim, num_experts, bias=False)
def forward(self, x):
# x: (batch, seq, dim)
logits = self.gate(x) # (batch, seq, num_experts)
# Add noise during training
if self.training and self.noise_std > 0:
noise = torch.randn_like(logits) * self.noise_std
logits = logits + noise
# Top-K selection
top_k_logits, top_k_indices = logits.topk(self.top_k, dim=-1)
top_k_gates = F.softmax(top_k_logits, dim=-1)
return top_k_gates, top_k_indices, logits
class MoELayer(nn.Module):
"""Mixture of Experts layer."""
def __init__(
self,
input_dim,
hidden_dim,
output_dim,
num_experts=8,
top_k=2,
aux_loss_weight=0.01
):
super().__init__()
self.num_experts = num_experts
self.top_k = top_k
self.aux_loss_weight = aux_loss_weight
# Create experts
self.experts = nn.ModuleList([
Expert(input_dim, hidden_dim, output_dim)
for _ in range(num_experts)
])
# Router
self.router = Router(input_dim, num_experts, top_k)
def compute_aux_loss(self, router_logits):
"""Load balancing auxiliary loss."""
# router_logits: (batch, seq, num_experts)
router_probs = F.softmax(router_logits, dim=-1)
# Fraction of tokens routed to each expert
tokens_per_expert = router_probs.mean(dim=[0, 1])
# Average router probability for each expert
router_prob_per_expert = router_probs.mean(dim=[0, 1])
# Auxiliary loss
aux_loss = self.num_experts * (tokens_per_expert * router_prob_per_expert).sum()
return aux_loss
def forward(self, x):
batch_size, seq_len, dim = x.shape
# Get routing decisions
gates, indices, router_logits = self.router(x)
# gates: (batch, seq, top_k)
# indices: (batch, seq, top_k)
# Compute auxiliary loss
aux_loss = self.compute_aux_loss(router_logits)
# Initialize output
output = torch.zeros_like(x)
# Route tokens to experts
for k in range(self.top_k):
expert_indices = indices[:, :, k] # (batch, seq)
expert_gates = gates[:, :, k:k+1] # (batch, seq, 1)
for e in range(self.num_experts):
# Find tokens routed to this expert
mask = (expert_indices == e)
if mask.any():
expert_input = x[mask]
expert_output = self.experts[e](expert_input)
output[mask] += expert_gates[mask].squeeze(-1).unsqueeze(-1) * expert_output
return output, aux_loss * self.aux_loss_weight
# Example usage
if __name__ == "__main__":
batch_size, seq_len, dim = 4, 128, 512
hidden_dim = 2048
moe = MoELayer(
input_dim=dim,
hidden_dim=hidden_dim,
output_dim=dim,
num_experts=8,
top_k=2
)
x = torch.randn(batch_size, seq_len, dim)
output, aux_loss = moe(x)
print(f"Input shape: {x.shape}")
print(f"Output shape: {output.shape}")
print(f"Auxiliary loss: {aux_loss.item():.4f}")
Efficient Batched Implementation¶
import torch
import torch.nn as nn
import torch.nn.functional as F
class EfficientMoELayer(nn.Module):
"""
Memory-efficient MoE with batched expert computation.
Inspired by Megablocks (https://github.com/databricks/megablocks)
"""
def __init__(
self,
input_dim,
hidden_dim,
num_experts=8,
top_k=2,
capacity_factor=1.25
):
super().__init__()
self.num_experts = num_experts
self.top_k = top_k
self.capacity_factor = capacity_factor
# Stacked expert weights for efficient batched computation
self.w1 = nn.Parameter(torch.randn(num_experts, input_dim, hidden_dim) * 0.02)
self.w2 = nn.Parameter(torch.randn(num_experts, hidden_dim, input_dim) * 0.02)
# Router
self.gate = nn.Linear(input_dim, num_experts, bias=False)
def forward(self, x):
batch_size, seq_len, dim = x.shape
num_tokens = batch_size * seq_len
x_flat = x.view(num_tokens, dim)
# Routing
router_logits = self.gate(x_flat) # (num_tokens, num_experts)
router_probs = F.softmax(router_logits, dim=-1)
# Top-K selection
top_k_probs, top_k_indices = router_probs.topk(self.top_k, dim=-1)
top_k_probs = top_k_probs / top_k_probs.sum(dim=-1, keepdim=True)
# Capacity
capacity = int(self.capacity_factor * num_tokens / self.num_experts)
# Initialize output
output = torch.zeros_like(x_flat)
# Process each expert
for e in range(self.num_experts):
# Find tokens for this expert
expert_mask = (top_k_indices == e).any(dim=-1)
token_indices = expert_mask.nonzero().squeeze(-1)
if len(token_indices) == 0:
continue
# Apply capacity limit
if len(token_indices) > capacity:
token_indices = token_indices[:capacity]
# Get expert weights for these tokens
expert_weights = torch.zeros(len(token_indices), device=x.device)
for k in range(self.top_k):
mask = top_k_indices[token_indices, k] == e
expert_weights[mask] = top_k_probs[token_indices[mask], k]
# Expert forward pass
expert_input = x_flat[token_indices] # (n, dim)
hidden = F.gelu(expert_input @ self.w1[e]) # (n, hidden_dim)
expert_output = hidden @ self.w2[e] # (n, dim)
# Weighted output
output[token_indices] += expert_weights.unsqueeze(-1) * expert_output
return output.view(batch_size, seq_len, dim)
Using HuggingFace Transformers (Mixtral)¶
from transformers import AutoModelForCausalLM, AutoTokenizer
import torch
# Load Mixtral 8x7B
model_id = "mistralai/Mixtral-8x7B-Instruct-v0.1"
tokenizer = AutoTokenizer.from_pretrained(model_id)
model = AutoModelForCausalLM.from_pretrained(
model_id,
torch_dtype=torch.float16,
device_map="auto", # Automatic multi-GPU distribution
load_in_4bit=True # Quantization for memory efficiency
)
# Generate text
prompt = "Explain mixture of experts in machine learning:"
inputs = tokenizer(prompt, return_tensors="pt").to(model.device)
outputs = model.generate(
**inputs,
max_new_tokens=256,
do_sample=True,
temperature=0.7
)
print(tokenizer.decode(outputs[0], skip_special_tokens=True))
When to Use MoE¶
Good Fit¶
| Scenario | Reason |
|---|---|
| Large-scale pretraining | Faster training with same compute |
| Multi-task learning | Different experts can specialize |
| Diverse data domains | Experts handle different domains |
| Inference latency critical | Active params < total params |
Not Ideal¶
| Scenario | Reason |
|---|---|
| Memory constrained | All experts must be loaded |
| Small datasets | Experts may not differentiate |
| Fine-tuning focus | MoE can overfit, harder to fine-tune |
| Simple tasks | Overhead not justified |
Comparison with Dense Models¶
| Aspect | Dense | MoE |
|---|---|---|
| Parameters | All active | Subset active |
| Pretraining Speed | Baseline | 2-4x faster |
| Inference FLOPs | Proportional to params | Lower than equivalent dense |
| Memory | Proportional to params | High (all experts loaded) |
| Fine-tuning | Standard | Challenging, prone to overfitting |
| Implementation | Simple | Complex (routing, load balancing) |
Recent Advances (2024-2025)¶
| Development | Description |
|---|---|
| Expert Choice Routing | Expert가 토큰 선택 (vs 토큰이 expert 선택) |
| Soft MoE | Continuous routing weights (not hard top-K) |
| MoE in Vision | Vision Transformers에 MoE 적용 (V-MoE) |
| Sparse Upcycling | Dense 모델을 MoE로 변환 |
| Mixture of LoRA | Fine-tuning용 MoE adapter |
Key Takeaways¶
- MoE는 sparse conditional computation으로 모델 용량과 계산 비용을 분리
- Router network가 각 입력을 적절한 expert subset에 라우팅
- Load balancing은 학습 안정성과 효율성의 핵심 (auxiliary loss, capacity factor)
- 메모리 요구량은 높지만, 추론 속도는 활성 파라미터에 비례
- Mixtral 8x7B 등 오픈소스 모델 등장으로 실용적 활용 가능
- Fine-tuning 어려움은 ongoing research area (Mixture of LoRA 등)
Last updated: 2026-02-07