콘텐츠로 이동
Data Prep
상세

Mixture-of-Depths (MoD)

개요

Mixture-of-Depths는 Google DeepMind가 2024년에 발표한 동적 연산 할당 기법이다. 기존 트랜스포머가 모든 토큰에 동일한 연산량을 할당하는 것과 달리, MoD는 토큰별로 필요한 연산량을 동적으로 결정한다.

항목 내용
논문 Mixture-of-Depths: Dynamically allocating compute in transformer-based language models
저자 David Raposo, Sam Ritter, Blake Richards, Timothy Lillicrap, Peter Conway Humphreys, Adam Santoro
arXiv 2404.02258
소속 Google DeepMind
발표 2024년 4월

핵심 아이디어

기존 트랜스포머의 문제

Standard Transformer:
Layer 1:  [tok1] [tok2] [tok3] [tok4] [tok5]  <- 모든 토큰 처리
Layer 2:  [tok1] [tok2] [tok3] [tok4] [tok5]  <- 모든 토큰 처리
Layer 3:  [tok1] [tok2] [tok3] [tok4] [tok5]  <- 모든 토큰 처리
...
모든 레이어에서 모든 토큰에 동일한 FLOPs 할당

문제점: - 일부 토큰은 "쉬움" (예: 관사 "the", "a") - 일부 토큰은 "어려움" (예: 전문 용어, 맥락 의존적 단어) - 모든 토큰에 동일 연산 = 비효율

MoD 접근법

Mixture-of-Depths:
Layer 1:  [tok1]   -   [tok3]   -   [tok5]  <- top-k 토큰만 처리
Layer 2:    -   [tok2]   -   [tok4] [tok5]  <- 다른 top-k 토큰
Layer 3:  [tok1] [tok2]   -     -   [tok5]  <- 또 다른 top-k 토큰
...
각 레이어에서 가장 중요한 k개 토큰만 연산

핵심 원리: 1. Capacity Factor: 각 레이어에서 처리할 토큰 비율 (예: 12.5%) 2. Router: 어떤 토큰을 처리할지 결정 3. Residual Connection: 처리되지 않은 토큰은 skip

아키텍처

Router 메커니즘

입력 시퀀스: [x_1, x_2, ..., x_n]

1. Router 점수 계산:
   r_i = W_r * x_i   (스칼라 점수)

2. Top-k 선택:
   k = capacity_factor * n
   S = top_k_indices(r_1, ..., r_n)

3. 선택된 토큰만 처리:
   for i in S:
       x_i = Attention(x_i) + MLP(x_i)

   for i not in S:
       x_i = x_i  (skip, residual만)

전체 구조

┌─────────────────────────────────────────────────────────┐
│                    Input Sequence                        │
│              [tok1, tok2, tok3, ..., tokN]              │
└─────────────────────────────────────────────────────────┘
┌─────────────────────────────────────────────────────────┐
│                     Router (Layer 1)                     │
│   scores = W_r @ X                                       │
│   selected_idx = top_k(scores, k=capacity*N)            │
└─────────────────────────────────────────────────────────┘
              ┌─────────────┴─────────────┐
              │                           │
              ▼                           ▼
┌─────────────────────────┐   ┌─────────────────────────┐
│    Selected Tokens      │   │    Skipped Tokens       │
│   (k tokens)            │   │   (N-k tokens)          │
├─────────────────────────┤   ├─────────────────────────┤
│   Self-Attention        │   │                         │
│         +               │   │   Identity (Residual)   │
│       MLP               │   │                         │
└─────────────────────────┘   └─────────────────────────┘
              │                           │
              └─────────────┬─────────────┘
┌─────────────────────────────────────────────────────────┐
│                  Combined Output                         │
│         (재정렬하여 원래 순서로)                         │
└─────────────────────────────────────────────────────────┘
                    (다음 레이어 반복)

MoD vs MoE 비교

