콘텐츠로 이동
Data Prep
상세

Token Fusion 방식

Vision Encoder에서 추출한 이미지 토큰과 텍스트 토큰을 결합하는 방식. VLM 아키텍처의 핵심 설계 결정 중 하나다.

개요

VLM에서 이미지와 텍스트를 결합하는 주요 방식:

token-fusion diagram 1

1. Concatenation (연결)

가장 단순한 방식. 이미지 토큰을 텍스트 토큰 앞에 붙여서 하나의 시퀀스로 처리.

아키텍처

token-fusion diagram 2

구현

import torch
import torch.nn as nn

class ConcatFusion(nn.Module):
    """LLaVA 스타일 Concatenation Fusion"""

    def __init__(self, vision_encoder, projector, llm):
        super().__init__()
        self.vision_encoder = vision_encoder  # CLIP ViT
        self.projector = projector            # Linear/MLP
        self.llm = llm                        # LLaMA, Vicuna 등

    def forward(self, images, input_ids, attention_mask):
        batch_size = images.shape[0]

        # 1. 이미지 인코딩
        # vision_encoder 출력: (batch, num_patches, vision_dim)
        image_features = self.vision_encoder(images)

        # 2. LLM 임베딩 공간으로 투영
        # projector 출력: (batch, num_patches, llm_dim)
        image_tokens = self.projector(image_features)
        num_image_tokens = image_tokens.shape[1]

        # 3. 텍스트 임베딩
        text_embeds = self.llm.embed_tokens(input_ids)

        # 4. 연결: [이미지 토큰] + [텍스트 토큰]
        inputs_embeds = torch.cat([image_tokens, text_embeds], dim=1)

        # 5. Attention mask 확장
        image_mask = torch.ones(
            (batch_size, num_image_tokens),
            dtype=attention_mask.dtype,
            device=attention_mask.device
        )
        combined_mask = torch.cat([image_mask, attention_mask], dim=1)

        # 6. LLM Forward
        outputs = self.llm(
            inputs_embeds=inputs_embeds,
            attention_mask=combined_mask,
            use_cache=True
        )

        return outputs

장점과 단점

장점 단점
구현 단순 Self-attention만으로 상호작용
기존 LLM 재사용 용이 시퀀스 길이 증가
학습 안정적 이미지 토큰이 많으면 비효율
추론 빠름 얕은 모달리티 통합

사용 모델

  • LLaVA: 가장 대표적인 Concatenation 방식
  • Qwen-VL: Dynamic resolution + Concatenation
  • InternVL: Concatenation with large vision encoder

2. Cross-Attention

텍스트가 이미지를 쿼리로 참조하는 방식. 더 깊은 모달리티 상호작용.

아키텍처

token-fusion diagram 3

구현

class CrossAttentionFusion(nn.Module):
    """Flamingo 스타일 Cross-Attention Fusion"""

    def __init__(self, llm_dim, vision_dim, num_heads=8):
        super().__init__()
        self.num_heads = num_heads
        self.head_dim = llm_dim // num_heads
        self.scale = self.head_dim ** -0.5

        # Query from text, Key/Value from image
        self.q_proj = nn.Linear(llm_dim, llm_dim)
        self.k_proj = nn.Linear(vision_dim, llm_dim)
        self.v_proj = nn.Linear(vision_dim, llm_dim)
        self.out_proj = nn.Linear(llm_dim, llm_dim)

        # Layer normalization
        self.norm_text = nn.LayerNorm(llm_dim)
        self.norm_image = nn.LayerNorm(vision_dim)

    def forward(self, text_hidden, image_features):
        """
        text_hidden: (batch, seq_len, llm_dim)
        image_features: (batch, num_patches, vision_dim)
        """
        batch_size, seq_len, _ = text_hidden.shape
        num_patches = image_features.shape[1]

        # Normalize
        text_normed = self.norm_text(text_hidden)
        image_normed = self.norm_image(image_features)

        # Project
        Q = self.q_proj(text_normed)  # (batch, seq_len, llm_dim)
        K = self.k_proj(image_normed) # (batch, num_patches, llm_dim)
        V = self.v_proj(image_normed) # (batch, num_patches, llm_dim)

        # Reshape for multi-head attention
        Q = Q.view(batch_size, seq_len, self.num_heads, self.head_dim).transpose(1, 2)
        K = K.view(batch_size, num_patches, self.num_heads, self.head_dim).transpose(1, 2)
        V = V.view(batch_size, num_patches, self.num_heads, self.head_dim).transpose(1, 2)

        # Attention: text attends to image
        attn_weights = torch.matmul(Q, K.transpose(-2, -1)) * self.scale
        attn_weights = F.softmax(attn_weights, dim=-1)

        # Apply attention
        attn_output = torch.matmul(attn_weights, V)
        attn_output = attn_output.transpose(1, 2).contiguous().view(batch_size, seq_len, -1)

        return self.out_proj(attn_output)


