콘텐츠로 이동
Data Prep
상세

Attention 메커니즘

Transformer의 핵심. 시퀀스 내 요소들 간의 관계를 학습하여 관련 정보에 집중함.

Scaled Dot-Product Attention

수식

\[\text{Attention}(Q, K, V) = \text{softmax}\left(\frac{QK^T}{\sqrt{d_k}}\right)V\]
  • Q (Query): "무엇을 찾을 것인가"
  • K (Key): "어떤 정보가 있는가"
  • V (Value): "실제 정보 내용"

직관적 이해

문장: "The cat sat on the mat"

"sat"에 대한 attention 계산:
Query: "sat"의 표현
Keys: 모든 토큰의 표현
→ Attention scores: [0.1, 0.3, 0.2, 0.1, 0.1, 0.2]
→ "cat"에 높은 가중치 (주어-동사 관계)

스케일링 이유

d_k가 크면 QK^T의 분산이 커짐
→ softmax 입력이 극단적
→ 기울기 소실

스케일링으로 분산을 1로 유지

구현

import torch
import torch.nn.functional as F
import math

def scaled_dot_product_attention(Q, K, V, mask=None, dropout=None):
    """
    Q: (batch, heads, seq_len, d_k)
    K: (batch, heads, seq_len, d_k)
    V: (batch, heads, seq_len, d_v)
    """
    d_k = Q.size(-1)

    # Attention scores
    scores = torch.matmul(Q, K.transpose(-2, -1)) / math.sqrt(d_k)

    # Masking
    if mask is not None:
        scores = scores.masked_fill(mask == 0, float('-inf'))

    # Softmax
    attn_weights = F.softmax(scores, dim=-1)

    if dropout is not None:
        attn_weights = dropout(attn_weights)

    # Weighted sum
    output = torch.matmul(attn_weights, V)

    return output, attn_weights

Multi-Head Attention

여러 "관점"에서 동시에 attention 계산.

수식

\[\text{MultiHead}(Q, K, V) = \text{Concat}(\text{head}_1, ..., \text{head}_h)W^O$$ $$\text{head}_i = \text{Attention}(QW_i^Q, KW_i^K, VW_i^V)\]

장점

  • 다양한 관계 유형 학습 (문법, 의미, 공지시 등)
  • 병렬 처리 가능
  • 표현력 증가
class MultiHeadAttention(nn.Module):
    def __init__(self, d_model, num_heads, dropout=0.1):
        super().__init__()
        assert d_model % num_heads == 0

        self.d_model = d_model
        self.num_heads = num_heads
        self.d_k = d_model // num_heads

        # 모든 head를 하나의 행렬로 처리
        self.W_qkv = nn.Linear(d_model, 3 * d_model)
        self.W_o = nn.Linear(d_model, d_model)

        self.dropout = nn.Dropout(dropout)

    def forward(self, x, mask=None):
        batch_size, seq_len, _ = x.shape

        # QKV projection
        qkv = self.W_qkv(x)
        qkv = qkv.reshape(batch_size, seq_len, 3, self.num_heads, self.d_k)
        qkv = qkv.permute(2, 0, 3, 1, 4)  # (3, batch, heads, seq, d_k)
        Q, K, V = qkv[0], qkv[1], qkv[2]

        # Attention
        attn_output, _ = scaled_dot_product_attention(Q, K, V, mask, self.dropout)

        # Concat heads
        attn_output = attn_output.transpose(1, 2).contiguous()
        attn_output = attn_output.view(batch_size, seq_len, self.d_model)

        return self.W_o(attn_output)

Causal (Masked) Self-Attention

미래 토큰을 보지 못하게 마스킹. Decoder/GPT에서 사용.

def create_causal_mask(seq_len, device):
    """
    [[1, 0, 0, 0],
     [1, 1, 0, 0],
     [1, 1, 1, 0],
     [1, 1, 1, 1]]
    """
    mask = torch.tril(torch.ones(seq_len, seq_len, device=device))
    return mask

# Attention에서 적용
scores = scores.masked_fill(mask == 0, float('-inf'))
# -inf는 softmax 후 0이 됨

Cross-Attention

