Attention Mechanism¶
논문 정보¶
| 항목 | 내용 |
|---|---|
| 제목 | Attention Is All You Need |
| 저자 | Ashish Vaswani et al. (Google) |
| 학회 | NeurIPS 2017 |
| 링크 | https://arxiv.org/abs/1706.03762 |
| 관련 논문 | 내용 |
|---|---|
| Bahdanau Attention (2014) | Neural Machine Translation by Jointly Learning to Align and Translate |
| Luong Attention (2015) | Effective Approaches to Attention-based Neural Machine Translation |
개요¶
문제 정의¶
Seq2Seq 모델의 병목 (Bottleneck) 문제:
- 긴 시퀀스의 정보 손실
- 입력 시퀀스의 특정 부분에 집중 불가
핵심 아이디어¶
"모든 입력을 동일하게 보지 말고, 관련 있는 부분에 집중하자"
\[\text{Attention}(Q, K, V) = \sum_i \alpha_i V_i\]
여기서 \(\alpha_i\)는 Query와 Key의 관련성(유사도)에 기반한 가중치.
Attention의 진화¶
Bahdanau Attention (Additive)¶
h_1 h_2 h_3 h_4
│ │ │ │
▼ ▼ ▼ ▼
┌──────────────────────┐
│ Alignment Model │
│ (Additive/MLP) │
└──────────────────────┘
│
▼
α_1 α_2 α_3 α_4 (attention weights)
│ │ │ │
└──────┴──────┴──────┘
│
▼
Context Vector c
Score 함수:
\[e_{ij} = v_a^T \tanh(W_a s_{i-1} + U_a h_j)\]
| 기호 | 설명 |
|---|---|
| \(s_{i-1}\) | 디코더의 이전 은닉 상태 (Query) |
| \(h_j\) | 인코더의 j번째 은닉 상태 (Key = Value) |
| \(W_a, U_a, v_a\) | 학습 가능한 파라미터 |
Attention 가중치:
\[\alpha_{ij} = \frac{\exp(e_{ij})}{\sum_{k=1}^{n} \exp(e_{ik})}\]
Context Vector:
\[c_i = \sum_{j=1}^{n} \alpha_{ij} h_j\]
Luong Attention (Multiplicative)¶
더 단순하고 효율적인 점수 계산:
| 유형 | Score 함수 |
|---|---|
| Dot | \(s_t^T h_s\) |
| General | \(s_t^T W_a h_s\) |
| Concat | \(v_a^T \tanh(W_a [s_t; h_s])\) |
Dot Product가 가장 효율적이고 널리 사용됨.
Self-Attention¶
개념¶
자기 자신의 시퀀스 내에서 각 위치가 다른 모든 위치를 참조:
Input: "The animal didn't cross the street because it was too tired"
↑
"it"이 무엇?
Self-Attention이 "it" -> "animal" 연결을 학습
Query, Key, Value¶
입력 시퀀스 \(X \in \mathbb{R}^{n \times d}\)에서:
\[Q = XW^Q, \quad K = XW^K, \quad V = XW^V\]
| 개념 | 역할 | 비유 |
|---|---|---|
| Query (Q) | 현재 위치에서 "무엇을 찾고 있는지" | 검색어 |
| Key (K) | 각 위치가 "어떤 정보를 가지고 있는지" | 문서 제목 |
| Value (V) | 실제로 전달할 정보 | 문서 내용 |
Scaled Dot-Product Attention¶
\[\text{Attention}(Q, K, V) = \text{softmax}\left(\frac{QK^T}{\sqrt{d_k}}\right)V\]
Scaling 이유: \(d_k\)가 크면 dot product 값이 커져서 softmax가 극단적인 값으로 포화됨.
\[\text{Var}(q \cdot k) = d_k \cdot \text{Var}(q_i) \cdot \text{Var}(k_i) = d_k\]
\(\sqrt{d_k}\)로 나누면 분산이 1로 정규화됨.
계산 흐름¶
Q (n×d_k) K^T (d_k×n) V (n×d_v)
│ │ │
└───────┬───────┘ │
│ │
▼ │
Q @ K^T │
(n × n) │
│ │
▼ │
/ sqrt(d_k) │
│ │
▼ │
Softmax │
(n × n) │
│ │
└──────────┬───────────────┘
│
▼
@ V (n × d_v)
│
▼
Output (n × d_v)
Multi-Head Attention¶
개념¶
단일 Attention 대신 여러 개의 Attention을 병렬로 수행:
\[\text{MultiHead}(Q, K, V) = \text{Concat}(\text{head}_1, ..., \text{head}_h)W^O\]
\[\text{head}_i = \text{Attention}(QW_i^Q, KW_i^K, VW_i^V)\]
장점¶
| 장점 | 설명 |
|---|---|
| 다양한 관계 학습 | 각 head가 다른 관계에 집중 |
| 표현력 증가 | 여러 부분공간(subspace)에서 attention |
| 안정적 학습 | head 간 평균화 효과 |
예시¶
"The cat sat on the mat"
Head 1: 주어-동사 관계 학습 (cat -> sat)
Head 2: 명사-전치사 관계 학습 (mat -> on)
Head 3: 위치 관계 학습 (인접 토큰)
...
파라미터¶
| 파라미터 | 일반적 값 | 설명 |
|---|---|---|
| \(d_{model}\) | 512, 768, 1024 | 모델 차원 |
| \(h\) | 8, 12, 16 | head 수 |
| \(d_k = d_v\) | \(d_{model}/h\) | 각 head의 차원 |
PyTorch 구현¶
Scaled Dot-Product Attention¶
import torch
import torch.nn as nn
import torch.nn.functional as F
import math
def scaled_dot_product_attention(Q, K, V, mask=None):
"""
Args:
Q: (batch, ..., seq_len, d_k)
K: (batch, ..., seq_len, d_k)
V: (batch, ..., seq_len, d_v)
mask: (batch, ..., seq_len, seq_len) or broadcastable
Returns:
output: (batch, ..., seq_len, d_v)
attention_weights: (batch, ..., seq_len, seq_len)
"""
d_k = Q.size(-1)
# Attention scores
scores = torch.matmul(Q, K.transpose(-2, -1)) / math.sqrt(d_k)
# Masking (optional)
if mask is not None:
scores = scores.masked_fill(mask == 0, float('-inf'))
# Softmax
attention_weights = F.softmax(scores, dim=-1)
# Weighted sum
output = torch.matmul(attention_weights, V)
return output, attention_weights
Multi-Head Attention¶
class MultiHeadAttention(nn.Module):
def __init__(self, d_model, num_heads, dropout=0.1):
super().__init__()
assert d_model % num_heads == 0
self.d_model = d_model
self.num_heads = num_heads
self.d_k = d_model // num_heads
self.W_q = nn.Linear(d_model, d_model)
self.W_k = nn.Linear(d_model, d_model)
self.W_v = nn.Linear(d_model, d_model)
self.W_o = nn.Linear(d_model, d_model)
self.dropout = nn.Dropout(dropout)
def forward(self, Q, K, V, mask=None):
batch_size = Q.size(0)
# Linear projections
Q = self.W_q(Q) # (batch, seq, d_model)
K = self.W_k(K)
V = self.W_v(V)
# Split into heads: (batch, seq, d_model) -> (batch, heads, seq, d_k)
Q = Q.view(batch_size, -1, self.num_heads, self.d_k).transpose(1, 2)
K = K.view(batch_size, -1, self.num_heads, self.d_k).transpose(1, 2)
V = V.view(batch_size, -1, self.num_heads, self.d_k).transpose(1, 2)
# Attention
attn_output, attn_weights = scaled_dot_product_attention(Q, K, V, mask)
# Concatenate heads: (batch, heads, seq, d_k) -> (batch, seq, d_model)
attn_output = attn_output.transpose(1, 2).contiguous().view(
batch_size, -1, self.d_model
)
# Final projection
output = self.W_o(attn_output)
return output, attn_weights
# 테스트
mha = MultiHeadAttention(d_model=512, num_heads=8)
x = torch.randn(32, 100, 512) # (batch, seq_len, d_model)
output, weights = mha(x, x, x) # Self-attention
print(f"Output: {output.shape}") # (32, 100, 512)
print(f"Weights: {weights.shape}") # (32, 8, 100, 100)
Attention Masking¶
def create_padding_mask(seq, pad_idx=0):
"""패딩 토큰 마스킹"""
# seq: (batch, seq_len)
# output: (batch, 1, 1, seq_len) for broadcasting
return (seq != pad_idx).unsqueeze(1).unsqueeze(2)
def create_causal_mask(seq_len):
"""미래 토큰 마스킹 (decoder self-attention용)"""
# output: (1, 1, seq_len, seq_len)
mask = torch.triu(torch.ones(seq_len, seq_len), diagonal=1)
return (mask == 0).unsqueeze(0).unsqueeze(0)
# 예시
seq_len = 10
causal_mask = create_causal_mask(seq_len)
print(causal_mask.squeeze())
# tensor([[1, 0, 0, 0, 0, 0, 0, 0, 0, 0],
# [1, 1, 0, 0, 0, 0, 0, 0, 0, 0],
# [1, 1, 1, 0, 0, 0, 0, 0, 0, 0],
# ...
PyTorch 내장 MHA¶
# PyTorch 내장 MultiheadAttention
mha = nn.MultiheadAttention(embed_dim=512, num_heads=8, batch_first=True)
x = torch.randn(32, 100, 512)
# Self-attention
output, attn_weights = mha(x, x, x)
# Cross-attention (encoder-decoder)
encoder_output = torch.randn(32, 50, 512)
decoder_input = torch.randn(32, 30, 512)
output, _ = mha(decoder_input, encoder_output, encoder_output) # Q from decoder, K,V from encoder
Attention 변형¶
Cross-Attention¶
Encoder-Decoder 구조에서 사용:
class CrossAttention(nn.Module):
def __init__(self, d_model, num_heads):
super().__init__()
self.mha = MultiHeadAttention(d_model, num_heads)
def forward(self, decoder_hidden, encoder_output, mask=None):
# Q: decoder, K/V: encoder
return self.mha(decoder_hidden, encoder_output, encoder_output, mask)
Efficient Attention 변형¶
| 변형 | 복잡도 | 설명 |
|---|---|---|
| Vanilla | O(n^2) | 기본 Self-Attention |
| Sparse | O(n sqrt(n)) | 일부 위치만 attend |
| Linear | O(n) | Kernel trick 사용 |
| Flash | O(n^2) 메모리 O(n) | IO-aware 최적화 |
Flash Attention¶
# PyTorch 2.0+ 내장
from torch.nn.functional import scaled_dot_product_attention
# 자동으로 Flash Attention 사용 (조건 충족 시)
output = scaled_dot_product_attention(Q, K, V, is_causal=True)
Attention 시각화¶
import matplotlib.pyplot as plt
import seaborn as sns
def visualize_attention(attention_weights, tokens_x, tokens_y=None):
"""
attention_weights: (seq_len_q, seq_len_k)
tokens_x: Key 토큰 리스트
tokens_y: Query 토큰 리스트 (None이면 tokens_x 사용)
"""
if tokens_y is None:
tokens_y = tokens_x
plt.figure(figsize=(10, 8))
sns.heatmap(
attention_weights.detach().numpy(),
xticklabels=tokens_x,
yticklabels=tokens_y,
cmap='Blues',
annot=True,
fmt='.2f'
)
plt.xlabel('Key')
plt.ylabel('Query')
plt.title('Attention Weights')
plt.tight_layout()
plt.savefig('attention_viz.png', dpi=150)
plt.show()
# 예시
tokens = ['The', 'cat', 'sat', 'on', 'mat']
attn = torch.softmax(torch.randn(5, 5), dim=-1)
visualize_attention(attn, tokens)
Attention의 의미¶
학습되는 것¶
| 관계 유형 | 예시 |
|---|---|
| 문법적 | 주어-동사, 관사-명사 |
| 의미적 | 대명사-선행사, 동의어 |
| 위치적 | 인접 토큰 관계 |
| 장거리 | 문장 간 참조 |
Head별 특화¶
연구에 따르면 각 head가 다른 언어적 관계에 특화:
Head 1: 바로 다음 토큰 주목 (positional)
Head 2: 동사-목적어 관계 (syntactic)
Head 3: 대명사 해소 (coreference)
Head 4: 부정어 범위 (semantic)
관련 문서¶
| 주제 | 링크 |
|---|---|
| 딥러닝 기초 | README.md |
| RNN/LSTM | rnn-lstm.md |
| Transformer | ../../architecture/transformer.md |
참고¶
- Vaswani, A. et al. (2017). "Attention Is All You Need"
- Bahdanau, D. et al. (2014). "Neural Machine Translation by Jointly Learning to Align and Translate"
- The Illustrated Transformer: https://jalammar.github.io/illustrated-transformer/
- Attention? Attention! (Lilian Weng): https://lilianweng.github.io/posts/2018-06-24-attention/