Speculative Decoding¶
메타 정보¶
| 항목 | 내용 |
|---|---|
| 논문 | Fast Inference from Transformers via Speculative Decoding |
| 저자 | Yaniv Leviathan, Matan Kalman, Yossi Matias |
| 소속 | Google Research |
| 발표 | ICML 2023 |
| arXiv | 2211.17192 |
| 분야 | LLM Inference Optimization, Efficient Decoding |
개요¶
Speculative Decoding은 대형 언어 모델(LLM)의 추론 속도를 향상시키는 기법이다. 작고 빠른 "draft" 모델이 여러 토큰을 미리 생성하고, 큰 "target" 모델이 이를 병렬로 검증하여 승인하는 방식으로 작동한다.
핵심 아이디어¶
LLM 추론의 병목은 메모리 대역폭이다. 토큰 하나를 생성할 때마다 전체 모델 가중치를 읽어야 하므로, 연산 자체보다 메모리 I/O가 시간을 지배한다. Speculative Decoding은 여러 토큰을 한 번에 검증함으로써 메모리 접근 횟수를 줄인다.
핵심 특성¶
- 출력 분포가 원래 모델과 완전히 동일 (lossless)
- 추가 학습 없이 기존 모델에 적용 가능
- 2-3배 속도 향상 (모델 및 태스크에 따라 다름)
- KV cache 구조와 호환
이론적 배경¶
자기회귀 디코딩의 문제점¶
표준 자기회귀 디코딩:
문제점: - 각 토큰 생성에 전체 모델 가중치 로드 필요 - GPU 연산 유닛 활용률 낮음 (memory-bound) - 배치 크기 1에서 특히 비효율적
Speculative Decoding 원리¶
두 단계로 구성:
- Draft Phase: 작은 모델 M_q가 K개의 토큰을 빠르게 생성
- Verify Phase: 큰 모델 M_p가 K개 토큰을 병렬로 검증
검증 시 rejection sampling을 사용하여 원래 모델의 분포를 정확히 보존한다.
알고리즘¶
전체 흐름¶
입력:
- M_p: Target 모델 (큰 모델)
- M_q: Draft 모델 (작은 모델)
- K: Speculation 길이
- prefix: 현재까지의 토큰 시퀀스
출력: 생성된 토큰들
1. Draft Phase:
draft_tokens = []
for i = 1 to K:
q_i = M_q(prefix + draft_tokens)
x_i ~ q_i
draft_tokens.append(x_i)
2. Verify Phase:
# 병렬로 K+1개 위치의 확률 분포 계산
p_1, p_2, ..., p_{K+1} = M_p(prefix + draft_tokens)
3. Acceptance Sampling:
accepted = []
for i = 1 to K:
r ~ Uniform(0, 1)
if r < min(1, p_i(x_i) / q_i(x_i)):
accepted.append(x_i)
else:
# Rejection: 수정된 분포에서 샘플링
x' ~ normalize(max(0, p_i - q_i))
accepted.append(x')
break
4. Bonus Token:
# 모든 토큰이 승인되면 추가 토큰 획득
if len(accepted) == K:
x_{K+1} ~ p_{K+1}
accepted.append(x_{K+1})
return accepted
수학적 정당성¶
Rejection sampling의 핵심 공식:
승인 확률:
거절 시 샘플링 분포:
이 방식은 최종 출력 분포가 정확히 p(x)임을 수학적으로 보장한다.
기대 토큰 수¶
한 번의 iteration에서 기대되는 생성 토큰 수:
여기서 alpha_j는 j번째 위치의 평균 승인 확률이다.
구현 상세¶
Draft 모델 선택 기준¶
| 기준 | 설명 |
|---|---|
| 속도 비율 | Target 대비 5-10배 빠른 모델 권장 |
| 분포 유사성 | Target과 유사한 출력 분포일수록 승인률 상승 |
| 아키텍처 호환 | 동일 토크나이저 필수 |
적합한 Draft 모델 예시: - Llama 70B + Llama 7B - GPT-4 + GPT-3.5 - 동일 모델의 양자화 버전
Speculation 길이 (K) 선택¶
| K 값 | 장점 | 단점 |
|---|---|---|
| 작음 (2-3) | 높은 승인률 | 속도 향상 제한적 |
| 큼 (8-10) | 최대 속도 향상 가능 | 낮은 승인률 시 오버헤드 |
권장: K = 4-6에서 시작, 태스크별 튜닝
KV Cache 관리¶
Speculative Decoding에서 KV cache 처리:
1. Draft phase: Draft 모델의 KV cache 유지
2. Verify phase: Target 모델이 K+1 위치 동시 처리
3. Rollback: 거절 시 해당 위치 이후 cache 무효화
4. Update: 승인된 토큰들의 KV cache 보존
Python 구현 예시¶
기본 Speculative Decoding¶
import torch
import torch.nn.functional as F
from typing import List, Tuple
from transformers import AutoModelForCausalLM, AutoTokenizer
class SpeculativeDecoder:
def __init__(
self,
target_model_name: str,
draft_model_name: str,
device: str = "cuda",
speculation_length: int = 5
):
"""
Speculative Decoding 구현
Args:
target_model_name: 큰 모델 (검증용)
draft_model_name: 작은 모델 (초안 생성용)
device: 실행 디바이스
speculation_length: 한 번에 추측할 토큰 수 (K)
"""
self.device = device
self.K = speculation_length
# Target 모델 (큰 모델)
self.target_model = AutoModelForCausalLM.from_pretrained(
target_model_name,
torch_dtype=torch.float16,
device_map="auto"
)
self.target_model.eval()
# Draft 모델 (작은 모델)
self.draft_model = AutoModelForCausalLM.from_pretrained(
draft_model_name,
torch_dtype=torch.float16
).to(device)
self.draft_model.eval()
# 토크나이저 (동일해야 함)
self.tokenizer = AutoTokenizer.from_pretrained(target_model_name)
@torch.no_grad()
def draft_tokens(
self,
input_ids: torch.Tensor,
attention_mask: torch.Tensor
) -> Tuple[torch.Tensor, torch.Tensor]:
"""
Draft 모델로 K개 토큰 생성
Returns:
draft_tokens: 생성된 토큰들 [K]
draft_probs: 각 토큰의 확률 분포 [K, vocab_size]
"""
draft_tokens = []
draft_probs = []
current_ids = input_ids.clone()
current_mask = attention_mask.clone()
for _ in range(self.K):
outputs = self.draft_model(
input_ids=current_ids,
attention_mask=current_mask
)
logits = outputs.logits[:, -1, :]
probs = F.softmax(logits, dim=-1)
# 샘플링
next_token = torch.multinomial(probs, num_samples=1)
draft_tokens.append(next_token)
draft_probs.append(probs)
# 다음 단계 준비
current_ids = torch.cat([current_ids, next_token], dim=-1)
current_mask = torch.cat([
current_mask,
torch.ones((1, 1), device=self.device)
], dim=-1)
return (
torch.cat(draft_tokens, dim=-1),
torch.stack([p.squeeze(0) for p in draft_probs], dim=0)
)
@torch.no_grad()
def verify_tokens(
self,
input_ids: torch.Tensor,
attention_mask: torch.Tensor,
draft_tokens: torch.Tensor
) -> torch.Tensor:
"""
Target 모델로 K+1개 위치의 확률 분포 병렬 계산
Returns:
target_probs: 확률 분포 [K+1, vocab_size]
"""
# Draft 토큰들을 포함한 전체 시퀀스
full_ids = torch.cat([input_ids, draft_tokens.unsqueeze(0)], dim=-1)
full_mask = torch.cat([
attention_mask,
torch.ones((1, self.K), device=self.device)
], dim=-1)
outputs = self.target_model(
input_ids=full_ids,
attention_mask=full_mask
)
# 마지막 K+1 위치의 logits
logits = outputs.logits[0, -(self.K + 1):, :]
probs = F.softmax(logits, dim=-1)
return probs
@torch.no_grad()
def speculative_sample(
self,
draft_tokens: torch.Tensor,
draft_probs: torch.Tensor,
target_probs: torch.Tensor
) -> torch.Tensor:
"""
Rejection sampling으로 토큰 승인/거절
Returns:
accepted_tokens: 최종 승인된 토큰들
"""
accepted = []
for i in range(self.K):
token = draft_tokens[i].item()
# 승인 확률 계산
p = target_probs[i, token].item()
q = draft_probs[i, token].item()
accept_prob = min(1.0, p / (q + 1e-10))
if torch.rand(1).item() < accept_prob:
# 승인
accepted.append(token)
else:
# 거절: 수정된 분포에서 샘플링
adjusted = torch.clamp(
target_probs[i] - draft_probs[i],
min=0
)
adjusted = adjusted / (adjusted.sum() + 1e-10)
new_token = torch.multinomial(adjusted, num_samples=1)
accepted.append(new_token.item())
break
# 모든 토큰 승인 시 보너스 토큰
if len(accepted) == self.K:
bonus_token = torch.multinomial(
target_probs[self.K],
num_samples=1
)
accepted.append(bonus_token.item())
return torch.tensor(accepted, device=self.device)
def generate(
self,
prompt: str,
max_tokens: int = 100
) -> str:
"""
Speculative Decoding으로 텍스트 생성
"""
inputs = self.tokenizer(prompt, return_tensors="pt").to(self.device)
input_ids = inputs.input_ids
attention_mask = inputs.attention_mask
generated_tokens = []
while len(generated_tokens) < max_tokens:
# 1. Draft phase
draft_tokens, draft_probs = self.draft_tokens(
input_ids, attention_mask
)
# 2. Verify phase
target_probs = self.verify_tokens(
input_ids, attention_mask, draft_tokens
)
# 3. Acceptance sampling
accepted = self.speculative_sample(
draft_tokens, draft_probs, target_probs
)
generated_tokens.extend(accepted.tolist())
# 다음 iteration 준비
input_ids = torch.cat([
input_ids,
accepted.unsqueeze(0)
], dim=-1)
attention_mask = torch.cat([
attention_mask,
torch.ones((1, len(accepted)), device=self.device)
], dim=-1)
# EOS 체크
if self.tokenizer.eos_token_id in accepted.tolist():
break
return self.tokenizer.decode(generated_tokens, skip_special_tokens=True)
승인률 모니터링¶
class SpeculativeDecoderWithMetrics(SpeculativeDecoder):
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
self.total_drafted = 0
self.total_accepted = 0
self.iterations = 0
def speculative_sample(self, draft_tokens, draft_probs, target_probs):
result = super().speculative_sample(
draft_tokens, draft_probs, target_probs
)
# 메트릭 업데이트
self.total_drafted += self.K
self.total_accepted += len(result)
self.iterations += 1
return result
def get_metrics(self) -> dict:
"""현재까지의 성능 메트릭 반환"""
return {
"acceptance_rate": self.total_accepted / max(1, self.total_drafted),
"tokens_per_iteration": self.total_accepted / max(1, self.iterations),
"total_iterations": self.iterations,
"speedup_estimate": self.total_accepted / max(1, self.iterations)
}
def reset_metrics(self):
self.total_drafted = 0
self.total_accepted = 0
self.iterations = 0
실행 예시¶
# 사용 예시
decoder = SpeculativeDecoderWithMetrics(
target_model_name="meta-llama/Llama-2-70b-hf",
draft_model_name="meta-llama/Llama-2-7b-hf",
speculation_length=5
)
prompt = "Explain the theory of relativity in simple terms:"
output = decoder.generate(prompt, max_tokens=200)
print(f"Generated: {output}")
print(f"Metrics: {decoder.get_metrics()}")
변형 및 후속 연구¶
Self-Speculative Decoding¶
- 논문: Draft & Verify: Lossless Large Language Model Acceleration via Self-Speculative Decoding (2023)
- 특징: 별도 draft 모델 없이 레이어 스킵으로 draft 생성
- 장점: 추가 모델 메모리 불필요
Medusa¶
- 논문: Medusa: Simple Framework for Accelerating LLM Generation with Multiple Decoding Heads (2024)
- 특징: 여러 개의 prediction head로 병렬 토큰 예측
- 장점: Draft 모델 없이 단일 모델로 구현
SpecInfer¶
- 논문: SpecInfer: Accelerating Generative LLM Serving with Speculative Inference and Token Tree Verification (2024)
- 특징: 트리 구조 speculation, 여러 경로 동시 검증
- 발표: MLSys 2024
Lookahead Decoding¶
- 논문: Break the Sequential Dependency of LLM Inference Using Lookahead Decoding (2024)
- 특징: Jacobi iteration 기반, draft 모델 불필요
- 장점: 단일 모델로 병렬 디코딩
EAGLE¶
- 논문: EAGLE: Speculative Sampling Requires Rethinking Feature Uncertainty (2024)
- 특징: Feature-level draft로 승인률 향상
- 발표: ICML 2024
Staged Speculative Decoding¶
- 논문: Staged Speculative Decoding for Faster LLM Inference (2024)
- 특징: 여러 단계의 draft 모델 사용 (cascade)
성능 분석¶
속도 향상 요인¶
| 요인 | 영향 |
|---|---|
| Draft 모델 속도 | 빠를수록 좋음 (target의 1/10 이하 권장) |
| 승인률 | 높을수록 speedup 증가 |
| Speculation 길이 | 적절한 K 선택 중요 |
| 하드웨어 | Memory bandwidth가 병목인 환경에서 효과적 |
실험 결과 (원 논문 기준)¶
| 태스크 | Target/Draft | Speedup |
|---|---|---|
| 번역 | T5-XXL / T5-Small | 2.8x |
| 요약 | PaLM 540B / PaLM 8B | 2.5x |
| 대화 | Chinchilla 70B / Chinchilla 7B | 2.3x |
최적 사용 환경¶
- 배치 크기가 작을 때 (특히 batch=1)
- 메모리 대역폭이 제한적일 때
- Greedy/low-temperature 샘플링
- 긴 시퀀스 생성 태스크
한계점 및 고려사항¶
알려진 한계¶
- Draft 모델 필요: 추가 모델 로드/메모리 필요
- Batching 비효율: 배치 크기가 클 때 효과 감소
- High Temperature: 무작위성이 높으면 승인률 감소
- 초기 지연: 짧은 생성에서는 오버헤드가 클 수 있음
적용 가이드¶
적합한 경우: - 대화형 AI (실시간 응답 필요) - 코드 생성 (결정적 출력) - 번역 (일관된 패턴)
비적합한 경우: - 대규모 배치 처리 - 창의적 글쓰기 (high temperature) - 매우 짧은 응답
실무 적용 팁¶
Draft 모델 선택¶
# 좋은 조합 예시
DRAFT_TARGET_PAIRS = {
# Target: Draft
"meta-llama/Llama-2-70b-hf": "meta-llama/Llama-2-7b-hf",
"mistralai/Mixtral-8x7B": "mistralai/Mistral-7B-v0.1",
"Qwen/Qwen2-72B": "Qwen/Qwen2-7B",
# Self-draft (양자화)
"model": "model-int4", # 동일 모델의 양자화 버전
}
하이퍼파라미터 튜닝¶
# 태스크별 권장 설정
TASK_CONFIGS = {
"code_generation": {
"speculation_length": 6,
"temperature": 0.2,
"expected_speedup": "2.5-3x"
},
"translation": {
"speculation_length": 5,
"temperature": 0.0,
"expected_speedup": "2.5-3x"
},
"chat": {
"speculation_length": 4,
"temperature": 0.7,
"expected_speedup": "1.8-2.2x"
},
"creative_writing": {
"speculation_length": 3,
"temperature": 1.0,
"expected_speedup": "1.3-1.5x"
}
}
참고 문헌¶
- Leviathan, Y., Kalman, M., & Matias, Y. (2023). Fast Inference from Transformers via Speculative Decoding. ICML 2023.
- Chen, C., et al. (2023). Accelerating Large Language Model Decoding with Speculative Sampling. arXiv:2302.01318.
- Zhang, J., et al. (2023). Draft & Verify: Lossless Large Language Model Acceleration via Self-Speculative Decoding. arXiv:2309.08168.
- Cai, T., et al. (2024). Medusa: Simple Framework for Accelerating LLM Generation with Multiple Decoding Heads. ICML 2024.
- Miao, X., et al. (2024). SpecInfer: Accelerating Generative LLM Serving with Speculative Inference and Token Tree Verification. MLSys 2024.
- Li, Y., et al. (2024). EAGLE: Speculative Sampling Requires Rethinking Feature Uncertainty. ICML 2024.
- Fu, Y., et al. (2024). Break the Sequential Dependency of LLM Inference Using Lookahead Decoding. arXiv:2402.02057.