class FlamingoLayer(nn.Module):
    """Flamingo 스타일 레이어 (Gated Cross-Attention)"""

    def __init__(self, llm_dim, vision_dim, num_heads=8):
        super().__init__()
        self.cross_attn = CrossAttentionFusion(llm_dim, vision_dim, num_heads)

        # Learnable gate (초기값 0으로 점진적 학습)
        self.gate = nn.Parameter(torch.zeros(1))

    def forward(self, text_hidden, image_features):
        # Cross-attention
        attn_out = self.cross_attn(text_hidden, image_features)

        # Gated residual connection
        # tanh(gate)는 초기에 0에 가까워 안정적 학습
        return text_hidden + torch.tanh(self.gate) * attn_out

Flamingo 전체 아키텍처

class FlamingoModel(nn.Module):
    def __init__(self, vision_encoder, llm, cross_attn_every_n_layers=4):
        super().__init__()
        self.vision_encoder = vision_encoder
        self.perceiver = PerceiverResampler(dim=vision_encoder.output_dim)
        self.llm = llm

        # Cross-attention 레이어 삽입 (N개 레이어마다)
        self.cross_attn_layers = nn.ModuleDict()
        for i in range(llm.config.num_hidden_layers):
            if (i + 1) % cross_attn_every_n_layers == 0:
                self.cross_attn_layers[str(i)] = FlamingoLayer(
                    llm.config.hidden_size,
                    vision_encoder.output_dim
                )

    def forward(self, images, input_ids):
        # 이미지 인코딩 + Perceiver로 압축
        image_features = self.vision_encoder(images)
        image_features = self.perceiver(image_features)

        # LLM forward with cross-attention injection
        hidden_states = self.llm.embed_tokens(input_ids)

        for i, layer in enumerate(self.llm.layers):
            # 원래 LLM 레이어
            hidden_states = layer(hidden_states)

            # Cross-attention 삽입 (해당 레이어에만)
            if str(i) in self.cross_attn_layers:
                hidden_states = self.cross_attn_layers[str(i)](
                    hidden_states, image_features
                )

        return self.llm.lm_head(hidden_states)

장점과 단점

장점 단점
깊은 모달리티 통합 추가 파라미터 필요
시퀀스 길이 유지 구현 복잡
텍스트가 이미지 선택적 참조 학습 어려움
더 나은 시각적 추론 계산 비용 증가

사용 모델

  • Flamingo: Gated Cross-Attention 원조
  • IDEFICS: Flamingo 오픈소스 구현
  • Otter: Flamingo 기반 instruction tuning

3. Gated Fusion

게이트를 통해 두 모달리티의 기여도를 동적으로 조절.

아키텍처

token-fusion diagram 4

구현

class GatedFusion(nn.Module):
    """게이트 기반 모달리티 융합"""

    def __init__(self, dim):
        super().__init__()
        # 게이트 네트워크
        self.gate_net = nn.Sequential(
            nn.Linear(dim * 2, dim),
            nn.ReLU(),
            nn.Linear(dim, dim),
            nn.Sigmoid()
        )

        # 모달리티별 projection
        self.image_proj = nn.Linear(dim, dim)
        self.text_proj = nn.Linear(dim, dim)

    def forward(self, image_features, text_features):
        """
        image_features: (batch, dim)
        text_features: (batch, dim)
        """
        # 두 모달리티 연결해서 게이트 값 계산
        combined = torch.cat([image_features, text_features], dim=-1)
        gate = self.gate_net(combined)  # (batch, dim)

        # 게이트 적용
        image_proj = self.image_proj(image_features)
        text_proj = self.text_proj(text_features)

        # 가중 합
        fused = gate * image_proj + (1 - gate) * text_proj

        return fused


class TokenWiseGatedFusion(nn.Module):
    """토큰 단위 게이트 융합"""

    def __init__(self, dim):
        super().__init__()
        self.gate = nn.Sequential(
            nn.Linear(dim * 2, dim),
            nn.Tanh(),
            nn.Linear(dim, 1),
            nn.Sigmoid()
        )

    def forward(self, image_tokens, text_tokens):
        """
        image_tokens: (batch, num_image, dim)
        text_tokens: (batch, num_text, dim)
        """
        # 이미지 토큰을 텍스트 위치에 매핑 (평균 또는 attention)
        image_summary = image_tokens.mean(dim=1, keepdim=True)
        image_expanded = image_summary.expand_as(text_tokens)

        # 각 텍스트 토큰에 대해 게이트 계산
        combined = torch.cat([image_expanded, text_tokens], dim=-1)
        gates = self.gate(combined)  # (batch, num_text, 1)

        # 게이트 적용
        fused = gates * image_expanded + (1 - gates) * text_tokens

        return fused

