콘텐츠로 이동
Data Prep
상세

Cross-modal Attention

이미지와 텍스트 모달리티 간의 상호작용을 학습하는 메커니즘. VLM의 핵심 구성 요소다.

Cross-Attention 기본

Self-Attention vs Cross-Attention

Self-Attention:
  같은 시퀀스 내에서 상호작용
  Q, K, V 모두 같은 입력에서 유래

Cross-Attention:
  다른 시퀀스 간 상호작용
  Q는 한 모달리티, K/V는 다른 모달리티

VLM에서의 역할

Text가 Image를 참조:
  Q = Text tokens (질문: "이 이미지에서 무엇을 찾을까?")
  K, V = Image tokens (답변 소스)

결과: 각 텍스트 토큰이 관련 이미지 영역에 집중

수식

\[\text{CrossAttn}(Q_{text}, K_{image}, V_{image}) = \text{softmax}\left(\frac{Q_{text}K_{image}^T}{\sqrt{d_k}}\right)V_{image}\]
  • \(Q_{text}\): 텍스트에서 유래한 Query
  • \(K_{image}, V_{image}\): 이미지에서 유래한 Key, Value
  • \(d_k\): Key 차원 (스케일링)

구현

import torch
import torch.nn as nn
import torch.nn.functional as F

class CrossAttention(nn.Module):
    """기본 Cross-Attention"""

    def __init__(self, query_dim, kv_dim, num_heads=8, dropout=0.1):
        super().__init__()
        self.num_heads = num_heads
        self.head_dim = query_dim // num_heads
        self.scale = self.head_dim ** -0.5

        # Query from text, Key/Value from image
        self.q_proj = nn.Linear(query_dim, query_dim)
        self.k_proj = nn.Linear(kv_dim, query_dim)
        self.v_proj = nn.Linear(kv_dim, query_dim)
        self.out_proj = nn.Linear(query_dim, query_dim)

        self.dropout = nn.Dropout(dropout)

    def forward(self, query, key_value, attention_mask=None):
        """
        query: 텍스트 토큰 (batch, seq_len, query_dim)
        key_value: 이미지 토큰 (batch, num_patches, kv_dim)
        """
        batch_size, seq_len, _ = query.shape
        num_patches = key_value.shape[1]

        # Linear projections
        Q = self.q_proj(query)
        K = self.k_proj(key_value)
        V = self.v_proj(key_value)

        # Reshape for multi-head attention
        # (batch, seq, dim) -> (batch, heads, seq, head_dim)
        Q = Q.view(batch_size, seq_len, self.num_heads, self.head_dim).transpose(1, 2)
        K = K.view(batch_size, num_patches, self.num_heads, self.head_dim).transpose(1, 2)
        V = V.view(batch_size, num_patches, self.num_heads, self.head_dim).transpose(1, 2)

        # Attention scores: Q @ K^T / sqrt(d)
        attn_weights = torch.matmul(Q, K.transpose(-2, -1)) * self.scale
        # attn_weights: (batch, heads, seq_len, num_patches)

        # Optional: attention mask
        if attention_mask is not None:
            attn_weights = attn_weights.masked_fill(attention_mask == 0, float('-inf'))

        # Softmax + dropout
        attn_weights = F.softmax(attn_weights, dim=-1)
        attn_weights = self.dropout(attn_weights)

        # Apply attention to values
        attn_output = torch.matmul(attn_weights, V)
        # attn_output: (batch, heads, seq_len, head_dim)

        # Reshape back
        attn_output = attn_output.transpose(1, 2).contiguous()
        attn_output = attn_output.view(batch_size, seq_len, -1)

        return self.out_proj(attn_output), attn_weights

VLM 퓨전 방식 비교

개요

방식 대표 모델 특징 복잡도
Prefix (Concatenation) LLaVA 단순, 효율적 낮음
Interleaved Cross-Attn Flamingo 깊은 통합 높음
Q-Former BLIP-2 토큰 압축 중간
Interleaved Tokens Gemini 유연함 중간

1. Prefix Tokens (LLaVA 스타일)

이미지 토큰을 텍스트 시퀀스 앞에 연결.

cross-attention diagram 1

