콘텐츠로 이동
Data Prep
상세

Efficient Attention Mechanisms

개요

Transformer의 Self-Attention은 O(n²) 복잡도로 긴 시퀀스에서 병목이 된다. Efficient Attention은 이 한계를 극복하기 위한 최적화 기법들이다.

표준 Attention 복잡도

시간 복잡도: O(n² · d)
공간 복잡도: O(n² + n·d)

n = 시퀀스 길이
d = 임베딩 차원

문제점: - 시퀀스 길이 2배 → 메모리/연산 4배 - 128K 토큰 컨텍스트: 수백 GB 메모리 필요

Flash Attention

핵심 아이디어

GPU 메모리 계층 구조를 활용한 IO-aware 알고리즘.

┌─────────────────────────────────────┐
│           GPU Architecture          │
├─────────────────────────────────────┤
│  HBM (High Bandwidth Memory)        │  ← 느림, 대용량 (80GB)
│  ↕ 데이터 전송 (병목)                │
│  SRAM (On-chip Memory)              │  ← 빠름, 소용량 (20MB)
└─────────────────────────────────────┘

기존 방식: 전체 Attention 행렬을 HBM에 저장 Flash Attention: 블록 단위로 SRAM에서 연산, HBM 접근 최소화

알고리즘

1. Q, K, V를 블록으로 분할
2. 각 블록을 SRAM에 로드
3. 블록 단위로 Attention 연산
4. Online Softmax로 점진적 정규화
5. 결과를 HBM에 기록
# 개념적 구현 (실제는 CUDA 커널)
def flash_attention(Q, K, V, block_size=256):
    n, d = Q.shape
    O = zeros_like(Q)

    for i in range(0, n, block_size):
        Qi = Q[i:i+block_size]  # SRAM에 로드

        for j in range(0, n, block_size):
            Kj = K[j:j+block_size]
            Vj = V[j:j+block_size]

            # SRAM에서 연산
            Sij = Qi @ Kj.T / sqrt(d)
            Pij = softmax(Sij)
            O[i:i+block_size] += Pij @ Vj

    return O

성능 비교

시퀀스 길이 PyTorch Flash Attention v1 Flash Attention v2
2K 1.0x 2.5x 3.5x
8K OOM 3.0x 4.0x
32K OOM 3.5x 5.0x
128K OOM 4.0x 6.0x

사용법

# PyTorch 2.0+
import torch
from torch.nn.functional import scaled_dot_product_attention

# 자동으로 Flash Attention 사용
output = scaled_dot_product_attention(
    query, key, value,
    attn_mask=None,
    dropout_p=0.0,
    is_causal=True,  # Causal masking
)

# HuggingFace Transformers
model = AutoModelForCausalLM.from_pretrained(
    "meta-llama/Llama-3.1-8B",
    attn_implementation="flash_attention_2",
    torch_dtype=torch.bfloat16,
)

Flash Attention 3

개선점 (Hopper 아키텍처)

Flash Attention 2 → Flash Attention 3
- 비동기 연산 (warpgroup 레벨)
- FP8 지원
- 블록 양자화 통합
- 1.5-2x 추가 속도 향상

성능

GPU FA2 (TFLOPS) FA3 (TFLOPS) 개선
H100 SXM 335 740 2.2x
H100 PCIe 280 580 2.1x

Multi-Query Attention (MQA)

구조

Standard MHA:
Q: [n_heads, head_dim]
K: [n_heads, head_dim]
V: [n_heads, head_dim]

Multi-Query:
Q: [n_heads, head_dim]
K: [1, head_dim]        ← 공유
V: [1, head_dim]        ← 공유

KV 캐시 절감

Llama 7B 예시:
- MHA KV 캐시: 32 heads × 128 dim × seq_len × 2 (K,V)
- MQA KV 캐시: 1 × 128 dim × seq_len × 2

→ 32배 메모리 절감

품질 vs 효율 트레이드오프

방식 KV 캐시 품질 손실
MHA 100% 0%
GQA (8 groups) 25% 0.5-1%
MQA 3% 1-3%

Grouped-Query Attention (GQA)

MHA와 MQA의 절충안

MHA: 32 heads → 32 KV heads
GQA: 32 heads → 8 KV heads (4개씩 그룹)
MQA: 32 heads → 1 KV head

┌─────────────────────────────────────┐
│ Query Heads:  [1][2][3][4][5][6][7][8]...  │
│               \___/\___/\___/\___/        │
│                 ▼    ▼    ▼    ▼          │
│ KV Groups:    [1]  [2]  [3]  [4]...       │
└─────────────────────────────────────────┘

