Cross-modal Attention¶
이미지와 텍스트 모달리티 간의 상호작용을 학습하는 메커니즘. VLM의 핵심 구성 요소다.
Cross-Attention 기본¶
Self-Attention vs Cross-Attention¶
Self-Attention:
같은 시퀀스 내에서 상호작용
Q, K, V 모두 같은 입력에서 유래
Cross-Attention:
다른 시퀀스 간 상호작용
Q는 한 모달리티, K/V는 다른 모달리티
VLM에서의 역할¶
Text가 Image를 참조:
Q = Text tokens (질문: "이 이미지에서 무엇을 찾을까?")
K, V = Image tokens (답변 소스)
결과: 각 텍스트 토큰이 관련 이미지 영역에 집중
수식¶
\[\text{CrossAttn}(Q_{text}, K_{image}, V_{image}) = \text{softmax}\left(\frac{Q_{text}K_{image}^T}{\sqrt{d_k}}\right)V_{image}\]
- \(Q_{text}\): 텍스트에서 유래한 Query
- \(K_{image}, V_{image}\): 이미지에서 유래한 Key, Value
- \(d_k\): Key 차원 (스케일링)
구현¶
import torch
import torch.nn as nn
import torch.nn.functional as F
class CrossAttention(nn.Module):
"""기본 Cross-Attention"""
def __init__(self, query_dim, kv_dim, num_heads=8, dropout=0.1):
super().__init__()
self.num_heads = num_heads
self.head_dim = query_dim // num_heads
self.scale = self.head_dim ** -0.5
# Query from text, Key/Value from image
self.q_proj = nn.Linear(query_dim, query_dim)
self.k_proj = nn.Linear(kv_dim, query_dim)
self.v_proj = nn.Linear(kv_dim, query_dim)
self.out_proj = nn.Linear(query_dim, query_dim)
self.dropout = nn.Dropout(dropout)
def forward(self, query, key_value, attention_mask=None):
"""
query: 텍스트 토큰 (batch, seq_len, query_dim)
key_value: 이미지 토큰 (batch, num_patches, kv_dim)
"""
batch_size, seq_len, _ = query.shape
num_patches = key_value.shape[1]
# Linear projections
Q = self.q_proj(query)
K = self.k_proj(key_value)
V = self.v_proj(key_value)
# Reshape for multi-head attention
# (batch, seq, dim) -> (batch, heads, seq, head_dim)
Q = Q.view(batch_size, seq_len, self.num_heads, self.head_dim).transpose(1, 2)
K = K.view(batch_size, num_patches, self.num_heads, self.head_dim).transpose(1, 2)
V = V.view(batch_size, num_patches, self.num_heads, self.head_dim).transpose(1, 2)
# Attention scores: Q @ K^T / sqrt(d)
attn_weights = torch.matmul(Q, K.transpose(-2, -1)) * self.scale
# attn_weights: (batch, heads, seq_len, num_patches)
# Optional: attention mask
if attention_mask is not None:
attn_weights = attn_weights.masked_fill(attention_mask == 0, float('-inf'))
# Softmax + dropout
attn_weights = F.softmax(attn_weights, dim=-1)
attn_weights = self.dropout(attn_weights)
# Apply attention to values
attn_output = torch.matmul(attn_weights, V)
# attn_output: (batch, heads, seq_len, head_dim)
# Reshape back
attn_output = attn_output.transpose(1, 2).contiguous()
attn_output = attn_output.view(batch_size, seq_len, -1)
return self.out_proj(attn_output), attn_weights
VLM 퓨전 방식 비교¶
개요¶
| 방식 | 대표 모델 | 특징 | 복잡도 |
|---|---|---|---|
| Prefix (Concatenation) | LLaVA | 단순, 효율적 | 낮음 |
| Interleaved Cross-Attn | Flamingo | 깊은 통합 | 높음 |
| Q-Former | BLIP-2 | 토큰 압축 | 중간 |
| Interleaved Tokens | Gemini | 유연함 | 중간 |
1. Prefix Tokens (LLaVA 스타일)¶
이미지 토큰을 텍스트 시퀀스 앞에 연결.
class LLaVA(nn.Module):
"""Prefix-style VLM"""
def __init__(self, vision_encoder, projector, llm):
super().__init__()
self.vision_encoder = vision_encoder
self.projector = projector
self.llm = llm
def forward(self, images, input_ids, attention_mask):
# 1. 이미지 인코딩
with torch.no_grad():
image_features = self.vision_encoder(images)
# 2. LLM 공간으로 투영
image_tokens = self.projector(image_features)
num_image_tokens = image_tokens.shape[1]
# 3. 텍스트 임베딩
text_embeds = self.llm.embed_tokens(input_ids)
# 4. [이미지 토큰] + [텍스트 토큰] 연결
inputs_embeds = torch.cat([image_tokens, text_embeds], dim=1)
# 5. Attention mask 확장
batch_size = images.shape[0]
image_mask = torch.ones(
(batch_size, num_image_tokens),
dtype=attention_mask.dtype,
device=attention_mask.device
)
extended_mask = torch.cat([image_mask, attention_mask], dim=1)
# 6. LLM forward
outputs = self.llm(
inputs_embeds=inputs_embeds,
attention_mask=extended_mask
)
return outputs
def generate(self, images, prompt_ids, max_new_tokens=256):
"""생성"""
# 이미지 토큰 준비
with torch.no_grad():
image_features = self.vision_encoder(images)
image_tokens = self.projector(image_features)
# 프롬프트 임베딩
prompt_embeds = self.llm.embed_tokens(prompt_ids)
inputs_embeds = torch.cat([image_tokens, prompt_embeds], dim=1)
# 생성
return self.llm.generate(
inputs_embeds=inputs_embeds,
max_new_tokens=max_new_tokens,
do_sample=True,
temperature=0.7
)
장점: - 구현 단순, LLM 수정 불필요 - 학습 안정적 - 추론 효율적
단점: - Self-attention만으로 모달리티 상호작용 - 이미지 토큰이 많으면 컨텍스트 소모
2. Interleaved Cross-Attention (Flamingo 스타일)¶
LLM 레이어 사이에 Cross-Attention 삽입.
class FlamingoBlock(nn.Module):
"""Gated Cross-Attention Block"""
def __init__(self, llm_dim, vision_dim, num_heads=8):
super().__init__()
# Cross-Attention
self.cross_attention = CrossAttention(llm_dim, vision_dim, num_heads)
self.norm = nn.LayerNorm(llm_dim)
# Learnable gate (중요!)
# 초기값 0: 처음에는 Cross-Attention이 거의 영향 없음
# 학습하며 점차 증가 → 안정적 학습
self.gate = nn.Parameter(torch.tensor(0.0))
def forward(self, text_hidden, image_features):
# Cross-attention
normed = self.norm(text_hidden)
attn_out, attn_weights = self.cross_attention(normed, image_features)
# Gated residual: tanh(gate) * attn_out
# tanh으로 -1 ~ 1 범위로 제한
return text_hidden + torch.tanh(self.gate) * attn_out
class FlamingoLLM(nn.Module):
"""Flamingo-style LLM with Cross-Attention"""
def __init__(self, llm, vision_dim, cross_attn_every=4):
super().__init__()
self.llm = llm
llm_dim = llm.config.hidden_size
num_layers = llm.config.num_hidden_layers
# Cross-Attention 레이어 (N개마다 삽입)
self.cross_attn_layers = nn.ModuleDict()
for i in range(num_layers):
if (i + 1) % cross_attn_every == 0:
self.cross_attn_layers[str(i)] = FlamingoBlock(
llm_dim, vision_dim, num_heads=8
)
# Perceiver Resampler (이미지 토큰 압축)
self.perceiver = PerceiverResampler(vision_dim, num_latents=64)
def forward(self, input_ids, image_features):
# 이미지 압축
compressed_image = self.perceiver(image_features)
# LLM embedding
hidden_states = self.llm.embed_tokens(input_ids)
# 각 레이어 통과
for i, layer in enumerate(self.llm.layers):
# LLM 레이어
hidden_states = layer(hidden_states)
# Cross-Attention (해당 레이어에만)
if str(i) in self.cross_attn_layers:
hidden_states = self.cross_attn_layers[str(i)](
hidden_states, compressed_image
)
# LM head
hidden_states = self.llm.norm(hidden_states)
return self.llm.lm_head(hidden_states)
장점: - 깊은 모달리티 통합 - 텍스트가 이미지를 선택적으로 참조 - 시퀀스 길이 유지
단점: - 추가 파라미터 - 구현 복잡 - 학습 어려움 (게이트 필수)
3. Interleaved Tokens (Gemini 스타일)¶
이미지와 텍스트를 자연스럽게 섞음.
class InterleavedVLM(nn.Module):
"""이미지-텍스트 인터리빙"""
def __init__(self, vision_encoder, projector, llm, tokenizer):
super().__init__()
self.vision_encoder = vision_encoder
self.projector = projector
self.llm = llm
self.tokenizer = tokenizer
# 특수 토큰
self.img_token_id = tokenizer.convert_tokens_to_ids("<image>")
def forward(self, text_with_placeholders, images):
"""
text_with_placeholders: "User: <image> What is this?"
images: [image1] (placeholder 수만큼)
"""
# 1. 텍스트 토크나이즈
input_ids = self.tokenizer(text_with_placeholders, return_tensors="pt").input_ids
# 2. <image> 위치 찾기
image_positions = (input_ids == self.img_token_id).nonzero(as_tuple=True)[1]
# 3. 텍스트 임베딩
embeddings = self.llm.embed_tokens(input_ids)
# 4. 이미지 인코딩 및 삽입
offset = 0 # 삽입으로 인한 위치 변화
for img_idx, pos in enumerate(image_positions):
# 이미지 인코딩
with torch.no_grad():
img_features = self.vision_encoder(images[img_idx:img_idx+1])
img_tokens = self.projector(img_features)
num_img_tokens = img_tokens.shape[1]
# 삽입 위치 조정
insert_pos = pos.item() + offset
# 임베딩에 이미지 토큰 삽입
embeddings = torch.cat([
embeddings[:, :insert_pos],
img_tokens,
embeddings[:, insert_pos+1:] # <image> 토큰 대체
], dim=1)
# 오프셋 업데이트 (1개 토큰이 num_img_tokens개로)
offset += num_img_tokens - 1
# 5. LLM forward
return self.llm(inputs_embeds=embeddings)
장점: - 자연스러운 멀티모달 대화 - 여러 이미지 유연하게 처리 - 이미지 위치가 문맥에 맞음
단점: - 구현 복잡 - 토큰 위치 관리 필요
Perceiver Resampler¶
가변 길이 이미지 토큰을 고정 길이로 압축. Flamingo의 핵심 컴포넌트.
왜 필요한가¶
구현¶
class PerceiverResampler(nn.Module):
"""Flamingo의 Perceiver Resampler"""
def __init__(
self,
dim,
num_latents=64, # 출력 토큰 수
num_heads=8,
num_layers=6,
ff_mult=4
):
super().__init__()
# 학습 가능한 latent 쿼리 (출력 토큰)
self.latents = nn.Parameter(torch.randn(num_latents, dim))
# Temporal embedding (비디오용)
self.time_embed = nn.Embedding(256, dim)
# Perceiver 레이어
self.layers = nn.ModuleList([
PerceiverLayer(dim, num_heads, ff_mult)
for _ in range(num_layers)
])
self.norm = nn.LayerNorm(dim)
def forward(self, image_features, time_indices=None):
"""
image_features: (batch, num_patches, dim)
또는 (batch, num_frames, num_patches, dim)
"""
batch_size = image_features.shape[0]
# 비디오 처리: temporal embedding 추가
if image_features.dim() == 4:
num_frames, num_patches = image_features.shape[1:3]
if time_indices is None:
time_indices = torch.arange(num_frames, device=image_features.device)
time_emb = self.time_embed(time_indices) # (frames, dim)
time_emb = time_emb.unsqueeze(0).unsqueeze(2) # (1, frames, 1, dim)
image_features = image_features + time_emb
image_features = image_features.flatten(1, 2) # (batch, frames*patches, dim)
# Latent 쿼리 확장
latents = self.latents.unsqueeze(0).expand(batch_size, -1, -1)
# Perceiver 레이어 통과
for layer in self.layers:
latents = layer(latents, image_features)
return self.norm(latents)
class PerceiverLayer(nn.Module):
def __init__(self, dim, num_heads, ff_mult=4):
super().__init__()
# Cross-attention: latents attend to image
self.cross_attn = nn.MultiheadAttention(
dim, num_heads, batch_first=True
)
self.norm1 = nn.LayerNorm(dim)
# Self-attention: latents interact
self.self_attn = nn.MultiheadAttention(
dim, num_heads, batch_first=True
)
self.norm2 = nn.LayerNorm(dim)
# Feed-forward
self.ffn = nn.Sequential(
nn.Linear(dim, dim * ff_mult),
nn.GELU(),
nn.Linear(dim * ff_mult, dim)
)
self.norm3 = nn.LayerNorm(dim)
def forward(self, latents, image_features):
# Cross-attention with image
ln = self.norm1(latents)
attn_out, _ = self.cross_attn(ln, image_features, image_features)
latents = latents + attn_out
# Self-attention among latents
ln = self.norm2(latents)
attn_out, _ = self.self_attn(ln, ln, ln)
latents = latents + attn_out
# FFN
latents = latents + self.ffn(self.norm3(latents))
return latents
Q-Former (BLIP-2)¶
Vision Encoder와 LLM 사이의 경량 브릿지.
아키텍처¶
구현¶
class QFormer(nn.Module):
"""BLIP-2의 Q-Former"""
def __init__(
self,
num_queries=32,
hidden_size=768,
num_heads=12,
num_layers=6,
vision_dim=1024
):
super().__init__()
# 학습 가능한 쿼리
self.queries = nn.Parameter(torch.randn(1, num_queries, hidden_size))
# 텍스트 인코더 (BERT 기반, 선택적)
self.text_encoder = BertModel.from_pretrained("bert-base-uncased")
# Q-Former 레이어
self.layers = nn.ModuleList([
QFormerLayer(hidden_size, num_heads, vision_dim)
for _ in range(num_layers)
])
self.norm = nn.LayerNorm(hidden_size)
# 태스크별 projection
self.itc_proj = nn.Linear(hidden_size, 256) # contrastive
self.itm_head = nn.Linear(hidden_size, 2) # matching
self.lm_head = nn.Linear(hidden_size, 30522) # generation
def forward(self, image_features, text_input_ids=None, mode="fusion"):
"""
mode: "fusion" | "itc" | "itm" | "generation"
"""
batch_size = image_features.shape[0]
queries = self.queries.expand(batch_size, -1, -1)
# Q-Former 레이어
for layer in self.layers:
queries = layer(queries, image_features, text_input_ids)
queries = self.norm(queries)
if mode == "itc":
# Image-Text Contrastive
return self.itc_proj(queries[:, 0])
elif mode == "itm":
# Image-Text Matching
return self.itm_head(queries[:, 0])
elif mode == "generation":
# LLM에 전달할 representation
return queries
else:
return queries
class QFormerLayer(nn.Module):
def __init__(self, hidden_size, num_heads, vision_dim):
super().__init__()
# Self-attention (쿼리 간, 텍스트와 함께)
self.self_attn = nn.MultiheadAttention(
hidden_size, num_heads, batch_first=True
)
self.norm1 = nn.LayerNorm(hidden_size)
# Cross-attention (쿼리 → 이미지)
self.cross_attn = nn.MultiheadAttention(
hidden_size, num_heads, batch_first=True,
kdim=vision_dim, vdim=vision_dim
)
self.norm2 = nn.LayerNorm(hidden_size)
# FFN
self.ffn = nn.Sequential(
nn.Linear(hidden_size, hidden_size * 4),
nn.GELU(),
nn.Linear(hidden_size * 4, hidden_size)
)
self.norm3 = nn.LayerNorm(hidden_size)
def forward(self, queries, image_features, text_embeds=None):
# 텍스트가 있으면 쿼리와 연결
if text_embeds is not None:
combined = torch.cat([queries, text_embeds], dim=1)
else:
combined = queries
# Self-attention
q = self.norm1(combined)
combined = combined + self.self_attn(q, q, q)[0]
# Cross-attention with image (쿼리 부분만)
q = self.norm2(combined[:, :queries.shape[1]])
queries = queries + self.cross_attn(
q, image_features, image_features
)[0]
# FFN
queries = queries + self.ffn(self.norm3(queries))
return queries
Attention 시각화¶
모델이 어디를 보는지 이해하기 위한 시각화.
import matplotlib.pyplot as plt
import numpy as np
def visualize_cross_attention(
attn_weights,
image,
query_tokens,
patch_size=14,
query_idx=0
):
"""
Cross-attention 가중치 시각화
attn_weights: (batch, heads, seq_len, num_patches)
"""
# 평균 attention (헤드 평균)
weights = attn_weights[0].mean(0) # (seq_len, num_patches)
# 특정 쿼리 토큰의 attention
query_attn = weights[query_idx] # (num_patches,)
# 2D로 reshape
h = w = int(np.sqrt(len(query_attn)))
attn_map = query_attn.reshape(h, w).detach().cpu().numpy()
# 시각화
fig, axes = plt.subplots(1, 3, figsize=(15, 5))
# 원본 이미지
axes[0].imshow(image)
axes[0].set_title("Original Image")
axes[0].axis('off')
# Attention 맵
axes[1].imshow(attn_map, cmap='viridis')
axes[1].set_title(f"Attention Map (Query: '{query_tokens[query_idx]}')")
axes[1].axis('off')
# 오버레이
attn_resized = np.array(
Image.fromarray((attn_map * 255).astype(np.uint8)).resize(
image.size, Image.BICUBIC
)
) / 255
axes[2].imshow(image)
axes[2].imshow(attn_resized, alpha=0.5, cmap='jet')
axes[2].set_title("Overlay")
axes[2].axis('off')
plt.tight_layout()
return fig
def get_attention_rollout(attn_weights_list):
"""
Attention Rollout: 여러 레이어의 attention 누적
더 정확한 attribution
"""
# 모든 레이어의 attention을 곱함
rollout = torch.eye(attn_weights_list[0].shape[-1])
for attn in attn_weights_list:
# 헤드 평균
attn = attn.mean(dim=1)
# Residual connection 고려
attn = attn + torch.eye(attn.shape[-1])
attn = attn / attn.sum(dim=-1, keepdim=True)
rollout = rollout @ attn
return rollout
실무 적용 가이드¶
퓨전 방식 선택¶
| 상황 | 추천 방식 | 이유 |
|---|---|---|
| 빠른 프로토타이핑 | Prefix (LLaVA) | 구현 쉬움 |
| 최고 성능 | Cross-Attn (Flamingo) | 깊은 통합 |
| 메모리 제약 | Q-Former | 토큰 압축 |
| 여러 이미지 대화 | Interleaved | 유연성 |
| 비디오 | Perceiver | 시간 처리 |
학습 팁¶
# 1. Gated Cross-Attention은 gate 초기값 중요
self.gate = nn.Parameter(torch.tensor(0.0)) # 0으로 시작
# 2. Vision Encoder는 보통 고정
for param in vision_encoder.parameters():
param.requires_grad = False
# 3. Projector는 높은 learning rate
optimizer = torch.optim.AdamW([
{'params': projector.parameters(), 'lr': 1e-3},
{'params': llm.parameters(), 'lr': 1e-5} # LoRA
])
# 4. Warmup 중요
scheduler = get_cosine_schedule_with_warmup(
optimizer,
num_warmup_steps=1000,
num_training_steps=total_steps
)
참고 자료¶
핵심 논문¶
- Flamingo Paper - Gated Cross-Attention 원조
- BLIP-2 Paper - Q-Former 아키텍처
- LLaVA Paper - Prefix 방식
- Perceiver Paper - Resampler 아이디어