Efficient Attention Mechanisms¶
개요¶
Transformer의 Self-Attention은 O(n²) 복잡도로 긴 시퀀스에서 병목이 된다. Efficient Attention은 이 한계를 극복하기 위한 최적화 기법들이다.
표준 Attention 복잡도¶
문제점: - 시퀀스 길이 2배 → 메모리/연산 4배 - 128K 토큰 컨텍스트: 수백 GB 메모리 필요
Flash Attention¶
핵심 아이디어¶
GPU 메모리 계층 구조를 활용한 IO-aware 알고리즘.
┌─────────────────────────────────────┐
│ GPU Architecture │
├─────────────────────────────────────┤
│ HBM (High Bandwidth Memory) │ ← 느림, 대용량 (80GB)
│ ↕ 데이터 전송 (병목) │
│ SRAM (On-chip Memory) │ ← 빠름, 소용량 (20MB)
└─────────────────────────────────────┘
기존 방식: 전체 Attention 행렬을 HBM에 저장 Flash Attention: 블록 단위로 SRAM에서 연산, HBM 접근 최소화
알고리즘¶
1. Q, K, V를 블록으로 분할
2. 각 블록을 SRAM에 로드
3. 블록 단위로 Attention 연산
4. Online Softmax로 점진적 정규화
5. 결과를 HBM에 기록
# 개념적 구현 (실제는 CUDA 커널)
def flash_attention(Q, K, V, block_size=256):
n, d = Q.shape
O = zeros_like(Q)
for i in range(0, n, block_size):
Qi = Q[i:i+block_size] # SRAM에 로드
for j in range(0, n, block_size):
Kj = K[j:j+block_size]
Vj = V[j:j+block_size]
# SRAM에서 연산
Sij = Qi @ Kj.T / sqrt(d)
Pij = softmax(Sij)
O[i:i+block_size] += Pij @ Vj
return O
성능 비교¶
| 시퀀스 길이 | PyTorch | Flash Attention v1 | Flash Attention v2 |
|---|---|---|---|
| 2K | 1.0x | 2.5x | 3.5x |
| 8K | OOM | 3.0x | 4.0x |
| 32K | OOM | 3.5x | 5.0x |
| 128K | OOM | 4.0x | 6.0x |
사용법¶
# PyTorch 2.0+
import torch
from torch.nn.functional import scaled_dot_product_attention
# 자동으로 Flash Attention 사용
output = scaled_dot_product_attention(
query, key, value,
attn_mask=None,
dropout_p=0.0,
is_causal=True, # Causal masking
)
# HuggingFace Transformers
model = AutoModelForCausalLM.from_pretrained(
"meta-llama/Llama-3.1-8B",
attn_implementation="flash_attention_2",
torch_dtype=torch.bfloat16,
)
Flash Attention 3¶
개선점 (Hopper 아키텍처)¶
Flash Attention 2 → Flash Attention 3
- 비동기 연산 (warpgroup 레벨)
- FP8 지원
- 블록 양자화 통합
- 1.5-2x 추가 속도 향상
성능¶
| GPU | FA2 (TFLOPS) | FA3 (TFLOPS) | 개선 |
|---|---|---|---|
| H100 SXM | 335 | 740 | 2.2x |
| H100 PCIe | 280 | 580 | 2.1x |
Multi-Query Attention (MQA)¶
구조¶
Standard MHA:
Q: [n_heads, head_dim]
K: [n_heads, head_dim]
V: [n_heads, head_dim]
Multi-Query:
Q: [n_heads, head_dim]
K: [1, head_dim] ← 공유
V: [1, head_dim] ← 공유
KV 캐시 절감¶
Llama 7B 예시:
- MHA KV 캐시: 32 heads × 128 dim × seq_len × 2 (K,V)
- MQA KV 캐시: 1 × 128 dim × seq_len × 2
→ 32배 메모리 절감
품질 vs 효율 트레이드오프¶
| 방식 | KV 캐시 | 품질 손실 |
|---|---|---|
| MHA | 100% | 0% |
| GQA (8 groups) | 25% | 0.5-1% |
| MQA | 3% | 1-3% |
Grouped-Query Attention (GQA)¶
MHA와 MQA의 절충안¶
MHA: 32 heads → 32 KV heads
GQA: 32 heads → 8 KV heads (4개씩 그룹)
MQA: 32 heads → 1 KV head
┌─────────────────────────────────────┐
│ Query Heads: [1][2][3][4][5][6][7][8]... │
│ \___/\___/\___/\___/ │
│ ▼ ▼ ▼ ▼ │
│ KV Groups: [1] [2] [3] [4]... │
└─────────────────────────────────────────┘
주요 모델 채택 현황¶
| 모델 | KV Heads | Query Heads | 비율 |
|---|---|---|---|
| Llama 3.1 8B | 8 | 32 | 4:1 |
| Llama 3.1 70B | 8 | 64 | 8:1 |
| Mistral 7B | 8 | 32 | 4:1 |
| Gemma 2 | 8 | 16 | 2:1 |
Sliding Window Attention¶
구조¶
Standard Attention:
Token i attends to: [0, 1, 2, ..., i]
Sliding Window (window=4):
Token i attends to: [max(0, i-4), ..., i]
┌───────────────────────────────┐
│ i=5 기준 │
│ Standard: [0,1,2,3,4,5] │
│ Sliding: [1,2,3,4,5] │
└───────────────────────────────┘
장점¶
- 고정 메모리 사용 (시퀀스 길이 무관)
- 로컬 패턴 효과적 캡처
- Mistral, Phi-3에서 채택
제한점¶
- 장거리 의존성 약화
- 윈도우 크기 선택 중요
Ring Attention¶
분산 Attention¶
Device 1: Q1, K1, V1
Device 2: Q2, K2, V2
Device 3: Q3, K3, V3
Device 4: Q4, K4, V4
↓ Ring 통신으로 KV 순환
Device 1: Q1 × [K1,K2,K3,K4]
Device 2: Q2 × [K2,K3,K4,K1]
...
확장성¶
- 시퀀스를 디바이스 수만큼 분할
- 1M+ 토큰 컨텍스트 가능
- Google의 Gemini에서 활용
PagedAttention¶
vLLM의 핵심 기술¶
기존 KV 캐시:
┌────────────────────────────────┐
│ 연속 메모리 할당 (낭비 발생) │
│ [████████░░░░░░░░░░░░░░░░░░░░] │
└────────────────────────────────┘
PagedAttention:
┌────┬────┬────┬────┬────┬────┐
│Blk1│Blk2│Blk3│ ...│ │ │
└────┴────┴────┴────┴────┴────┘
↓ ↓ ↓
Page Table로 관리
장점¶
- 메모리 단편화 해소
- 동적 배치 크기
- prefix 공유 효율화
사용 예시¶
from vllm import LLM
llm = LLM(
model="meta-llama/Llama-3.1-8B-Instruct",
# PagedAttention 자동 적용
gpu_memory_utilization=0.95,
)
Linear Attention¶
O(n) 복잡도 달성¶
대표 모델¶
| 모델 | 기법 | 특징 |
|---|---|---|
| Linear Transformer | Feature map | 초기 연구 |
| Performer | FAVOR+ | Random features |
| RWKV | WKV | RNN-like |
| Mamba | S4/S6 | State space |
| RetNet | Retention | 선형 + 청크 |
트레이드오프¶
┌─────────────────────────────────────┐
│ 품질: Standard > Linear │
│ 속도: Linear > Standard (긴 seq) │
│ 메모리: Linear > Standard │
│ 학습: Standard > Linear (안정성) │
└─────────────────────────────────────┘
선택 가이드¶
상황별 추천¶
| 상황 | 추천 기법 |
|---|---|
| 일반 추론 최적화 | Flash Attention 2/3 |
| 긴 컨텍스트 (32K+) | Flash + GQA |
| 실시간 서빙 | PagedAttention (vLLM) |
| 무한 컨텍스트 | Ring Attention |
| 엣지 디바이스 | MQA + 양자화 |
| 연구/실험 | Linear Attention |
조합 예시¶
프로덕션 LLM 서버:
초장문 처리:
구현 체크리스트¶
- [ ] PyTorch 2.0+ 사용 (SDPA 자동 적용)
- [ ] bfloat16/float16 사용
- [ ] Flash Attention 2 지원 확인
- [ ] GQA 모델 선택 (신규 배포 시)
- [ ] vLLM/TGI 서빙 프레임워크 활용
- [ ] KV 캐시 크기 모니터링
참고 자료¶
- "FlashAttention: Fast and Memory-Efficient Exact Attention" (Dao et al., 2022)
- "FlashAttention-2: Faster Attention with Better Parallelism" (2023)
- "FlashAttention-3: Fast and Accurate Attention with Asynchrony and Low-precision" (2024)
- "GQA: Training Generalized Multi-Query Transformer Models" (Ainslie et al., 2023)
- "Ring Attention with Blockwise Transformers for Near-Infinite Context" (2024)
최종 업데이트: 2026-02-18