콘텐츠로 이동
Data Prep
상세

Speculative Decoding

메타 정보

항목 내용
논문 Fast Inference from Transformers via Speculative Decoding
저자 Yaniv Leviathan, Matan Kalman, Yossi Matias
소속 Google Research
발표 ICML 2023
arXiv 2211.17192
분야 LLM Inference Optimization, Efficient Decoding

개요

Speculative Decoding은 대형 언어 모델(LLM)의 추론 속도를 향상시키는 기법이다. 작고 빠른 "draft" 모델이 여러 토큰을 미리 생성하고, 큰 "target" 모델이 이를 병렬로 검증하여 승인하는 방식으로 작동한다.

핵심 아이디어

LLM 추론의 병목은 메모리 대역폭이다. 토큰 하나를 생성할 때마다 전체 모델 가중치를 읽어야 하므로, 연산 자체보다 메모리 I/O가 시간을 지배한다. Speculative Decoding은 여러 토큰을 한 번에 검증함으로써 메모리 접근 횟수를 줄인다.

핵심 특성

  • 출력 분포가 원래 모델과 완전히 동일 (lossless)
  • 추가 학습 없이 기존 모델에 적용 가능
  • 2-3배 속도 향상 (모델 및 태스크에 따라 다름)
  • KV cache 구조와 호환

이론적 배경

자기회귀 디코딩의 문제점

표준 자기회귀 디코딩:

for t = 1 to T:
    p_t = Model(x_1, ..., x_{t-1})  # 전체 모델 forward pass
    x_t ~ p_t                        # 샘플링

문제점: - 각 토큰 생성에 전체 모델 가중치 로드 필요 - GPU 연산 유닛 활용률 낮음 (memory-bound) - 배치 크기 1에서 특히 비효율적

Speculative Decoding 원리

두 단계로 구성:

  1. Draft Phase: 작은 모델 M_q가 K개의 토큰을 빠르게 생성
  2. Verify Phase: 큰 모델 M_p가 K개 토큰을 병렬로 검증

검증 시 rejection sampling을 사용하여 원래 모델의 분포를 정확히 보존한다.


알고리즘

전체 흐름

입력:
  - M_p: Target 모델 (큰 모델)
  - M_q: Draft 모델 (작은 모델)
  - K: Speculation 길이
  - prefix: 현재까지의 토큰 시퀀스

출력: 생성된 토큰들

1. Draft Phase:
   draft_tokens = []
   for i = 1 to K:
       q_i = M_q(prefix + draft_tokens)
       x_i ~ q_i
       draft_tokens.append(x_i)

2. Verify Phase:
   # 병렬로 K+1개 위치의 확률 분포 계산
   p_1, p_2, ..., p_{K+1} = M_p(prefix + draft_tokens)