class LLaVA(nn.Module):
    """Prefix-style VLM"""

    def __init__(self, vision_encoder, projector, llm):
        super().__init__()
        self.vision_encoder = vision_encoder
        self.projector = projector
        self.llm = llm

    def forward(self, images, input_ids, attention_mask):
        # 1. 이미지 인코딩
        with torch.no_grad():
            image_features = self.vision_encoder(images)

        # 2. LLM 공간으로 투영
        image_tokens = self.projector(image_features)
        num_image_tokens = image_tokens.shape[1]

        # 3. 텍스트 임베딩
        text_embeds = self.llm.embed_tokens(input_ids)

        # 4. [이미지 토큰] + [텍스트 토큰] 연결
        inputs_embeds = torch.cat([image_tokens, text_embeds], dim=1)

        # 5. Attention mask 확장
        batch_size = images.shape[0]
        image_mask = torch.ones(
            (batch_size, num_image_tokens),
            dtype=attention_mask.dtype,
            device=attention_mask.device
        )
        extended_mask = torch.cat([image_mask, attention_mask], dim=1)

        # 6. LLM forward
        outputs = self.llm(
            inputs_embeds=inputs_embeds,
            attention_mask=extended_mask
        )

        return outputs

    def generate(self, images, prompt_ids, max_new_tokens=256):
        """생성"""
        # 이미지 토큰 준비
        with torch.no_grad():
            image_features = self.vision_encoder(images)
        image_tokens = self.projector(image_features)

        # 프롬프트 임베딩
        prompt_embeds = self.llm.embed_tokens(prompt_ids)
        inputs_embeds = torch.cat([image_tokens, prompt_embeds], dim=1)

        # 생성
        return self.llm.generate(
            inputs_embeds=inputs_embeds,
            max_new_tokens=max_new_tokens,
            do_sample=True,
            temperature=0.7
        )

장점: - 구현 단순, LLM 수정 불필요 - 학습 안정적 - 추론 효율적

단점: - Self-attention만으로 모달리티 상호작용 - 이미지 토큰이 많으면 컨텍스트 소모

2. Interleaved Cross-Attention (Flamingo 스타일)

LLM 레이어 사이에 Cross-Attention 삽입.

cross-attention diagram 2

class FlamingoBlock(nn.Module):
    """Gated Cross-Attention Block"""

    def __init__(self, llm_dim, vision_dim, num_heads=8):
        super().__init__()

        # Cross-Attention
        self.cross_attention = CrossAttention(llm_dim, vision_dim, num_heads)
        self.norm = nn.LayerNorm(llm_dim)

        # Learnable gate (중요!)
        # 초기값 0: 처음에는 Cross-Attention이 거의 영향 없음
        # 학습하며 점차 증가 → 안정적 학습
        self.gate = nn.Parameter(torch.tensor(0.0))

    def forward(self, text_hidden, image_features):
        # Cross-attention
        normed = self.norm(text_hidden)
        attn_out, attn_weights = self.cross_attention(normed, image_features)

        # Gated residual: tanh(gate) * attn_out
        # tanh으로 -1 ~ 1 범위로 제한
        return text_hidden + torch.tanh(self.gate) * attn_out


class FlamingoLLM(nn.Module):
    """Flamingo-style LLM with Cross-Attention"""

    def __init__(self, llm, vision_dim, cross_attn_every=4):
        super().__init__()
        self.llm = llm
        llm_dim = llm.config.hidden_size
        num_layers = llm.config.num_hidden_layers

        # Cross-Attention 레이어 (N개마다 삽입)
        self.cross_attn_layers = nn.ModuleDict()
        for i in range(num_layers):
            if (i + 1) % cross_attn_every == 0:
                self.cross_attn_layers[str(i)] = FlamingoBlock(
                    llm_dim, vision_dim, num_heads=8
                )

        # Perceiver Resampler (이미지 토큰 압축)
        self.perceiver = PerceiverResampler(vision_dim, num_latents=64)

    def forward(self, input_ids, image_features):
        # 이미지 압축
        compressed_image = self.perceiver(image_features)

        # LLM embedding
        hidden_states = self.llm.embed_tokens(input_ids)

        # 각 레이어 통과
        for i, layer in enumerate(self.llm.layers):
            # LLM 레이어
            hidden_states = layer(hidden_states)

            # Cross-Attention (해당 레이어에만)
            if str(i) in self.cross_attn_layers:
                hidden_states = self.cross_attn_layers[str(i)](
                    hidden_states, compressed_image
                )

        # LM head
        hidden_states = self.llm.norm(hidden_states)
        return self.llm.lm_head(hidden_states)