Encoder-Decoder 모델에서 사용. Decoder가 Encoder 출력을 참조.

class CrossAttention(nn.Module):
    def __init__(self, d_model, num_heads):
        super().__init__()

        self.W_q = nn.Linear(d_model, d_model)  # Decoder에서
        self.W_k = nn.Linear(d_model, d_model)  # Encoder에서
        self.W_v = nn.Linear(d_model, d_model)  # Encoder에서
        self.W_o = nn.Linear(d_model, d_model)
        ...

    def forward(self, decoder_hidden, encoder_output, mask=None):
        Q = self.W_q(decoder_hidden)   # Query: Decoder
        K = self.W_k(encoder_output)   # Key: Encoder
        V = self.W_v(encoder_output)   # Value: Encoder

        # 나머지 동일
        ...

Grouped Query Attention (GQA)

LLaMA 2, Gemma 등에서 사용. Key-Value 헤드 수 줄여 메모리 절약.

Multi-Head Attention (MHA):
Q heads: 32, K heads: 32, V heads: 32

Grouped Query Attention (GQA):
Q heads: 32, K heads: 8, V heads: 8
→ 4개의 Q head가 1개의 KV 공유

Multi-Query Attention (MQA):
Q heads: 32, K heads: 1, V heads: 1
→ 모든 Q가 1개의 KV 공유
class GroupedQueryAttention(nn.Module):
    def __init__(self, d_model, num_heads, num_kv_heads):
        super().__init__()
        self.num_heads = num_heads
        self.num_kv_heads = num_kv_heads
        self.num_groups = num_heads // num_kv_heads

        self.d_k = d_model // num_heads

        self.W_q = nn.Linear(d_model, d_model)
        self.W_k = nn.Linear(d_model, num_kv_heads * self.d_k)
        self.W_v = nn.Linear(d_model, num_kv_heads * self.d_k)
        self.W_o = nn.Linear(d_model, d_model)

    def forward(self, x, mask=None):
        batch_size, seq_len, _ = x.shape

        Q = self.W_q(x).view(batch_size, seq_len, self.num_heads, self.d_k)
        K = self.W_k(x).view(batch_size, seq_len, self.num_kv_heads, self.d_k)
        V = self.W_v(x).view(batch_size, seq_len, self.num_kv_heads, self.d_k)

        # KV를 Q head 수만큼 반복
        K = K.repeat_interleave(self.num_groups, dim=2)
        V = V.repeat_interleave(self.num_groups, dim=2)

        # 나머지 attention 계산 동일
        ...

Flash Attention

GPU 메모리 계층을 고려한 효율적 구현.

문제

기존: O(N²) 메모리 (attention matrix 전체 저장)
N=4096 → 64GB 필요 (FP32)

해결

Tiling: 블록 단위로 계산
Recomputation: 역전파 시 재계산
→ O(N) 메모리, 2-4x 속도 향상
# PyTorch 2.0+
from torch.nn.functional import scaled_dot_product_attention

# 자동으로 Flash Attention 사용
output = scaled_dot_product_attention(Q, K, V, attn_mask=mask)

# 또는 xformers
from xformers.ops import memory_efficient_attention
output = memory_efficient_attention(Q, K, V, attn_bias=mask)

# Flash Attention 라이브러리
from flash_attn import flash_attn_func
output = flash_attn_func(Q, K, V, causal=True)

KV Cache

자기회귀 생성 시 이전 토큰의 K, V 재사용.

Step 1: "The" → K1, V1 계산 → 캐시
Step 2: "The cat" → K2, V2만 계산 → 캐시에 추가
Step 3: "The cat sat" → K3, V3만 계산 → 캐시에 추가
...

각 스텝에서 전체 시퀀스의 K, V 재계산 불필요
→ O(N²) → O(N) 복잡도 감소
class AttentionWithKVCache(nn.Module):
    def __init__(self, d_model, num_heads):
        super().__init__()
        self.d_k = d_model // num_heads
        self.num_heads = num_heads
        ...

    def forward(self, x, kv_cache=None, use_cache=False):
        batch_size, seq_len, _ = x.shape

        Q = self.W_q(x)
        K = self.W_k(x)
        V = self.W_v(x)

        if kv_cache is not None:
            # 캐시된 KV와 연결
            K = torch.cat([kv_cache['k'], K], dim=1)
            V = torch.cat([kv_cache['v'], V], dim=1)

        # Attention 계산 (Q는 현재 토큰만)
        ...

        if use_cache:
            new_cache = {'k': K, 'v': V}
            return output, new_cache

        return output, None