주요 모델 채택 현황

모델 KV Heads Query Heads 비율
Llama 3.1 8B 8 32 4:1
Llama 3.1 70B 8 64 8:1
Mistral 7B 8 32 4:1
Gemma 2 8 16 2:1

Sliding Window Attention

구조

Standard Attention:
Token i attends to: [0, 1, 2, ..., i]

Sliding Window (window=4):
Token i attends to: [max(0, i-4), ..., i]

┌───────────────────────────────┐
│    i=5 기준                    │
│    Standard: [0,1,2,3,4,5]    │
│    Sliding:  [1,2,3,4,5]      │
└───────────────────────────────┘

장점

  • 고정 메모리 사용 (시퀀스 길이 무관)
  • 로컬 패턴 효과적 캡처
  • Mistral, Phi-3에서 채택

제한점

  • 장거리 의존성 약화
  • 윈도우 크기 선택 중요

Ring Attention

분산 Attention

Device 1: Q1, K1, V1
Device 2: Q2, K2, V2
Device 3: Q3, K3, V3
Device 4: Q4, K4, V4

↓ Ring 통신으로 KV 순환

Device 1: Q1 × [K1,K2,K3,K4]
Device 2: Q2 × [K2,K3,K4,K1]
...

확장성

  • 시퀀스를 디바이스 수만큼 분할
  • 1M+ 토큰 컨텍스트 가능
  • Google의 Gemini에서 활용

PagedAttention

vLLM의 핵심 기술

기존 KV 캐시:
┌────────────────────────────────┐
│ 연속 메모리 할당 (낭비 발생)     │
│ [████████░░░░░░░░░░░░░░░░░░░░] │
└────────────────────────────────┘

PagedAttention:
┌────┬────┬────┬────┬────┬────┐
│Blk1│Blk2│Blk3│ ...│    │    │
└────┴────┴────┴────┴────┴────┘
  ↓     ↓     ↓
 Page Table로 관리

장점

  • 메모리 단편화 해소
  • 동적 배치 크기
  • prefix 공유 효율화

사용 예시

from vllm import LLM

llm = LLM(
    model="meta-llama/Llama-3.1-8B-Instruct",
    # PagedAttention 자동 적용
    gpu_memory_utilization=0.95,
)

Linear Attention

O(n) 복잡도 달성

Standard: Softmax(QK^T)V → O(n²)
Linear:   φ(Q)(φ(K)^T V) → O(n)

φ = 커널 함수 (e.g., elu + 1)

대표 모델

모델 기법 특징
Linear Transformer Feature map 초기 연구
Performer FAVOR+ Random features
RWKV WKV RNN-like
Mamba S4/S6 State space
RetNet Retention 선형 + 청크

트레이드오프

┌─────────────────────────────────────┐
│ 품질:    Standard > Linear          │
│ 속도:    Linear > Standard (긴 seq)  │
│ 메모리:  Linear > Standard          │
│ 학습:    Standard > Linear (안정성)  │
└─────────────────────────────────────┘

선택 가이드

상황별 추천

상황 추천 기법
일반 추론 최적화 Flash Attention 2/3
긴 컨텍스트 (32K+) Flash + GQA
실시간 서빙 PagedAttention (vLLM)
무한 컨텍스트 Ring Attention
엣지 디바이스 MQA + 양자화
연구/실험 Linear Attention

조합 예시

프로덕션 LLM 서버:

Flash Attention 2 + GQA + PagedAttention
→ 처리량 4-6x 향상

초장문 처리:

Ring Attention + Flash Attention + Sliding Window
→ 1M 토큰 가능

구현 체크리스트

  • [ ] PyTorch 2.0+ 사용 (SDPA 자동 적용)
  • [ ] bfloat16/float16 사용
  • [ ] Flash Attention 2 지원 확인
  • [ ] GQA 모델 선택 (신규 배포 시)
  • [ ] vLLM/TGI 서빙 프레임워크 활용
  • [ ] KV 캐시 크기 모니터링

참고 자료

  • "FlashAttention: Fast and Memory-Efficient Exact Attention" (Dao et al., 2022)
  • "FlashAttention-2: Faster Attention with Better Parallelism" (2023)
  • "FlashAttention-3: Fast and Accurate Attention with Asynchrony and Low-precision" (2024)
  • "GQA: Training Generalized Multi-Query Transformer Models" (Ainslie et al., 2023)
  • "Ring Attention with Blockwise Transformers for Near-Infinite Context" (2024)

최종 업데이트: 2026-02-18