측면 Mixture-of-Experts (MoE) Mixture-of-Depths (MoD)
라우팅 대상 Expert (MLP 변형) Layer (전체 블록)
동적 요소 어떤 expert를 사용할지 얼마나 깊이 처리할지
파라미터 증가 (다중 expert) 동일 유지
FLOPs 고정 (선택된 expert만) 감소 (일부 토큰 skip)
추론 속도 메모리 바운드 FLOP 절약으로 빨라짐

결합 가능성

MoD + MoE 결합:
1. MoD로 어떤 토큰을 처리할지 결정
2. 선택된 토큰에 대해 MoE로 어떤 expert를 사용할지 결정

결과: 토큰별, 레이어별, Expert별 동적 연산

학습 방법

Auxiliary Loss

Router가 의미있는 선택을 하도록 보조 손실 함수 사용:

L_total = L_language + alpha * L_router

L_router: load balancing loss (토큰 분포 균형)

Capacity Factor 설정

Capacity Factor 처리 토큰 비율 FLOPs 절약 성능 영향
1.0 100% 0% 베이스라인
0.5 50% ~50% 미미
0.25 25% ~75% 약간
0.125 12.5% ~87.5% 측정 가능

논문 결과: capacity_factor=0.125에서도 베이스라인과 동등한 성능 달성

성능 결과

학습 효율성

모델 FLOPs (학습) FLOPs (추론) 성능
Baseline (12.5B) 1x 1x 1.0
MoD (12.5B) 1x 0.5x 1.0
MoD (12.5B, isoFLOP) 0.66x 0.33x 1.0

추론 속도

  • 최대 50% 빠른 샘플링 (동일 품질)
  • Static compute graph로 효율적 배치 처리
  • KV-cache와 완전 호환

Python 구현

기본 Router

import torch
import torch.nn as nn
import torch.nn.functional as F

class MoDRouter(nn.Module):
    """Mixture-of-Depths Router"""

    def __init__(self, dim: int, capacity_factor: float = 0.125):
        super().__init__()
        self.capacity_factor = capacity_factor
        self.router = nn.Linear(dim, 1, bias=False)

    def forward(self, x: torch.Tensor) -> tuple:
        """
        Args:
            x: (batch, seq_len, dim)
        Returns:
            selected_mask: (batch, seq_len) bool tensor
            router_weights: (batch, seq_len) for aux loss
        """
        batch, seq_len, dim = x.shape

        # Router 점수 계산
        router_logits = self.router(x).squeeze(-1)  # (batch, seq_len)
        router_weights = torch.sigmoid(router_logits)

        # Top-k 선택
        k = int(seq_len * self.capacity_factor)
        k = max(1, k)  # 최소 1개

        # 각 배치에서 top-k 인덱스 선택
        _, indices = torch.topk(router_logits, k, dim=-1)

        # 마스크 생성
        selected_mask = torch.zeros(batch, seq_len, dtype=torch.bool, device=x.device)
        selected_mask.scatter_(1, indices, True)

        return selected_mask, router_weights

MoD Transformer Block