장점: - 깊은 모달리티 통합 - 텍스트가 이미지를 선택적으로 참조 - 시퀀스 길이 유지

단점: - 추가 파라미터 - 구현 복잡 - 학습 어려움 (게이트 필수)

3. Interleaved Tokens (Gemini 스타일)

이미지와 텍스트를 자연스럽게 섞음.

사용자: <image>이미지1</image> 여기서 <image>이미지2</image>와 비교해서...
class InterleavedVLM(nn.Module):
    """이미지-텍스트 인터리빙"""

    def __init__(self, vision_encoder, projector, llm, tokenizer):
        super().__init__()
        self.vision_encoder = vision_encoder
        self.projector = projector
        self.llm = llm
        self.tokenizer = tokenizer

        # 특수 토큰
        self.img_token_id = tokenizer.convert_tokens_to_ids("<image>")

    def forward(self, text_with_placeholders, images):
        """
        text_with_placeholders: "User: <image> What is this?"
        images: [image1] (placeholder 수만큼)
        """
        # 1. 텍스트 토크나이즈
        input_ids = self.tokenizer(text_with_placeholders, return_tensors="pt").input_ids

        # 2. <image> 위치 찾기
        image_positions = (input_ids == self.img_token_id).nonzero(as_tuple=True)[1]

        # 3. 텍스트 임베딩
        embeddings = self.llm.embed_tokens(input_ids)

        # 4. 이미지 인코딩 및 삽입
        offset = 0  # 삽입으로 인한 위치 변화

        for img_idx, pos in enumerate(image_positions):
            # 이미지 인코딩
            with torch.no_grad():
                img_features = self.vision_encoder(images[img_idx:img_idx+1])
            img_tokens = self.projector(img_features)
            num_img_tokens = img_tokens.shape[1]

            # 삽입 위치 조정
            insert_pos = pos.item() + offset

            # 임베딩에 이미지 토큰 삽입
            embeddings = torch.cat([
                embeddings[:, :insert_pos],
                img_tokens,
                embeddings[:, insert_pos+1:]  # <image> 토큰 대체
            ], dim=1)

            # 오프셋 업데이트 (1개 토큰이 num_img_tokens개로)
            offset += num_img_tokens - 1

        # 5. LLM forward
        return self.llm(inputs_embeds=embeddings)

장점: - 자연스러운 멀티모달 대화 - 여러 이미지 유연하게 처리 - 이미지 위치가 문맥에 맞음

단점: - 구현 복잡 - 토큰 위치 관리 필요

Perceiver Resampler

가변 길이 이미지 토큰을 고정 길이로 압축. Flamingo의 핵심 컴포넌트.

왜 필요한가

cross-attention diagram 3

구현

class PerceiverResampler(nn.Module):
    """Flamingo의 Perceiver Resampler"""

    def __init__(
        self,
        dim,
        num_latents=64,      # 출력 토큰 수
        num_heads=8,
        num_layers=6,
        ff_mult=4
    ):
        super().__init__()

        # 학습 가능한 latent 쿼리 (출력 토큰)
        self.latents = nn.Parameter(torch.randn(num_latents, dim))

        # Temporal embedding (비디오용)
        self.time_embed = nn.Embedding(256, dim)

        # Perceiver 레이어
        self.layers = nn.ModuleList([
            PerceiverLayer(dim, num_heads, ff_mult)
            for _ in range(num_layers)
        ])

        self.norm = nn.LayerNorm(dim)

    def forward(self, image_features, time_indices=None):
        """
        image_features: (batch, num_patches, dim)
                       또는 (batch, num_frames, num_patches, dim)
        """
        batch_size = image_features.shape[0]

        # 비디오 처리: temporal embedding 추가
        if image_features.dim() == 4:
            num_frames, num_patches = image_features.shape[1:3]

            if time_indices is None:
                time_indices = torch.arange(num_frames, device=image_features.device)

            time_emb = self.time_embed(time_indices)  # (frames, dim)
            time_emb = time_emb.unsqueeze(0).unsqueeze(2)  # (1, frames, 1, dim)

            image_features = image_features + time_emb
            image_features = image_features.flatten(1, 2)  # (batch, frames*patches, dim)

        # Latent 쿼리 확장
        latents = self.latents.unsqueeze(0).expand(batch_size, -1, -1)

        # Perceiver 레이어 통과
        for layer in self.layers:
            latents = layer(latents, image_features)

        return self.norm(latents)