Sparse Attention

긴 시퀀스를 위한 희소 패턴.

패턴 종류

Local: 주변 토큰만 attention
Strided: 일정 간격의 토큰
Global: 특정 토큰 (예: [CLS])
Block-sparse: 블록 단위
# Longformer 스타일 (local + global)
def create_longformer_mask(seq_len, local_window=512, global_tokens=[0]):
    mask = torch.zeros(seq_len, seq_len)

    # Local attention
    for i in range(seq_len):
        start = max(0, i - local_window // 2)
        end = min(seq_len, i + local_window // 2)
        mask[i, start:end] = 1

    # Global tokens
    for g in global_tokens:
        mask[g, :] = 1
        mask[:, g] = 1

    return mask

Sliding Window Attention

Mistral 등에서 사용. 고정 크기 윈도우 내에서만 attention.

class SlidingWindowAttention(nn.Module):
    def __init__(self, d_model, num_heads, window_size=4096):
        super().__init__()
        self.window_size = window_size
        ...

    def forward(self, x, mask=None):
        # 윈도우 크기 내에서만 attention 계산
        seq_len = x.size(1)

        # Sliding window mask
        window_mask = torch.ones(seq_len, seq_len)
        for i in range(seq_len):
            start = max(0, i - self.window_size)
            window_mask[i, :start] = 0

        # Causal mask와 결합
        if mask is not None:
            mask = mask * window_mask
        else:
            mask = window_mask

        ...

실무 트레이드오프

Attention 변형 선택 가이드

기법 메모리 속도 시퀀스 길이 적용 모델
Vanilla MHA O(N^2) 기준 ~2K BERT, GPT-2
Flash Attention O(N) 2-4x ~8K 대부분 현대 모델
GQA O(N^2) 1.5x ~4K LLaMA 2, Gemma
MQA O(N^2) 2x ~4K Falcon, PaLM
Sliding Window O(N*W) 빠름 ~128K Mistral, Longformer

KV Cache 크기 계산

def kv_cache_size_gb(
    batch_size,
    seq_len,
    num_layers,
    num_kv_heads,
    head_dim,
    dtype_bytes=2  # FP16
):
    """KV Cache 메모리 계산"""
    # K, V 각각 저장
    cache_size = 2 * batch_size * seq_len * num_layers * num_kv_heads * head_dim * dtype_bytes
    return cache_size / (1024 ** 3)  # GB

# LLaMA-7B 예시 (batch=1, seq=4096)
size = kv_cache_size_gb(1, 4096, 32, 32, 128)
# 약 2GB

# LLaMA-70B 예시 (batch=1, seq=4096, GQA 8 heads)
size = kv_cache_size_gb(1, 4096, 80, 8, 128)
# 약 2.5GB (GQA 덕분에 70B가 7B와 비슷)

Flash Attention 사용 조건

# PyTorch 2.0+ 자동 선택
# 다음 조건에서 Flash Attention 활성화:
# 1. CUDA GPU (Ampere 이상 권장)
# 2. Head dimension이 특정 값 (64, 128 등)
# 3. Causal mask 또는 no mask

# 수동 확인
import torch
print(torch.backends.cuda.flash_sdp_enabled())  # True면 활성화

# 강제 활성화/비활성화
with torch.backends.cuda.sdp_kernel(
    enable_flash=True,
    enable_math=False,
    enable_mem_efficient=False
):
    output = F.scaled_dot_product_attention(Q, K, V, is_causal=True)

흔한 실수

실수 증상 해결책
Softmax 축 오류 이상한 attention 분포 dim=-1 확인
Mask dtype 불일치 CUDA 에러 mask를 float로 변환
KV Cache 미사용 생성 속도 느림 use_cache=True
Attention score 스케일 누락 학습 불안정 / sqrt(d_k) 적용

참고 자료