콘텐츠로 이동
Data Prep
상세

Mixture of Experts (MoE)

Sparse architecture that routes each input to a subset of specialized "experts", enabling efficient scaling of model capacity while maintaining constant computational cost per example.


Meta

Item Value
Category Deep Learning Architecture
First Proposed Jacobs et al. (1991)
Modern Revival Shazeer et al. (2017)
Key Conferences NeurIPS, ICML, ICLR
Notable Models Switch Transformer, Mixtral, GPT-4 (rumored), DeepSeek-MoE

Key References

  • Jacobs et al. "Adaptive Mixtures of Local Experts" Neural Computation (1991)
  • Shazeer et al. "Outrageously Large Neural Networks: The Sparsely-Gated Mixture-of-Experts Layer" ICLR (2017)
  • Fedus et al. "Switch Transformers: Scaling to Trillion Parameter Models" JMLR (2022)
  • Jiang et al. "Mixtral of Experts" arXiv:2401.04088 (2024)
  • Fedus et al. "A Review of Sparse Expert Models in Deep Learning" arXiv:2209.01667 (2022)

Core Concept

Mixture of Experts (MoE)는 조건부 계산(conditional computation)을 통해 모델 파라미터 수를 크게 늘리면서도 추론 비용을 일정하게 유지하는 아키텍처다.

Key Components

Component Description
Experts 독립적인 서브 네트워크 (일반적으로 FFN)
Router/Gate 입력을 적절한 expert에 라우팅하는 네트워크
Top-K Selection 각 토큰에 대해 K개의 expert만 활성화
Load Balancing Expert 간 균형 있는 토큰 분배를 위한 메커니즘

Why MoE?

Dense Model: All parameters active for every input
              Compute = O(Parameters)

MoE Model:    Only subset of parameters active per input
              Compute = O(Active Parameters) << O(Total Parameters)

예시: Mixtral 8x7B

  • 총 파라미터: 47B (8개 expert FFN + shared layers)
  • 활성 파라미터: ~13B (top-2 expert selection)
  • 추론 속도: 12B dense 모델과 유사
  • 성능: 70B dense 모델과 경쟁

Architecture

Standard MoE Layer

Transformer의 FFN layer를 MoE layer로 대체:

Input x
    |
    v
[Router Network] --> expert weights g_1, g_2, ..., g_n
    |
    v
Top-K Selection --> activate experts E_i, E_j, ...
    |
    v
y = sum(g_i * E_i(x)) for selected experts
    |
    v
Output y

Router/Gating Mechanisms

1. Softmax Gating (Basic)

G(x) = Softmax(x * W_g)

2. Noisy Top-K Gating (Shazeer 2017)

H(x)_i = (x * W_g)_i + StandardNormal() * Softplus((x * W_noise)_i)
G(x) = Softmax(KeepTopK(H(x), k))
  • 노이즈 추가로 exploration 촉진
  • Top-K로 sparsity 확보

3. Expert Choice Routing (Zhou et al. 2022)

기존: 토큰이 expert를 선택 Expert Choice: expert가 토큰을 선택

For each expert e:
    Select top-c tokens with highest affinity to expert e

장점: 완벽한 load balancing, 더 높은 throughput


Key Variants

1. Switch Transformer (Fedus et al. 2022)

Top-1 routing으로 단순화. 1.6T 파라미터까지 스케일링.

Feature Description
Routing Top-1 (single expert per token)
Scale Up to 2048 experts, 1.6T parameters
Speedup 4x faster pretraining vs T5-XXL

Auxiliary Loss for Load Balancing:

L_aux = alpha * n * sum(f_i * P_i)

f_i = (tokens routed to expert i) / (total tokens)
P_i = (sum of router probabilities for expert i) / (total tokens)

2. GShard (Lepikhin et al. 2020)

분산 학습을 위한 MoE 아키텍처.

  • Top-2 gating
  • Random routing for second expert
  • Expert capacity 개념 도입

