Token Fusion 방식¶
Vision Encoder에서 추출한 이미지 토큰과 텍스트 토큰을 결합하는 방식. VLM 아키텍처의 핵심 설계 결정 중 하나다.
개요¶
VLM에서 이미지와 텍스트를 결합하는 주요 방식:
1. Concatenation (연결)¶
가장 단순한 방식. 이미지 토큰을 텍스트 토큰 앞에 붙여서 하나의 시퀀스로 처리.
아키텍처¶
구현¶
import torch
import torch.nn as nn
class ConcatFusion(nn.Module):
"""LLaVA 스타일 Concatenation Fusion"""
def __init__(self, vision_encoder, projector, llm):
super().__init__()
self.vision_encoder = vision_encoder # CLIP ViT
self.projector = projector # Linear/MLP
self.llm = llm # LLaMA, Vicuna 등
def forward(self, images, input_ids, attention_mask):
batch_size = images.shape[0]
# 1. 이미지 인코딩
# vision_encoder 출력: (batch, num_patches, vision_dim)
image_features = self.vision_encoder(images)
# 2. LLM 임베딩 공간으로 투영
# projector 출력: (batch, num_patches, llm_dim)
image_tokens = self.projector(image_features)
num_image_tokens = image_tokens.shape[1]
# 3. 텍스트 임베딩
text_embeds = self.llm.embed_tokens(input_ids)
# 4. 연결: [이미지 토큰] + [텍스트 토큰]
inputs_embeds = torch.cat([image_tokens, text_embeds], dim=1)
# 5. Attention mask 확장
image_mask = torch.ones(
(batch_size, num_image_tokens),
dtype=attention_mask.dtype,
device=attention_mask.device
)
combined_mask = torch.cat([image_mask, attention_mask], dim=1)
# 6. LLM Forward
outputs = self.llm(
inputs_embeds=inputs_embeds,
attention_mask=combined_mask,
use_cache=True
)
return outputs
장점과 단점¶
| 장점 | 단점 |
|---|---|
| 구현 단순 | Self-attention만으로 상호작용 |
| 기존 LLM 재사용 용이 | 시퀀스 길이 증가 |
| 학습 안정적 | 이미지 토큰이 많으면 비효율 |
| 추론 빠름 | 얕은 모달리티 통합 |
사용 모델¶
- LLaVA: 가장 대표적인 Concatenation 방식
- Qwen-VL: Dynamic resolution + Concatenation
- InternVL: Concatenation with large vision encoder
2. Cross-Attention¶
텍스트가 이미지를 쿼리로 참조하는 방식. 더 깊은 모달리티 상호작용.
아키텍처¶
구현¶
class CrossAttentionFusion(nn.Module):
"""Flamingo 스타일 Cross-Attention Fusion"""
def __init__(self, llm_dim, vision_dim, num_heads=8):
super().__init__()
self.num_heads = num_heads
self.head_dim = llm_dim // num_heads
self.scale = self.head_dim ** -0.5
# Query from text, Key/Value from image
self.q_proj = nn.Linear(llm_dim, llm_dim)
self.k_proj = nn.Linear(vision_dim, llm_dim)
self.v_proj = nn.Linear(vision_dim, llm_dim)
self.out_proj = nn.Linear(llm_dim, llm_dim)
# Layer normalization
self.norm_text = nn.LayerNorm(llm_dim)
self.norm_image = nn.LayerNorm(vision_dim)
def forward(self, text_hidden, image_features):
"""
text_hidden: (batch, seq_len, llm_dim)
image_features: (batch, num_patches, vision_dim)
"""
batch_size, seq_len, _ = text_hidden.shape
num_patches = image_features.shape[1]
# Normalize
text_normed = self.norm_text(text_hidden)
image_normed = self.norm_image(image_features)
# Project
Q = self.q_proj(text_normed) # (batch, seq_len, llm_dim)
K = self.k_proj(image_normed) # (batch, num_patches, llm_dim)
V = self.v_proj(image_normed) # (batch, num_patches, llm_dim)
# Reshape for multi-head attention
Q = Q.view(batch_size, seq_len, self.num_heads, self.head_dim).transpose(1, 2)
K = K.view(batch_size, num_patches, self.num_heads, self.head_dim).transpose(1, 2)
V = V.view(batch_size, num_patches, self.num_heads, self.head_dim).transpose(1, 2)
# Attention: text attends to image
attn_weights = torch.matmul(Q, K.transpose(-2, -1)) * self.scale
attn_weights = F.softmax(attn_weights, dim=-1)
# Apply attention
attn_output = torch.matmul(attn_weights, V)
attn_output = attn_output.transpose(1, 2).contiguous().view(batch_size, seq_len, -1)
return self.out_proj(attn_output)
class FlamingoLayer(nn.Module):
"""Flamingo 스타일 레이어 (Gated Cross-Attention)"""
def __init__(self, llm_dim, vision_dim, num_heads=8):
super().__init__()
self.cross_attn = CrossAttentionFusion(llm_dim, vision_dim, num_heads)
# Learnable gate (초기값 0으로 점진적 학습)
self.gate = nn.Parameter(torch.zeros(1))
def forward(self, text_hidden, image_features):
# Cross-attention
attn_out = self.cross_attn(text_hidden, image_features)
# Gated residual connection
# tanh(gate)는 초기에 0에 가까워 안정적 학습
return text_hidden + torch.tanh(self.gate) * attn_out
Flamingo 전체 아키텍처¶
class FlamingoModel(nn.Module):
def __init__(self, vision_encoder, llm, cross_attn_every_n_layers=4):
super().__init__()
self.vision_encoder = vision_encoder
self.perceiver = PerceiverResampler(dim=vision_encoder.output_dim)
self.llm = llm
# Cross-attention 레이어 삽입 (N개 레이어마다)
self.cross_attn_layers = nn.ModuleDict()
for i in range(llm.config.num_hidden_layers):
if (i + 1) % cross_attn_every_n_layers == 0:
self.cross_attn_layers[str(i)] = FlamingoLayer(
llm.config.hidden_size,
vision_encoder.output_dim
)
def forward(self, images, input_ids):
# 이미지 인코딩 + Perceiver로 압축
image_features = self.vision_encoder(images)
image_features = self.perceiver(image_features)
# LLM forward with cross-attention injection
hidden_states = self.llm.embed_tokens(input_ids)
for i, layer in enumerate(self.llm.layers):
# 원래 LLM 레이어
hidden_states = layer(hidden_states)
# Cross-attention 삽입 (해당 레이어에만)
if str(i) in self.cross_attn_layers:
hidden_states = self.cross_attn_layers[str(i)](
hidden_states, image_features
)
return self.llm.lm_head(hidden_states)
장점과 단점¶
| 장점 | 단점 |
|---|---|
| 깊은 모달리티 통합 | 추가 파라미터 필요 |
| 시퀀스 길이 유지 | 구현 복잡 |
| 텍스트가 이미지 선택적 참조 | 학습 어려움 |
| 더 나은 시각적 추론 | 계산 비용 증가 |
사용 모델¶
- Flamingo: Gated Cross-Attention 원조
- IDEFICS: Flamingo 오픈소스 구현
- Otter: Flamingo 기반 instruction tuning
3. Gated Fusion¶
게이트를 통해 두 모달리티의 기여도를 동적으로 조절.
아키텍처¶
구현¶
class GatedFusion(nn.Module):
"""게이트 기반 모달리티 융합"""
def __init__(self, dim):
super().__init__()
# 게이트 네트워크
self.gate_net = nn.Sequential(
nn.Linear(dim * 2, dim),
nn.ReLU(),
nn.Linear(dim, dim),
nn.Sigmoid()
)
# 모달리티별 projection
self.image_proj = nn.Linear(dim, dim)
self.text_proj = nn.Linear(dim, dim)
def forward(self, image_features, text_features):
"""
image_features: (batch, dim)
text_features: (batch, dim)
"""
# 두 모달리티 연결해서 게이트 값 계산
combined = torch.cat([image_features, text_features], dim=-1)
gate = self.gate_net(combined) # (batch, dim)
# 게이트 적용
image_proj = self.image_proj(image_features)
text_proj = self.text_proj(text_features)
# 가중 합
fused = gate * image_proj + (1 - gate) * text_proj
return fused
class TokenWiseGatedFusion(nn.Module):
"""토큰 단위 게이트 융합"""
def __init__(self, dim):
super().__init__()
self.gate = nn.Sequential(
nn.Linear(dim * 2, dim),
nn.Tanh(),
nn.Linear(dim, 1),
nn.Sigmoid()
)
def forward(self, image_tokens, text_tokens):
"""
image_tokens: (batch, num_image, dim)
text_tokens: (batch, num_text, dim)
"""
# 이미지 토큰을 텍스트 위치에 매핑 (평균 또는 attention)
image_summary = image_tokens.mean(dim=1, keepdim=True)
image_expanded = image_summary.expand_as(text_tokens)
# 각 텍스트 토큰에 대해 게이트 계산
combined = torch.cat([image_expanded, text_tokens], dim=-1)
gates = self.gate(combined) # (batch, num_text, 1)
# 게이트 적용
fused = gates * image_expanded + (1 - gates) * text_tokens
return fused
장점과 단점¶
| 장점 | 단점 |
|---|---|
| 적응적 모달리티 선택 | 추가 파라미터 |
| 해석 가능 (gate 값) | 학습 불안정 가능 |
| 모달리티 균형 조절 | 복잡한 튜닝 |
4. Q-Former (Learned Query)¶
학습 가능한 쿼리로 이미지 정보를 압축하는 방식.
아키텍처¶
구현¶
class QFormer(nn.Module):
"""BLIP-2 스타일 Q-Former"""
def __init__(
self,
num_queries=32,
hidden_size=768,
num_heads=12,
num_layers=6,
vision_dim=1024
):
super().__init__()
# 학습 가능한 쿼리
self.queries = nn.Parameter(torch.randn(1, num_queries, hidden_size))
# Q-Former 레이어
self.layers = nn.ModuleList([
QFormerLayer(hidden_size, num_heads, vision_dim)
for _ in range(num_layers)
])
self.norm = nn.LayerNorm(hidden_size)
def forward(self, image_features):
"""
image_features: (batch, num_patches, vision_dim)
returns: (batch, num_queries, hidden_size)
"""
batch_size = image_features.shape[0]
# 쿼리 확장
queries = self.queries.expand(batch_size, -1, -1)
# Q-Former 레이어 통과
for layer in self.layers:
queries = layer(queries, image_features)
return self.norm(queries)
class QFormerLayer(nn.Module):
def __init__(self, hidden_size, num_heads, vision_dim):
super().__init__()
# Self-attention (쿼리 간)
self.self_attn = nn.MultiheadAttention(
hidden_size, num_heads, batch_first=True
)
self.norm1 = nn.LayerNorm(hidden_size)
# Cross-attention (쿼리 → 이미지)
self.cross_attn = nn.MultiheadAttention(
hidden_size, num_heads, batch_first=True,
kdim=vision_dim, vdim=vision_dim
)
self.norm2 = nn.LayerNorm(hidden_size)
# FFN
self.ffn = nn.Sequential(
nn.Linear(hidden_size, hidden_size * 4),
nn.GELU(),
nn.Linear(hidden_size * 4, hidden_size)
)
self.norm3 = nn.LayerNorm(hidden_size)
def forward(self, queries, image_features):
# Self-attention
q = self.norm1(queries)
queries = queries + self.self_attn(q, q, q)[0]
# Cross-attention with image
q = self.norm2(queries)
queries = queries + self.cross_attn(
q, image_features, image_features
)[0]
# FFN
queries = queries + self.ffn(self.norm3(queries))
return queries
장점과 단점¶
| 장점 | 단점 |
|---|---|
| 시각 토큰 수 고정 (효율) | 정보 손실 가능 |
| LLM 수정 불필요 | 추가 학습 필요 |
| 모듈화된 설계 | 복잡한 사전학습 |
사용 모델¶
- BLIP-2: Q-Former 원조
- InstructBLIP: Q-Former + Instruction tuning
- MiniGPT-4: Q-Former 활용
5. Perceiver Resampler¶
가변 길이 이미지 토큰을 고정 길이로 리샘플링.
아키텍처¶
구현¶
class PerceiverResampler(nn.Module):
"""Flamingo 스타일 Perceiver Resampler"""
def __init__(
self,
dim,
num_latents=64,
num_heads=8,
num_layers=6
):
super().__init__()
# 학습 가능한 latent 쿼리
self.latents = nn.Parameter(torch.randn(num_latents, dim))
# Temporal embedding (여러 프레임 처리용)
self.time_embed = nn.Embedding(100, dim)
# Perceiver 레이어
self.layers = nn.ModuleList([
PerceiverLayer(dim, num_heads)
for _ in range(num_layers)
])
self.norm = nn.LayerNorm(dim)
def forward(self, image_features, time_idx=None):
"""
image_features: (batch, num_patches, dim) 또는
(batch, num_frames, num_patches, dim)
"""
batch_size = image_features.shape[0]
# 비디오인 경우 temporal embedding 추가
if image_features.dim() == 4:
num_frames = image_features.shape[1]
time_emb = self.time_embed(
torch.arange(num_frames, device=image_features.device)
)
image_features = image_features + time_emb.unsqueeze(0).unsqueeze(2)
image_features = image_features.flatten(1, 2)
# Latent 쿼리 확장
latents = self.latents.unsqueeze(0).expand(batch_size, -1, -1)
# Perceiver 레이어 통과
for layer in self.layers:
latents = layer(latents, image_features)
return self.norm(latents)
class PerceiverLayer(nn.Module):
def __init__(self, dim, num_heads):
super().__init__()
# Cross-attention: latents가 image를 attend
self.cross_attn = nn.MultiheadAttention(
dim, num_heads, batch_first=True
)
self.norm1 = nn.LayerNorm(dim)
# Self-attention: latents 간 상호작용
self.self_attn = nn.MultiheadAttention(
dim, num_heads, batch_first=True
)
self.norm2 = nn.LayerNorm(dim)
# FFN
self.ffn = nn.Sequential(
nn.Linear(dim, dim * 4),
nn.GELU(),
nn.Linear(dim * 4, dim)
)
self.norm3 = nn.LayerNorm(dim)
def forward(self, latents, image_features):
# Cross-attention with image
ln = self.norm1(latents)
latents = latents + self.cross_attn(
ln, image_features, image_features
)[0]
# Self-attention
ln = self.norm2(latents)
latents = latents + self.self_attn(ln, ln, ln)[0]
# FFN
latents = latents + self.ffn(self.norm3(latents))
return latents
방식별 비교¶
정량적 비교¶
| 방식 | 추가 파라미터 | 추론 비용 | 학습 난이도 | 성능 |
|---|---|---|---|---|
| Concatenation | 낮음 | 중간 | 쉬움 | 좋음 |
| Cross-Attention | 중간 | 높음 | 어려움 | 최고 |
| Gated Fusion | 낮음 | 낮음 | 중간 | 좋음 |
| Q-Former | 높음 | 낮음 | 어려움 | 좋음 |
| Perceiver | 중간 | 낮음 | 중간 | 좋음 |
사용 사례별 권장¶
| 사용 사례 | 권장 방식 | 이유 |
|---|---|---|
| 빠른 프로토타이핑 | Concatenation | 구현 단순 |
| 최고 성능 추구 | Cross-Attention | 깊은 통합 |
| 효율성 중시 | Q-Former/Perceiver | 토큰 압축 |
| 해석 필요 | Gated Fusion | 게이트 분석 |
| 비디오 처리 | Perceiver | 프레임 처리 |
최신 트렌드¶
Interleaved Fusion (Gemini 스타일)¶
이미지와 텍스트를 자연스럽게 섞어서 처리:
class InterleavedFusion(nn.Module):
"""이미지-텍스트 인터리빙 방식"""
def __init__(self, vision_encoder, projector, llm):
super().__init__()
self.vision_encoder = vision_encoder
self.projector = projector
self.llm = llm
# 특수 토큰 ID
self.img_start_id = llm.config.vocab_size # <img>
self.img_end_id = llm.config.vocab_size + 1 # </img>
def forward(self, text_with_placeholders, images, image_positions):
"""
text_with_placeholders: "User: <img></img> What is this?"
images: [image1, image2, ...]
image_positions: [(start1, end1), (start2, end2), ...]
"""
# 텍스트 임베딩
text_embeds = self.llm.embed_tokens(text_with_placeholders)
# 이미지 인코딩
for i, (img, (start, end)) in enumerate(zip(images, image_positions)):
img_features = self.vision_encoder(img.unsqueeze(0))
img_tokens = self.projector(img_features).squeeze(0)
# 해당 위치에 이미지 토큰 삽입
text_embeds = torch.cat([
text_embeds[:, :start],
img_tokens.unsqueeze(0),
text_embeds[:, end:]
], dim=1)
return self.llm(inputs_embeds=text_embeds)
Dynamic Resolution Fusion¶
해상도에 따라 토큰 수가 변하는 방식:
class DynamicResolutionFusion(nn.Module):
"""Qwen-VL 스타일 동적 해상도"""
def __init__(self, vision_encoder, llm):
super().__init__()
self.vision_encoder = vision_encoder
self.llm = llm
# 해상도별 position embedding
self.spatial_pos_embed = nn.Parameter(
torch.randn(1, 4096, vision_encoder.output_dim)
)
def forward(self, images, input_ids):
all_image_tokens = []
for img in images:
# 이미지 크기에 따른 패치 수 결정
h, w = img.shape[-2:]
num_patches_h = h // self.vision_encoder.patch_size
num_patches_w = w // self.vision_encoder.patch_size
num_patches = num_patches_h * num_patches_w
# 인코딩
features = self.vision_encoder(img.unsqueeze(0))
# 동적 position embedding
pos_emb = self.spatial_pos_embed[:, :num_patches]
features = features + pos_emb
all_image_tokens.append(features)
# 가변 길이 이미지 토큰과 텍스트 결합
# ...
참고 자료¶
- LLaVA Paper - Concatenation
- Flamingo Paper - Cross-Attention, Perceiver
- BLIP-2 Paper - Q-Former
- Qwen-VL Paper - Dynamic Resolution
- awesome-vlm-architectures - VLM 아키텍처 모음