class PerceiverLayer(nn.Module):
    def __init__(self, dim, num_heads, ff_mult=4):
        super().__init__()

        # Cross-attention: latents attend to image
        self.cross_attn = nn.MultiheadAttention(
            dim, num_heads, batch_first=True
        )
        self.norm1 = nn.LayerNorm(dim)

        # Self-attention: latents interact
        self.self_attn = nn.MultiheadAttention(
            dim, num_heads, batch_first=True
        )
        self.norm2 = nn.LayerNorm(dim)

        # Feed-forward
        self.ffn = nn.Sequential(
            nn.Linear(dim, dim * ff_mult),
            nn.GELU(),
            nn.Linear(dim * ff_mult, dim)
        )
        self.norm3 = nn.LayerNorm(dim)

    def forward(self, latents, image_features):
        # Cross-attention with image
        ln = self.norm1(latents)
        attn_out, _ = self.cross_attn(ln, image_features, image_features)
        latents = latents + attn_out

        # Self-attention among latents
        ln = self.norm2(latents)
        attn_out, _ = self.self_attn(ln, ln, ln)
        latents = latents + attn_out

        # FFN
        latents = latents + self.ffn(self.norm3(latents))

        return latents

Q-Former (BLIP-2)

Vision Encoder와 LLM 사이의 경량 브릿지.

아키텍처

cross-attention diagram 4

구현

class QFormer(nn.Module):
    """BLIP-2의 Q-Former"""

    def __init__(
        self,
        num_queries=32,
        hidden_size=768,
        num_heads=12,
        num_layers=6,
        vision_dim=1024
    ):
        super().__init__()

        # 학습 가능한 쿼리
        self.queries = nn.Parameter(torch.randn(1, num_queries, hidden_size))

        # 텍스트 인코더 (BERT 기반, 선택적)
        self.text_encoder = BertModel.from_pretrained("bert-base-uncased")

        # Q-Former 레이어
        self.layers = nn.ModuleList([
            QFormerLayer(hidden_size, num_heads, vision_dim)
            for _ in range(num_layers)
        ])

        self.norm = nn.LayerNorm(hidden_size)

        # 태스크별 projection
        self.itc_proj = nn.Linear(hidden_size, 256)  # contrastive
        self.itm_head = nn.Linear(hidden_size, 2)    # matching
        self.lm_head = nn.Linear(hidden_size, 30522) # generation

    def forward(self, image_features, text_input_ids=None, mode="fusion"):
        """
        mode: "fusion" | "itc" | "itm" | "generation"
        """
        batch_size = image_features.shape[0]
        queries = self.queries.expand(batch_size, -1, -1)

        # Q-Former 레이어
        for layer in self.layers:
            queries = layer(queries, image_features, text_input_ids)

        queries = self.norm(queries)

        if mode == "itc":
            # Image-Text Contrastive
            return self.itc_proj(queries[:, 0])
        elif mode == "itm":
            # Image-Text Matching
            return self.itm_head(queries[:, 0])
        elif mode == "generation":
            # LLM에 전달할 representation
            return queries
        else:
            return queries