장점과 단점

장점 단점
적응적 모달리티 선택 추가 파라미터
해석 가능 (gate 값) 학습 불안정 가능
모달리티 균형 조절 복잡한 튜닝

4. Q-Former (Learned Query)

학습 가능한 쿼리로 이미지 정보를 압축하는 방식.

아키텍처

token-fusion diagram 5

구현

class QFormer(nn.Module):
    """BLIP-2 스타일 Q-Former"""

    def __init__(
        self,
        num_queries=32,
        hidden_size=768,
        num_heads=12,
        num_layers=6,
        vision_dim=1024
    ):
        super().__init__()

        # 학습 가능한 쿼리
        self.queries = nn.Parameter(torch.randn(1, num_queries, hidden_size))

        # Q-Former 레이어
        self.layers = nn.ModuleList([
            QFormerLayer(hidden_size, num_heads, vision_dim)
            for _ in range(num_layers)
        ])

        self.norm = nn.LayerNorm(hidden_size)

    def forward(self, image_features):
        """
        image_features: (batch, num_patches, vision_dim)
        returns: (batch, num_queries, hidden_size)
        """
        batch_size = image_features.shape[0]

        # 쿼리 확장
        queries = self.queries.expand(batch_size, -1, -1)

        # Q-Former 레이어 통과
        for layer in self.layers:
            queries = layer(queries, image_features)

        return self.norm(queries)


class QFormerLayer(nn.Module):
    def __init__(self, hidden_size, num_heads, vision_dim):
        super().__init__()

        # Self-attention (쿼리 간)
        self.self_attn = nn.MultiheadAttention(
            hidden_size, num_heads, batch_first=True
        )
        self.norm1 = nn.LayerNorm(hidden_size)

        # Cross-attention (쿼리 → 이미지)
        self.cross_attn = nn.MultiheadAttention(
            hidden_size, num_heads, batch_first=True,
            kdim=vision_dim, vdim=vision_dim
        )
        self.norm2 = nn.LayerNorm(hidden_size)

        # FFN
        self.ffn = nn.Sequential(
            nn.Linear(hidden_size, hidden_size * 4),
            nn.GELU(),
            nn.Linear(hidden_size * 4, hidden_size)
        )
        self.norm3 = nn.LayerNorm(hidden_size)

    def forward(self, queries, image_features):
        # Self-attention
        q = self.norm1(queries)
        queries = queries + self.self_attn(q, q, q)[0]

        # Cross-attention with image
        q = self.norm2(queries)
        queries = queries + self.cross_attn(
            q, image_features, image_features
        )[0]

        # FFN
        queries = queries + self.ffn(self.norm3(queries))

        return queries

장점과 단점

장점 단점
시각 토큰 수 고정 (효율) 정보 손실 가능
LLM 수정 불필요 추가 학습 필요
모듈화된 설계 복잡한 사전학습

사용 모델

  • BLIP-2: Q-Former 원조
  • InstructBLIP: Q-Former + Instruction tuning
  • MiniGPT-4: Q-Former 활용

5. Perceiver Resampler

가변 길이 이미지 토큰을 고정 길이로 리샘플링.

아키텍처

token-fusion diagram 6

구현

class PerceiverResampler(nn.Module):
    """Flamingo 스타일 Perceiver Resampler"""

    def __init__(
        self,
        dim,
        num_latents=64,
        num_heads=8,
        num_layers=6
    ):
        super().__init__()

        # 학습 가능한 latent 쿼리
        self.latents = nn.Parameter(torch.randn(num_latents, dim))

        # Temporal embedding (여러 프레임 처리용)
        self.time_embed = nn.Embedding(100, dim)

        # Perceiver 레이어
        self.layers = nn.ModuleList([
            PerceiverLayer(dim, num_heads)
            for _ in range(num_layers)
        ])

        self.norm = nn.LayerNorm(dim)

    def forward(self, image_features, time_idx=None):
        """
        image_features: (batch, num_patches, dim) 또는
                       (batch, num_frames, num_patches, dim)
        """
        batch_size = image_features.shape[0]

        # 비디오인 경우 temporal embedding 추가
        if image_features.dim() == 4:
            num_frames = image_features.shape[1]
            time_emb = self.time_embed(
                torch.arange(num_frames, device=image_features.device)
            )
            image_features = image_features + time_emb.unsqueeze(0).unsqueeze(2)
            image_features = image_features.flatten(1, 2)

        # Latent 쿼리 확장
        latents = self.latents.unsqueeze(0).expand(batch_size, -1, -1)

        # Perceiver 레이어 통과
        for layer in self.layers:
            latents = layer(latents, image_features)

        return self.norm(latents)