3. Mixtral 8x7B (Mistral AI, 2024)

오픈소스 MoE 모델의 게임체인저.

Spec Value
Experts 8 per layer
Active 2 per token
Total Params 47B
Active Params 13B
Context 32K tokens

벤치마크 결과: Llama 2 70B, GPT-3.5 능가 (수학, 코드, 다국어)

4. DeepSeek-MoE (2024)

Fine-grained expert segmentation + shared expert 도입.

y = E_shared(x) + sum(g_i * E_i(x))
  • 일부 expert는 모든 토큰에 공유
  • 나머지는 routing으로 선택
  • 더 효율적인 파라미터 활용

Training Considerations

Load Balancing Problem

문제: Router가 소수의 expert에만 토큰을 보내는 경향

해결책:

Method Description
Auxiliary Loss Expert utilization 균형화하는 추가 loss
Capacity Factor Expert당 처리 가능 토큰 수 제한
Noise Injection Router에 노이즈 추가로 exploration
Expert Dropout 일부 expert 랜덤 드롭

Capacity Factor

Expert Capacity = (tokens_per_batch / num_experts) * capacity_factor

capacity_factor > 1.0: 오버플로우 허용
capacity_factor = 1.0: 완벽 균형 가정

오버플로우된 토큰은 residual connection으로 다음 layer로 전달 (또는 drop)

Router Z-Loss (Switch Transformer)

Router logits의 크기를 제한하여 안정성 향상:

L_z = (1/B) * sum(log(sum(exp(x_i))))^2

Inference Optimization

Challenges

Challenge Description
Memory 모든 expert를 메모리에 로드해야 함
Batching Dynamic routing으로 배치 효율 저하
Communication 분산 환경에서 expert 간 통신 오버헤드

Solutions

1. Expert Parallelism

각 GPU에 서로 다른 expert 배치. All-to-all communication으로 토큰 교환.

2. Expert Offloading

비활성 expert는 CPU/SSD에 저장, 필요시 로드.

# Conceptual: Expert offloading
for token in batch:
    active_experts = router(token)
    load_experts_to_gpu(active_experts)
    output = forward(token, active_experts)
    offload_experts(active_experts)

3. Expert Pruning/Merging

  • 사용률 낮은 expert 제거
  • 유사한 expert 병합
  • Knowledge distillation으로 dense 모델로 압축

Python Implementation

Basic MoE Layer (PyTorch)

import torch
import torch.nn as nn
import torch.nn.functional as F

class Expert(nn.Module):
    """Single expert network (FFN)."""
    def __init__(self, input_dim, hidden_dim, output_dim):
        super().__init__()
        self.fc1 = nn.Linear(input_dim, hidden_dim)
        self.fc2 = nn.Linear(hidden_dim, output_dim)

    def forward(self, x):
        return self.fc2(F.gelu(self.fc1(x)))


class Router(nn.Module):
    """Top-K router with noise."""
    def __init__(self, input_dim, num_experts, top_k=2, noise_std=1.0):
        super().__init__()
        self.num_experts = num_experts
        self.top_k = top_k
        self.noise_std = noise_std
        self.gate = nn.Linear(input_dim, num_experts, bias=False)

    def forward(self, x):
        # x: (batch, seq, dim)
        logits = self.gate(x)  # (batch, seq, num_experts)

        # Add noise during training
        if self.training and self.noise_std > 0:
            noise = torch.randn_like(logits) * self.noise_std
            logits = logits + noise

        # Top-K selection
        top_k_logits, top_k_indices = logits.topk(self.top_k, dim=-1)
        top_k_gates = F.softmax(top_k_logits, dim=-1)

        return top_k_gates, top_k_indices, logits


