Attention 메커니즘¶
Transformer의 핵심. 시퀀스 내 요소들 간의 관계를 학습하여 관련 정보에 집중함.
Scaled Dot-Product Attention¶
수식¶
\[\text{Attention}(Q, K, V) = \text{softmax}\left(\frac{QK^T}{\sqrt{d_k}}\right)V\]
- Q (Query): "무엇을 찾을 것인가"
- K (Key): "어떤 정보가 있는가"
- V (Value): "실제 정보 내용"
직관적 이해¶
문장: "The cat sat on the mat"
"sat"에 대한 attention 계산:
Query: "sat"의 표현
Keys: 모든 토큰의 표현
→ Attention scores: [0.1, 0.3, 0.2, 0.1, 0.1, 0.2]
→ "cat"에 높은 가중치 (주어-동사 관계)
스케일링 이유¶
구현¶
import torch
import torch.nn.functional as F
import math
def scaled_dot_product_attention(Q, K, V, mask=None, dropout=None):
"""
Q: (batch, heads, seq_len, d_k)
K: (batch, heads, seq_len, d_k)
V: (batch, heads, seq_len, d_v)
"""
d_k = Q.size(-1)
# Attention scores
scores = torch.matmul(Q, K.transpose(-2, -1)) / math.sqrt(d_k)
# Masking
if mask is not None:
scores = scores.masked_fill(mask == 0, float('-inf'))
# Softmax
attn_weights = F.softmax(scores, dim=-1)
if dropout is not None:
attn_weights = dropout(attn_weights)
# Weighted sum
output = torch.matmul(attn_weights, V)
return output, attn_weights
Multi-Head Attention¶
여러 "관점"에서 동시에 attention 계산.
수식¶
\[\text{MultiHead}(Q, K, V) = \text{Concat}(\text{head}_1, ..., \text{head}_h)W^O$$
$$\text{head}_i = \text{Attention}(QW_i^Q, KW_i^K, VW_i^V)\]
장점¶
- 다양한 관계 유형 학습 (문법, 의미, 공지시 등)
- 병렬 처리 가능
- 표현력 증가
class MultiHeadAttention(nn.Module):
def __init__(self, d_model, num_heads, dropout=0.1):
super().__init__()
assert d_model % num_heads == 0
self.d_model = d_model
self.num_heads = num_heads
self.d_k = d_model // num_heads
# 모든 head를 하나의 행렬로 처리
self.W_qkv = nn.Linear(d_model, 3 * d_model)
self.W_o = nn.Linear(d_model, d_model)
self.dropout = nn.Dropout(dropout)
def forward(self, x, mask=None):
batch_size, seq_len, _ = x.shape
# QKV projection
qkv = self.W_qkv(x)
qkv = qkv.reshape(batch_size, seq_len, 3, self.num_heads, self.d_k)
qkv = qkv.permute(2, 0, 3, 1, 4) # (3, batch, heads, seq, d_k)
Q, K, V = qkv[0], qkv[1], qkv[2]
# Attention
attn_output, _ = scaled_dot_product_attention(Q, K, V, mask, self.dropout)
# Concat heads
attn_output = attn_output.transpose(1, 2).contiguous()
attn_output = attn_output.view(batch_size, seq_len, self.d_model)
return self.W_o(attn_output)
Causal (Masked) Self-Attention¶
미래 토큰을 보지 못하게 마스킹. Decoder/GPT에서 사용.
def create_causal_mask(seq_len, device):
"""
[[1, 0, 0, 0],
[1, 1, 0, 0],
[1, 1, 1, 0],
[1, 1, 1, 1]]
"""
mask = torch.tril(torch.ones(seq_len, seq_len, device=device))
return mask
# Attention에서 적용
scores = scores.masked_fill(mask == 0, float('-inf'))
# -inf는 softmax 후 0이 됨
Cross-Attention¶
Encoder-Decoder 모델에서 사용. Decoder가 Encoder 출력을 참조.
class CrossAttention(nn.Module):
def __init__(self, d_model, num_heads):
super().__init__()
self.W_q = nn.Linear(d_model, d_model) # Decoder에서
self.W_k = nn.Linear(d_model, d_model) # Encoder에서
self.W_v = nn.Linear(d_model, d_model) # Encoder에서
self.W_o = nn.Linear(d_model, d_model)
...
def forward(self, decoder_hidden, encoder_output, mask=None):
Q = self.W_q(decoder_hidden) # Query: Decoder
K = self.W_k(encoder_output) # Key: Encoder
V = self.W_v(encoder_output) # Value: Encoder
# 나머지 동일
...
Grouped Query Attention (GQA)¶
LLaMA 2, Gemma 등에서 사용. Key-Value 헤드 수 줄여 메모리 절약.
Multi-Head Attention (MHA):
Q heads: 32, K heads: 32, V heads: 32
Grouped Query Attention (GQA):
Q heads: 32, K heads: 8, V heads: 8
→ 4개의 Q head가 1개의 KV 공유
Multi-Query Attention (MQA):
Q heads: 32, K heads: 1, V heads: 1
→ 모든 Q가 1개의 KV 공유
class GroupedQueryAttention(nn.Module):
def __init__(self, d_model, num_heads, num_kv_heads):
super().__init__()
self.num_heads = num_heads
self.num_kv_heads = num_kv_heads
self.num_groups = num_heads // num_kv_heads
self.d_k = d_model // num_heads
self.W_q = nn.Linear(d_model, d_model)
self.W_k = nn.Linear(d_model, num_kv_heads * self.d_k)
self.W_v = nn.Linear(d_model, num_kv_heads * self.d_k)
self.W_o = nn.Linear(d_model, d_model)
def forward(self, x, mask=None):
batch_size, seq_len, _ = x.shape
Q = self.W_q(x).view(batch_size, seq_len, self.num_heads, self.d_k)
K = self.W_k(x).view(batch_size, seq_len, self.num_kv_heads, self.d_k)
V = self.W_v(x).view(batch_size, seq_len, self.num_kv_heads, self.d_k)
# KV를 Q head 수만큼 반복
K = K.repeat_interleave(self.num_groups, dim=2)
V = V.repeat_interleave(self.num_groups, dim=2)
# 나머지 attention 계산 동일
...
Flash Attention¶
GPU 메모리 계층을 고려한 효율적 구현.
문제¶
해결¶
# PyTorch 2.0+
from torch.nn.functional import scaled_dot_product_attention
# 자동으로 Flash Attention 사용
output = scaled_dot_product_attention(Q, K, V, attn_mask=mask)
# 또는 xformers
from xformers.ops import memory_efficient_attention
output = memory_efficient_attention(Q, K, V, attn_bias=mask)
# Flash Attention 라이브러리
from flash_attn import flash_attn_func
output = flash_attn_func(Q, K, V, causal=True)
KV Cache¶
자기회귀 생성 시 이전 토큰의 K, V 재사용.
Step 1: "The" → K1, V1 계산 → 캐시
Step 2: "The cat" → K2, V2만 계산 → 캐시에 추가
Step 3: "The cat sat" → K3, V3만 계산 → 캐시에 추가
...
각 스텝에서 전체 시퀀스의 K, V 재계산 불필요
→ O(N²) → O(N) 복잡도 감소
class AttentionWithKVCache(nn.Module):
def __init__(self, d_model, num_heads):
super().__init__()
self.d_k = d_model // num_heads
self.num_heads = num_heads
...
def forward(self, x, kv_cache=None, use_cache=False):
batch_size, seq_len, _ = x.shape
Q = self.W_q(x)
K = self.W_k(x)
V = self.W_v(x)
if kv_cache is not None:
# 캐시된 KV와 연결
K = torch.cat([kv_cache['k'], K], dim=1)
V = torch.cat([kv_cache['v'], V], dim=1)
# Attention 계산 (Q는 현재 토큰만)
...
if use_cache:
new_cache = {'k': K, 'v': V}
return output, new_cache
return output, None
Sparse Attention¶
긴 시퀀스를 위한 희소 패턴.
패턴 종류¶
# Longformer 스타일 (local + global)
def create_longformer_mask(seq_len, local_window=512, global_tokens=[0]):
mask = torch.zeros(seq_len, seq_len)
# Local attention
for i in range(seq_len):
start = max(0, i - local_window // 2)
end = min(seq_len, i + local_window // 2)
mask[i, start:end] = 1
# Global tokens
for g in global_tokens:
mask[g, :] = 1
mask[:, g] = 1
return mask
Sliding Window Attention¶
Mistral 등에서 사용. 고정 크기 윈도우 내에서만 attention.
class SlidingWindowAttention(nn.Module):
def __init__(self, d_model, num_heads, window_size=4096):
super().__init__()
self.window_size = window_size
...
def forward(self, x, mask=None):
# 윈도우 크기 내에서만 attention 계산
seq_len = x.size(1)
# Sliding window mask
window_mask = torch.ones(seq_len, seq_len)
for i in range(seq_len):
start = max(0, i - self.window_size)
window_mask[i, :start] = 0
# Causal mask와 결합
if mask is not None:
mask = mask * window_mask
else:
mask = window_mask
...
실무 트레이드오프¶
Attention 변형 선택 가이드¶
| 기법 | 메모리 | 속도 | 시퀀스 길이 | 적용 모델 |
|---|---|---|---|---|
| Vanilla MHA | O(N^2) | 기준 | ~2K | BERT, GPT-2 |
| Flash Attention | O(N) | 2-4x | ~8K | 대부분 현대 모델 |
| GQA | O(N^2) | 1.5x | ~4K | LLaMA 2, Gemma |
| MQA | O(N^2) | 2x | ~4K | Falcon, PaLM |
| Sliding Window | O(N*W) | 빠름 | ~128K | Mistral, Longformer |
KV Cache 크기 계산¶
def kv_cache_size_gb(
batch_size,
seq_len,
num_layers,
num_kv_heads,
head_dim,
dtype_bytes=2 # FP16
):
"""KV Cache 메모리 계산"""
# K, V 각각 저장
cache_size = 2 * batch_size * seq_len * num_layers * num_kv_heads * head_dim * dtype_bytes
return cache_size / (1024 ** 3) # GB
# LLaMA-7B 예시 (batch=1, seq=4096)
size = kv_cache_size_gb(1, 4096, 32, 32, 128)
# 약 2GB
# LLaMA-70B 예시 (batch=1, seq=4096, GQA 8 heads)
size = kv_cache_size_gb(1, 4096, 80, 8, 128)
# 약 2.5GB (GQA 덕분에 70B가 7B와 비슷)
Flash Attention 사용 조건¶
# PyTorch 2.0+ 자동 선택
# 다음 조건에서 Flash Attention 활성화:
# 1. CUDA GPU (Ampere 이상 권장)
# 2. Head dimension이 특정 값 (64, 128 등)
# 3. Causal mask 또는 no mask
# 수동 확인
import torch
print(torch.backends.cuda.flash_sdp_enabled()) # True면 활성화
# 강제 활성화/비활성화
with torch.backends.cuda.sdp_kernel(
enable_flash=True,
enable_math=False,
enable_mem_efficient=False
):
output = F.scaled_dot_product_attention(Q, K, V, is_causal=True)
흔한 실수¶
| 실수 | 증상 | 해결책 |
|---|---|---|
| Softmax 축 오류 | 이상한 attention 분포 | dim=-1 확인 |
| Mask dtype 불일치 | CUDA 에러 | mask를 float로 변환 |
| KV Cache 미사용 | 생성 속도 느림 | use_cache=True |
| Attention score 스케일 누락 | 학습 불안정 | / sqrt(d_k) 적용 |