class PerceiverLayer(nn.Module):
    def __init__(self, dim, num_heads):
        super().__init__()

        # Cross-attention: latents가 image를 attend
        self.cross_attn = nn.MultiheadAttention(
            dim, num_heads, batch_first=True
        )
        self.norm1 = nn.LayerNorm(dim)

        # Self-attention: latents 간 상호작용
        self.self_attn = nn.MultiheadAttention(
            dim, num_heads, batch_first=True
        )
        self.norm2 = nn.LayerNorm(dim)

        # FFN
        self.ffn = nn.Sequential(
            nn.Linear(dim, dim * 4),
            nn.GELU(),
            nn.Linear(dim * 4, dim)
        )
        self.norm3 = nn.LayerNorm(dim)

    def forward(self, latents, image_features):
        # Cross-attention with image
        ln = self.norm1(latents)
        latents = latents + self.cross_attn(
            ln, image_features, image_features
        )[0]

        # Self-attention
        ln = self.norm2(latents)
        latents = latents + self.self_attn(ln, ln, ln)[0]

        # FFN
        latents = latents + self.ffn(self.norm3(latents))

        return latents

방식별 비교

정량적 비교

방식 추가 파라미터 추론 비용 학습 난이도 성능
Concatenation 낮음 중간 쉬움 좋음
Cross-Attention 중간 높음 어려움 최고
Gated Fusion 낮음 낮음 중간 좋음
Q-Former 높음 낮음 어려움 좋음
Perceiver 중간 낮음 중간 좋음

사용 사례별 권장

사용 사례 권장 방식 이유
빠른 프로토타이핑 Concatenation 구현 단순
최고 성능 추구 Cross-Attention 깊은 통합
효율성 중시 Q-Former/Perceiver 토큰 압축
해석 필요 Gated Fusion 게이트 분석
비디오 처리 Perceiver 프레임 처리

최신 트렌드

Interleaved Fusion (Gemini 스타일)

이미지와 텍스트를 자연스럽게 섞어서 처리:

class InterleavedFusion(nn.Module):
    """이미지-텍스트 인터리빙 방식"""

    def __init__(self, vision_encoder, projector, llm):
        super().__init__()
        self.vision_encoder = vision_encoder
        self.projector = projector
        self.llm = llm

        # 특수 토큰 ID
        self.img_start_id = llm.config.vocab_size  # <img>
        self.img_end_id = llm.config.vocab_size + 1  # </img>

    def forward(self, text_with_placeholders, images, image_positions):
        """
        text_with_placeholders: "User: <img></img> What is this?"
        images: [image1, image2, ...]
        image_positions: [(start1, end1), (start2, end2), ...]
        """
        # 텍스트 임베딩
        text_embeds = self.llm.embed_tokens(text_with_placeholders)

        # 이미지 인코딩
        for i, (img, (start, end)) in enumerate(zip(images, image_positions)):
            img_features = self.vision_encoder(img.unsqueeze(0))
            img_tokens = self.projector(img_features).squeeze(0)

            # 해당 위치에 이미지 토큰 삽입
            text_embeds = torch.cat([
                text_embeds[:, :start],
                img_tokens.unsqueeze(0),
                text_embeds[:, end:]
            ], dim=1)

        return self.llm(inputs_embeds=text_embeds)

Dynamic Resolution Fusion

해상도에 따라 토큰 수가 변하는 방식:

class DynamicResolutionFusion(nn.Module):
    """Qwen-VL 스타일 동적 해상도"""

    def __init__(self, vision_encoder, llm):
        super().__init__()
        self.vision_encoder = vision_encoder
        self.llm = llm

        # 해상도별 position embedding
        self.spatial_pos_embed = nn.Parameter(
            torch.randn(1, 4096, vision_encoder.output_dim)
        )

    def forward(self, images, input_ids):
        all_image_tokens = []

        for img in images:
            # 이미지 크기에 따른 패치 수 결정
            h, w = img.shape[-2:]
            num_patches_h = h // self.vision_encoder.patch_size
            num_patches_w = w // self.vision_encoder.patch_size
            num_patches = num_patches_h * num_patches_w

            # 인코딩
            features = self.vision_encoder(img.unsqueeze(0))

            # 동적 position embedding
            pos_emb = self.spatial_pos_embed[:, :num_patches]
            features = features + pos_emb

            all_image_tokens.append(features)

        # 가변 길이 이미지 토큰과 텍스트 결합
        # ...

참고 자료