class MoELayer(nn.Module):
    """Mixture of Experts layer."""
    def __init__(
        self, 
        input_dim, 
        hidden_dim, 
        output_dim, 
        num_experts=8, 
        top_k=2,
        aux_loss_weight=0.01
    ):
        super().__init__()
        self.num_experts = num_experts
        self.top_k = top_k
        self.aux_loss_weight = aux_loss_weight

        # Create experts
        self.experts = nn.ModuleList([
            Expert(input_dim, hidden_dim, output_dim)
            for _ in range(num_experts)
        ])

        # Router
        self.router = Router(input_dim, num_experts, top_k)

    def compute_aux_loss(self, router_logits):
        """Load balancing auxiliary loss."""
        # router_logits: (batch, seq, num_experts)
        router_probs = F.softmax(router_logits, dim=-1)

        # Fraction of tokens routed to each expert
        tokens_per_expert = router_probs.mean(dim=[0, 1])

        # Average router probability for each expert
        router_prob_per_expert = router_probs.mean(dim=[0, 1])

        # Auxiliary loss
        aux_loss = self.num_experts * (tokens_per_expert * router_prob_per_expert).sum()

        return aux_loss

    def forward(self, x):
        batch_size, seq_len, dim = x.shape

        # Get routing decisions
        gates, indices, router_logits = self.router(x)
        # gates: (batch, seq, top_k)
        # indices: (batch, seq, top_k)

        # Compute auxiliary loss
        aux_loss = self.compute_aux_loss(router_logits)

        # Initialize output
        output = torch.zeros_like(x)

        # Route tokens to experts
        for k in range(self.top_k):
            expert_indices = indices[:, :, k]  # (batch, seq)
            expert_gates = gates[:, :, k:k+1]  # (batch, seq, 1)

            for e in range(self.num_experts):
                # Find tokens routed to this expert
                mask = (expert_indices == e)
                if mask.any():
                    expert_input = x[mask]
                    expert_output = self.experts[e](expert_input)
                    output[mask] += expert_gates[mask].squeeze(-1).unsqueeze(-1) * expert_output

        return output, aux_loss * self.aux_loss_weight


# Example usage
if __name__ == "__main__":
    batch_size, seq_len, dim = 4, 128, 512
    hidden_dim = 2048

    moe = MoELayer(
        input_dim=dim,
        hidden_dim=hidden_dim,
        output_dim=dim,
        num_experts=8,
        top_k=2
    )

    x = torch.randn(batch_size, seq_len, dim)
    output, aux_loss = moe(x)

    print(f"Input shape: {x.shape}")
    print(f"Output shape: {output.shape}")
    print(f"Auxiliary loss: {aux_loss.item():.4f}")

Efficient Batched Implementation

import torch
import torch.nn as nn
import torch.nn.functional as F

class EfficientMoELayer(nn.Module):
    """
    Memory-efficient MoE with batched expert computation.
    Inspired by Megablocks (https://github.com/databricks/megablocks)
    """
    def __init__(
        self,
        input_dim,
        hidden_dim,
        num_experts=8,
        top_k=2,
        capacity_factor=1.25
    ):
        super().__init__()
        self.num_experts = num_experts
        self.top_k = top_k
        self.capacity_factor = capacity_factor

        # Stacked expert weights for efficient batched computation
        self.w1 = nn.Parameter(torch.randn(num_experts, input_dim, hidden_dim) * 0.02)
        self.w2 = nn.Parameter(torch.randn(num_experts, hidden_dim, input_dim) * 0.02)

        # Router
        self.gate = nn.Linear(input_dim, num_experts, bias=False)

    def forward(self, x):
        batch_size, seq_len, dim = x.shape
        num_tokens = batch_size * seq_len
        x_flat = x.view(num_tokens, dim)

        # Routing
        router_logits = self.gate(x_flat)  # (num_tokens, num_experts)
        router_probs = F.softmax(router_logits, dim=-1)

        # Top-K selection
        top_k_probs, top_k_indices = router_probs.topk(self.top_k, dim=-1)
        top_k_probs = top_k_probs / top_k_probs.sum(dim=-1, keepdim=True)

        # Capacity
        capacity = int(self.capacity_factor * num_tokens / self.num_experts)

        # Initialize output
        output = torch.zeros_like(x_flat)

        # Process each expert
        for e in range(self.num_experts):
            # Find tokens for this expert
            expert_mask = (top_k_indices == e).any(dim=-1)
            token_indices = expert_mask.nonzero().squeeze(-1)

            if len(token_indices) == 0:
                continue

            # Apply capacity limit
            if len(token_indices) > capacity:
                token_indices = token_indices[:capacity]

            # Get expert weights for these tokens
            expert_weights = torch.zeros(len(token_indices), device=x.device)
            for k in range(self.top_k):
                mask = top_k_indices[token_indices, k] == e
                expert_weights[mask] = top_k_probs[token_indices[mask], k]

            # Expert forward pass
            expert_input = x_flat[token_indices]  # (n, dim)
            hidden = F.gelu(expert_input @ self.w1[e])  # (n, hidden_dim)
            expert_output = hidden @ self.w2[e]  # (n, dim)

            # Weighted output
            output[token_indices] += expert_weights.unsqueeze(-1) * expert_output

        return output.view(batch_size, seq_len, dim)