class QFormerLayer(nn.Module):
    def __init__(self, hidden_size, num_heads, vision_dim):
        super().__init__()

        # Self-attention (쿼리 간, 텍스트와 함께)
        self.self_attn = nn.MultiheadAttention(
            hidden_size, num_heads, batch_first=True
        )
        self.norm1 = nn.LayerNorm(hidden_size)

        # Cross-attention (쿼리 → 이미지)
        self.cross_attn = nn.MultiheadAttention(
            hidden_size, num_heads, batch_first=True,
            kdim=vision_dim, vdim=vision_dim
        )
        self.norm2 = nn.LayerNorm(hidden_size)

        # FFN
        self.ffn = nn.Sequential(
            nn.Linear(hidden_size, hidden_size * 4),
            nn.GELU(),
            nn.Linear(hidden_size * 4, hidden_size)
        )
        self.norm3 = nn.LayerNorm(hidden_size)

    def forward(self, queries, image_features, text_embeds=None):
        # 텍스트가 있으면 쿼리와 연결
        if text_embeds is not None:
            combined = torch.cat([queries, text_embeds], dim=1)
        else:
            combined = queries

        # Self-attention
        q = self.norm1(combined)
        combined = combined + self.self_attn(q, q, q)[0]

        # Cross-attention with image (쿼리 부분만)
        q = self.norm2(combined[:, :queries.shape[1]])
        queries = queries + self.cross_attn(
            q, image_features, image_features
        )[0]

        # FFN
        queries = queries + self.ffn(self.norm3(queries))

        return queries

Attention 시각화

모델이 어디를 보는지 이해하기 위한 시각화.

import matplotlib.pyplot as plt
import numpy as np

def visualize_cross_attention(
    attn_weights,
    image,
    query_tokens,
    patch_size=14,
    query_idx=0
):
    """
    Cross-attention 가중치 시각화

    attn_weights: (batch, heads, seq_len, num_patches)
    """
    # 평균 attention (헤드 평균)
    weights = attn_weights[0].mean(0)  # (seq_len, num_patches)

    # 특정 쿼리 토큰의 attention
    query_attn = weights[query_idx]  # (num_patches,)

    # 2D로 reshape
    h = w = int(np.sqrt(len(query_attn)))
    attn_map = query_attn.reshape(h, w).detach().cpu().numpy()

    # 시각화
    fig, axes = plt.subplots(1, 3, figsize=(15, 5))

    # 원본 이미지
    axes[0].imshow(image)
    axes[0].set_title("Original Image")
    axes[0].axis('off')

    # Attention 맵
    axes[1].imshow(attn_map, cmap='viridis')
    axes[1].set_title(f"Attention Map (Query: '{query_tokens[query_idx]}')")
    axes[1].axis('off')

    # 오버레이
    attn_resized = np.array(
        Image.fromarray((attn_map * 255).astype(np.uint8)).resize(
            image.size, Image.BICUBIC
        )
    ) / 255
    axes[2].imshow(image)
    axes[2].imshow(attn_resized, alpha=0.5, cmap='jet')
    axes[2].set_title("Overlay")
    axes[2].axis('off')

    plt.tight_layout()
    return fig


def get_attention_rollout(attn_weights_list):
    """
    Attention Rollout: 여러 레이어의 attention 누적
    더 정확한 attribution
    """
    # 모든 레이어의 attention을 곱함
    rollout = torch.eye(attn_weights_list[0].shape[-1])

    for attn in attn_weights_list:
        # 헤드 평균
        attn = attn.mean(dim=1)

        # Residual connection 고려
        attn = attn + torch.eye(attn.shape[-1])
        attn = attn / attn.sum(dim=-1, keepdim=True)

        rollout = rollout @ attn

    return rollout

실무 적용 가이드

퓨전 방식 선택

상황 추천 방식 이유
빠른 프로토타이핑 Prefix (LLaVA) 구현 쉬움
최고 성능 Cross-Attn (Flamingo) 깊은 통합
메모리 제약 Q-Former 토큰 압축
여러 이미지 대화 Interleaved 유연성
비디오 Perceiver 시간 처리

학습 팁

# 1. Gated Cross-Attention은 gate 초기값 중요
self.gate = nn.Parameter(torch.tensor(0.0))  # 0으로 시작

# 2. Vision Encoder는 보통 고정
for param in vision_encoder.parameters():
    param.requires_grad = False

# 3. Projector는 높은 learning rate
optimizer = torch.optim.AdamW([
    {'params': projector.parameters(), 'lr': 1e-3},
    {'params': llm.parameters(), 'lr': 1e-5}  # LoRA
])

# 4. Warmup 중요
scheduler = get_cosine_schedule_with_warmup(
    optimizer,
    num_warmup_steps=1000,
    num_training_steps=total_steps
)

참고 자료

핵심 논문

코드/구현