class MoDBlock(nn.Module):
    """Mixture-of-Depths Transformer Block"""

    def __init__(
        self,
        dim: int,
        num_heads: int,
        mlp_ratio: float = 4.0,
        capacity_factor: float = 0.125,
        dropout: float = 0.0,
    ):
        super().__init__()
        self.dim = dim
        self.capacity_factor = capacity_factor

        # Router
        self.router = MoDRouter(dim, capacity_factor)

        # Attention
        self.norm1 = nn.LayerNorm(dim)
        self.attn = nn.MultiheadAttention(
            dim, num_heads, dropout=dropout, batch_first=True
        )

        # MLP
        self.norm2 = nn.LayerNorm(dim)
        hidden_dim = int(dim * mlp_ratio)
        self.mlp = nn.Sequential(
            nn.Linear(dim, hidden_dim),
            nn.GELU(),
            nn.Dropout(dropout),
            nn.Linear(hidden_dim, dim),
            nn.Dropout(dropout),
        )

    def forward(self, x: torch.Tensor, return_router_weights: bool = False):
        """
        Args:
            x: (batch, seq_len, dim)
        Returns:
            output: (batch, seq_len, dim)
            router_weights: optional, for aux loss
        """
        batch, seq_len, dim = x.shape

        # 1. Router로 토큰 선택
        selected_mask, router_weights = self.router(x)

        # 2. 선택된 토큰 추출
        # selected_mask: (batch, seq_len) -> indices
        selected_indices = selected_mask.nonzero(as_tuple=False)

        # 선택된 토큰만 처리 (효율적 구현)
        if selected_mask.any():
            # 선택된 토큰 gather
            x_selected = x[selected_mask]  # (num_selected, dim)

            # Attention (선택된 토큰끼리만)
            # 실제로는 causal mask와 함께 처리 필요
            x_norm = self.norm1(x_selected)

            # 간단한 구현: 전체 시퀀스에 대해 attention 후 마스킹
            # 실제 구현에서는 효율적인 sparse attention 사용
            x_full_norm = self.norm1(x)
            attn_out, _ = self.attn(x_full_norm, x_full_norm, x_full_norm)

            # MLP
            mlp_out = self.mlp(self.norm2(x + attn_out))

            # 선택된 토큰만 업데이트
            output = x.clone()
            output = output + attn_out * selected_mask.unsqueeze(-1).float()
            output = output + mlp_out * selected_mask.unsqueeze(-1).float()
        else:
            output = x

        if return_router_weights:
            return output, router_weights
        return output


class MoDTransformer(nn.Module):
    """Mixture-of-Depths Transformer"""

    def __init__(
        self,
        vocab_size: int,
        dim: int = 512,
        depth: int = 12,
        num_heads: int = 8,
        mlp_ratio: float = 4.0,
        capacity_factor: float = 0.125,
        max_seq_len: int = 2048,
        dropout: float = 0.0,
    ):
        super().__init__()
        self.dim = dim

        # Embeddings
        self.token_emb = nn.Embedding(vocab_size, dim)
        self.pos_emb = nn.Embedding(max_seq_len, dim)

        # MoD Blocks
        self.blocks = nn.ModuleList([
            MoDBlock(dim, num_heads, mlp_ratio, capacity_factor, dropout)
            for _ in range(depth)
        ])

        # Output
        self.norm = nn.LayerNorm(dim)
        self.head = nn.Linear(dim, vocab_size, bias=False)

    def forward(self, x: torch.Tensor, return_router_weights: bool = False):
        """
        Args:
            x: (batch, seq_len) token indices
        Returns:
            logits: (batch, seq_len, vocab_size)
        """
        batch, seq_len = x.shape

        # Embeddings
        pos = torch.arange(seq_len, device=x.device)
        x = self.token_emb(x) + self.pos_emb(pos)

        # MoD Blocks
        all_router_weights = []
        for block in self.blocks:
            if return_router_weights:
                x, rw = block(x, return_router_weights=True)
                all_router_weights.append(rw)
            else:
                x = block(x)

        # Output
        x = self.norm(x)
        logits = self.head(x)

        if return_router_weights:
            return logits, all_router_weights
        return logits

Load Balancing Loss

def load_balancing_loss(router_weights_list: list) -> torch.Tensor:
    """
    Router weights의 load balancing을 위한 auxiliary loss

    Args:
        router_weights_list: list of (batch, seq_len) tensors
    """
    total_loss = 0.0

    for router_weights in router_weights_list:
        # 각 토큰이 선택될 확률의 분산을 최소화
        # 이상적: 모든 토큰이 균등한 확률로 선택됨
        mean_weight = router_weights.mean(dim=-1, keepdim=True)
        variance = ((router_weights - mean_weight) ** 2).mean()
        total_loss = total_loss + variance

    return total_loss / len(router_weights_list)