Using HuggingFace Transformers (Mixtral)

from transformers import AutoModelForCausalLM, AutoTokenizer
import torch

# Load Mixtral 8x7B
model_id = "mistralai/Mixtral-8x7B-Instruct-v0.1"

tokenizer = AutoTokenizer.from_pretrained(model_id)
model = AutoModelForCausalLM.from_pretrained(
    model_id,
    torch_dtype=torch.float16,
    device_map="auto",  # Automatic multi-GPU distribution
    load_in_4bit=True   # Quantization for memory efficiency
)

# Generate text
prompt = "Explain mixture of experts in machine learning:"
inputs = tokenizer(prompt, return_tensors="pt").to(model.device)

outputs = model.generate(
    **inputs,
    max_new_tokens=256,
    do_sample=True,
    temperature=0.7
)

print(tokenizer.decode(outputs[0], skip_special_tokens=True))

When to Use MoE

Good Fit

Scenario Reason
Large-scale pretraining Faster training with same compute
Multi-task learning Different experts can specialize
Diverse data domains Experts handle different domains
Inference latency critical Active params < total params

Not Ideal

Scenario Reason
Memory constrained All experts must be loaded
Small datasets Experts may not differentiate
Fine-tuning focus MoE can overfit, harder to fine-tune
Simple tasks Overhead not justified

Comparison with Dense Models

Aspect Dense MoE
Parameters All active Subset active
Pretraining Speed Baseline 2-4x faster
Inference FLOPs Proportional to params Lower than equivalent dense
Memory Proportional to params High (all experts loaded)
Fine-tuning Standard Challenging, prone to overfitting
Implementation Simple Complex (routing, load balancing)

Recent Advances (2024-2025)

Development Description
Expert Choice Routing Expert가 토큰 선택 (vs 토큰이 expert 선택)
Soft MoE Continuous routing weights (not hard top-K)
MoE in Vision Vision Transformers에 MoE 적용 (V-MoE)
Sparse Upcycling Dense 모델을 MoE로 변환
Mixture of LoRA Fine-tuning용 MoE adapter

Key Takeaways

  1. MoE는 sparse conditional computation으로 모델 용량과 계산 비용을 분리
  2. Router network가 각 입력을 적절한 expert subset에 라우팅
  3. Load balancing은 학습 안정성과 효율성의 핵심 (auxiliary loss, capacity factor)
  4. 메모리 요구량은 높지만, 추론 속도는 활성 파라미터에 비례
  5. Mixtral 8x7B 등 오픈소스 모델 등장으로 실용적 활용 가능
  6. Fine-tuning 어려움은 ongoing research area (Mixture of LoRA 등)

Last updated: 2026-02-07