3. Acceptance Sampling:
   accepted = []
   for i = 1 to K:
       r ~ Uniform(0, 1)
       if r < min(1, p_i(x_i) / q_i(x_i)):
           accepted.append(x_i)
       else:
           # Rejection: 수정된 분포에서 샘플링
           x' ~ normalize(max(0, p_i - q_i))
           accepted.append(x')
           break

4. Bonus Token:
   # 모든 토큰이 승인되면 추가 토큰 획득
   if len(accepted) == K:
       x_{K+1} ~ p_{K+1}
       accepted.append(x_{K+1})

return accepted

수학적 정당성

Rejection sampling의 핵심 공식:

승인 확률:

P(accept x) = min(1, p(x) / q(x))

거절 시 샘플링 분포:

p'(x) = normalize(max(0, p(x) - q(x)))

이 방식은 최종 출력 분포가 정확히 p(x)임을 수학적으로 보장한다.

기대 토큰 수

한 번의 iteration에서 기대되는 생성 토큰 수:

E[# accepted] = sum_{i=1}^{K} prod_{j=1}^{i} alpha_j + prod_{j=1}^{K} alpha_j

여기서 alpha_j는 j번째 위치의 평균 승인 확률이다.


구현 상세

Draft 모델 선택 기준

기준 설명
속도 비율 Target 대비 5-10배 빠른 모델 권장
분포 유사성 Target과 유사한 출력 분포일수록 승인률 상승
아키텍처 호환 동일 토크나이저 필수

적합한 Draft 모델 예시: - Llama 70B + Llama 7B - GPT-4 + GPT-3.5 - 동일 모델의 양자화 버전

Speculation 길이 (K) 선택

K 값 장점 단점
작음 (2-3) 높은 승인률 속도 향상 제한적
큼 (8-10) 최대 속도 향상 가능 낮은 승인률 시 오버헤드

권장: K = 4-6에서 시작, 태스크별 튜닝

KV Cache 관리

Speculative Decoding에서 KV cache 처리:

1. Draft phase: Draft 모델의 KV cache 유지
2. Verify phase: Target 모델이 K+1 위치 동시 처리
3. Rollback: 거절 시 해당 위치 이후 cache 무효화
4. Update: 승인된 토큰들의 KV cache 보존

Python 구현 예시

기본 Speculative Decoding

import torch
import torch.nn.functional as F
from typing import List, Tuple
from transformers import AutoModelForCausalLM, AutoTokenizer


class SpeculativeDecoder:
    def __init__(
        self,
        target_model_name: str,
        draft_model_name: str,
        device: str = "cuda",
        speculation_length: int = 5
    ):
        """
        Speculative Decoding 구현

        Args:
            target_model_name: 큰 모델 (검증용)
            draft_model_name: 작은 모델 (초안 생성용)
            device: 실행 디바이스
            speculation_length: 한 번에 추측할 토큰 수 (K)
        """
        self.device = device
        self.K = speculation_length

        # Target 모델 (큰 모델)
        self.target_model = AutoModelForCausalLM.from_pretrained(
            target_model_name,
            torch_dtype=torch.float16,
            device_map="auto"
        )
        self.target_model.eval()

        # Draft 모델 (작은 모델)
        self.draft_model = AutoModelForCausalLM.from_pretrained(
            draft_model_name,
            torch_dtype=torch.float16
        ).to(device)
        self.draft_model.eval()

        # 토크나이저 (동일해야 함)
        self.tokenizer = AutoTokenizer.from_pretrained(target_model_name)

    @torch.no_grad()
    def draft_tokens(
        self,
        input_ids: torch.Tensor,
        attention_mask: torch.Tensor
    ) -> Tuple[torch.Tensor, torch.Tensor]:
        """
        Draft 모델로 K개 토큰 생성

        Returns:
            draft_tokens: 생성된 토큰들 [K]
            draft_probs: 각 토큰의 확률 분포 [K, vocab_size]
        """
        draft_tokens = []
        draft_probs = []

        current_ids = input_ids.clone()
        current_mask = attention_mask.clone()

        for _ in range(self.K):
            outputs = self.draft_model(
                input_ids=current_ids,
                attention_mask=current_mask
            )
            logits = outputs.logits[:, -1, :]
            probs = F.softmax(logits, dim=-1)

            # 샘플링
            next_token = torch.multinomial(probs, num_samples=1)

            draft_tokens.append(next_token)
            draft_probs.append(probs)

            # 다음 단계 준비
            current_ids = torch.cat([current_ids, next_token], dim=-1)
            current_mask = torch.cat([
                current_mask,
                torch.ones((1, 1), device=self.device)
            ], dim=-1)

        return (
            torch.cat(draft_tokens, dim=-1),
            torch.stack([p.squeeze(0) for p in draft_probs], dim=0)
        )

    @torch.no_grad()
    def verify_tokens(
        self,
        input_ids: torch.Tensor,
        attention_mask: torch.Tensor,
        draft_tokens: torch.Tensor
    ) -> torch.Tensor:
        """
        Target 모델로 K+1개 위치의 확률 분포 병렬 계산

        Returns:
            target_probs: 확률 분포 [K+1, vocab_size]
        """
        # Draft 토큰들을 포함한 전체 시퀀스
        full_ids = torch.cat([input_ids, draft_tokens.unsqueeze(0)], dim=-1)
        full_mask = torch.cat([
            attention_mask,
            torch.ones((1, self.K), device=self.device)
        ], dim=-1)

        outputs = self.target_model(
            input_ids=full_ids,
            attention_mask=full_mask
        )

        # 마지막 K+1 위치의 logits
        logits = outputs.logits[0, -(self.K + 1):, :]
        probs = F.softmax(logits, dim=-1)

        return probs

    @torch.no_grad()
    def speculative_sample(
        self,
        draft_tokens: torch.Tensor,
        draft_probs: torch.Tensor,
        target_probs: torch.Tensor
    ) -> torch.Tensor:
        """
        Rejection sampling으로 토큰 승인/거절

        Returns:
            accepted_tokens: 최종 승인된 토큰들
        """
        accepted = []

        for i in range(self.K):
            token = draft_tokens[i].item()

            # 승인 확률 계산
            p = target_probs[i, token].item()
            q = draft_probs[i, token].item()

            accept_prob = min(1.0, p / (q + 1e-10))

            if torch.rand(1).item() < accept_prob:
                # 승인
                accepted.append(token)
            else:
                # 거절: 수정된 분포에서 샘플링
                adjusted = torch.clamp(
                    target_probs[i] - draft_probs[i],
                    min=0
                )
                adjusted = adjusted / (adjusted.sum() + 1e-10)

                new_token = torch.multinomial(adjusted, num_samples=1)
                accepted.append(new_token.item())
                break

        # 모든 토큰 승인 시 보너스 토큰
        if len(accepted) == self.K:
            bonus_token = torch.multinomial(
                target_probs[self.K],
                num_samples=1
            )
            accepted.append(bonus_token.item())

        return torch.tensor(accepted, device=self.device)

    def generate(
        self,
        prompt: str,
        max_tokens: int = 100
    ) -> str:
        """
        Speculative Decoding으로 텍스트 생성
        """
        inputs = self.tokenizer(prompt, return_tensors="pt").to(self.device)
        input_ids = inputs.input_ids
        attention_mask = inputs.attention_mask

        generated_tokens = []

        while len(generated_tokens) < max_tokens:
            # 1. Draft phase
            draft_tokens, draft_probs = self.draft_tokens(
                input_ids, attention_mask
            )

            # 2. Verify phase
            target_probs = self.verify_tokens(
                input_ids, attention_mask, draft_tokens
            )

            # 3. Acceptance sampling
            accepted = self.speculative_sample(
                draft_tokens, draft_probs, target_probs
            )

            generated_tokens.extend(accepted.tolist())

            # 다음 iteration 준비
            input_ids = torch.cat([
                input_ids,
                accepted.unsqueeze(0)
            ], dim=-1)
            attention_mask = torch.cat([
                attention_mask,
                torch.ones((1, len(accepted)), device=self.device)
            ], dim=-1)

            # EOS 체크
            if self.tokenizer.eos_token_id in accepted.tolist():
                break

        return self.tokenizer.decode(generated_tokens, skip_special_tokens=True)

승인률 모니터링

class SpeculativeDecoderWithMetrics(SpeculativeDecoder):
    def __init__(self, *args, **kwargs):
        super().__init__(*args, **kwargs)
        self.total_drafted = 0
        self.total_accepted = 0
        self.iterations = 0

    def speculative_sample(self, draft_tokens, draft_probs, target_probs):
        result = super().speculative_sample(
            draft_tokens, draft_probs, target_probs
        )

        # 메트릭 업데이트
        self.total_drafted += self.K
        self.total_accepted += len(result)
        self.iterations += 1

        return result

    def get_metrics(self) -> dict:
        """현재까지의 성능 메트릭 반환"""
        return {
            "acceptance_rate": self.total_accepted / max(1, self.total_drafted),
            "tokens_per_iteration": self.total_accepted / max(1, self.iterations),
            "total_iterations": self.iterations,
            "speedup_estimate": self.total_accepted / max(1, self.iterations)
        }

    def reset_metrics(self):
        self.total_drafted = 0
        self.total_accepted = 0
        self.iterations = 0

실행 예시

# 사용 예시
decoder = SpeculativeDecoderWithMetrics(
    target_model_name="meta-llama/Llama-2-70b-hf",
    draft_model_name="meta-llama/Llama-2-7b-hf",
    speculation_length=5
)

prompt = "Explain the theory of relativity in simple terms:"
output = decoder.generate(prompt, max_tokens=200)

print(f"Generated: {output}")
print(f"Metrics: {decoder.get_metrics()}")

변형 및 후속 연구

Self-Speculative Decoding

  • 논문: Draft & Verify: Lossless Large Language Model Acceleration via Self-Speculative Decoding (2023)
  • 특징: 별도 draft 모델 없이 레이어 스킵으로 draft 생성
  • 장점: 추가 모델 메모리 불필요

Medusa

  • 논문: Medusa: Simple Framework for Accelerating LLM Generation with Multiple Decoding Heads (2024)
  • 특징: 여러 개의 prediction head로 병렬 토큰 예측
  • 장점: Draft 모델 없이 단일 모델로 구현

SpecInfer

  • 논문: SpecInfer: Accelerating Generative LLM Serving with Speculative Inference and Token Tree Verification (2024)
  • 특징: 트리 구조 speculation, 여러 경로 동시 검증
  • 발표: MLSys 2024

Lookahead Decoding

  • 논문: Break the Sequential Dependency of LLM Inference Using Lookahead Decoding (2024)
  • 특징: Jacobi iteration 기반, draft 모델 불필요
  • 장점: 단일 모델로 병렬 디코딩

EAGLE

  • 논문: EAGLE: Speculative Sampling Requires Rethinking Feature Uncertainty (2024)
  • 특징: Feature-level draft로 승인률 향상
  • 발표: ICML 2024

Staged Speculative Decoding

  • 논문: Staged Speculative Decoding for Faster LLM Inference (2024)
  • 특징: 여러 단계의 draft 모델 사용 (cascade)

성능 분석

속도 향상 요인

요인 영향
Draft 모델 속도 빠를수록 좋음 (target의 1/10 이하 권장)
승인률 높을수록 speedup 증가
Speculation 길이 적절한 K 선택 중요
하드웨어 Memory bandwidth가 병목인 환경에서 효과적

실험 결과 (원 논문 기준)

태스크 Target/Draft Speedup
번역 T5-XXL / T5-Small 2.8x
요약 PaLM 540B / PaLM 8B 2.5x
대화 Chinchilla 70B / Chinchilla 7B 2.3x

최적 사용 환경

  • 배치 크기가 작을 때 (특히 batch=1)
  • 메모리 대역폭이 제한적일 때
  • Greedy/low-temperature 샘플링
  • 긴 시퀀스 생성 태스크

한계점 및 고려사항

알려진 한계

  1. Draft 모델 필요: 추가 모델 로드/메모리 필요
  2. Batching 비효율: 배치 크기가 클 때 효과 감소
  3. High Temperature: 무작위성이 높으면 승인률 감소
  4. 초기 지연: 짧은 생성에서는 오버헤드가 클 수 있음

적용 가이드

적합한 경우: - 대화형 AI (실시간 응답 필요) - 코드 생성 (결정적 출력) - 번역 (일관된 패턴)

비적합한 경우: - 대규모 배치 처리 - 창의적 글쓰기 (high temperature) - 매우 짧은 응답


실무 적용 팁

Draft 모델 선택

# 좋은 조합 예시
DRAFT_TARGET_PAIRS = {
    # Target: Draft
    "meta-llama/Llama-2-70b-hf": "meta-llama/Llama-2-7b-hf",
    "mistralai/Mixtral-8x7B": "mistralai/Mistral-7B-v0.1",
    "Qwen/Qwen2-72B": "Qwen/Qwen2-7B",
    # Self-draft (양자화)
    "model": "model-int4",  # 동일 모델의 양자화 버전
}

하이퍼파라미터 튜닝

# 태스크별 권장 설정
TASK_CONFIGS = {
    "code_generation": {
        "speculation_length": 6,
        "temperature": 0.2,
        "expected_speedup": "2.5-3x"
    },
    "translation": {
        "speculation_length": 5,
        "temperature": 0.0,
        "expected_speedup": "2.5-3x"
    },
    "chat": {
        "speculation_length": 4,
        "temperature": 0.7,
        "expected_speedup": "1.8-2.2x"
    },
    "creative_writing": {
        "speculation_length": 3,
        "temperature": 1.0,
        "expected_speedup": "1.3-1.5x"
    }
}

참고 문헌

  1. Leviathan, Y., Kalman, M., & Matias, Y. (2023). Fast Inference from Transformers via Speculative Decoding. ICML 2023.
  2. Chen, C., et al. (2023). Accelerating Large Language Model Decoding with Speculative Sampling. arXiv:2302.01318.
  3. Zhang, J., et al. (2023). Draft & Verify: Lossless Large Language Model Acceleration via Self-Speculative Decoding. arXiv:2309.08168.
  4. Cai, T., et al. (2024). Medusa: Simple Framework for Accelerating LLM Generation with Multiple Decoding Heads. ICML 2024.
  5. Miao, X., et al. (2024). SpecInfer: Accelerating Generative LLM Serving with Speculative Inference and Token Tree Verification. MLSys 2024.
  6. Li, Y., et al. (2024). EAGLE: Speculative Sampling Requires Rethinking Feature Uncertainty. ICML 2024.
  7. Fu, Y., et al. (2024). Break the Sequential Dependency of LLM Inference Using Lookahead Decoding. arXiv:2402.02057.