def train_step(model, optimizer, x, y, aux_weight=0.01):
    """MoD 모델 학습 스텝"""
    optimizer.zero_grad()

    logits, router_weights = model(x, return_router_weights=True)

    # Language modeling loss
    lm_loss = F.cross_entropy(
        logits.view(-1, logits.size(-1)),
        y.view(-1)
    )

    # Auxiliary load balancing loss
    aux_loss = load_balancing_loss(router_weights)

    # Total loss
    total_loss = lm_loss + aux_weight * aux_loss

    total_loss.backward()
    optimizer.step()

    return {
        'total_loss': total_loss.item(),
        'lm_loss': lm_loss.item(),
        'aux_loss': aux_loss.item(),
    }

효율적인 추론

class EfficientMoDInference:
    """효율적인 MoD 추론을 위한 wrapper"""

    def __init__(self, model: MoDTransformer):
        self.model = model
        self.model.eval()

    @torch.no_grad()
    def generate(
        self,
        prompt_ids: torch.Tensor,
        max_new_tokens: int = 100,
        temperature: float = 1.0,
        top_k: int = 50,
    ) -> torch.Tensor:
        """
        Autoregressive generation with MoD

        Args:
            prompt_ids: (1, prompt_len) token indices
            max_new_tokens: 생성할 최대 토큰 수
        """
        generated = prompt_ids.clone()

        for _ in range(max_new_tokens):
            # Forward pass
            logits = self.model(generated)

            # 마지막 토큰의 로짓만 사용
            next_logits = logits[:, -1, :] / temperature

            # Top-k sampling
            if top_k > 0:
                v, _ = torch.topk(next_logits, top_k)
                next_logits[next_logits < v[:, [-1]]] = float('-inf')

            probs = F.softmax(next_logits, dim=-1)
            next_token = torch.multinomial(probs, num_samples=1)

            generated = torch.cat([generated, next_token], dim=1)

        return generated

    def compute_flops_savings(self, seq_len: int) -> dict:
        """FLOPs 절약량 계산"""
        capacity_factor = self.model.blocks[0].capacity_factor
        num_layers = len(self.model.blocks)

        baseline_flops = num_layers * seq_len  # 상대적 단위
        mod_flops = num_layers * seq_len * capacity_factor

        return {
            'baseline_flops': baseline_flops,
            'mod_flops': mod_flops,
            'savings_percent': (1 - capacity_factor) * 100,
        }

응용 및 확장

적용 분야

분야 장점
실시간 추론 50% 이상 속도 향상
엣지 디바이스 FLOPs 절약으로 저전력
긴 문맥 처리 메모리 효율성
배치 추론 처리량 증가

변형 및 확장

  1. Layer-wise Capacity: 각 레이어마다 다른 capacity factor
  2. Dynamic Capacity: 입력에 따라 capacity 조절
  3. MoD + MoE: 두 기법 결합
  4. Auxiliary Router: 별도 경량 모델로 라우팅 결정

한계 및 고려사항

한계 설명
학습 복잡도 Router 학습에 추가 하이퍼파라미터 필요
작은 모델 큰 모델에서 효과가 더 큼
Task 의존성 모든 태스크에서 동일한 효과 X
구현 복잡도 효율적 구현에 커스텀 커널 필요

관련 연구

논문/기법 관계
Mixture-of-Experts 유사한 라우팅 메커니즘, 다른 적용 대상
Early Exit 토큰이 아닌 샘플 단위 동적 깊이
Adaptive Computation 일반적인 동적 연산 프레임워크
Universal Transformers 반복 횟수 동적 조절

참고 자료

논문

  • Raposo et al. (2024). Mixture-of-Depths: Dynamically allocating compute in transformer-based language models. arXiv:2404.02258

구현

관련 문서