콘텐츠로 이동
Data Prep
상세

멀티모달 학습 (Multimodal Learning)

서로 다른 모달리티(이미지, 텍스트, 오디오 등)의 정보를 통합하여 학습하는 방법.

왜 멀티모달인가

단일 모달리티의 한계

multimodal diagram 1

실세계 데이터는 멀티모달

도메인 모달리티 조합
소셜 미디어 이미지 + 텍스트 + 해시태그
의료 X-ray + 진단서 + 환자 정보
자율주행 카메라 + LiDAR + GPS
전자상거래 상품 이미지 + 설명 + 리뷰

모달리티 통합 전략

1. Early Fusion (조기 융합)

입력 단계에서 모달리티 결합.

multimodal diagram 2

import torch
import torch.nn as nn

class EarlyFusion(nn.Module):
    def __init__(self, image_dim, text_dim, hidden_dim, num_classes):
        super().__init__()
        self.image_proj = nn.Linear(image_dim, hidden_dim)
        self.text_proj = nn.Linear(text_dim, hidden_dim)
        self.classifier = nn.Sequential(
            nn.Linear(hidden_dim * 2, hidden_dim),
            nn.ReLU(),
            nn.Dropout(0.1),
            nn.Linear(hidden_dim, num_classes)
        )

    def forward(self, image_features, text_features):
        img = self.image_proj(image_features)
        txt = self.text_proj(text_features)

        # 연결 (concatenation)
        fused = torch.cat([img, txt], dim=-1)
        return self.classifier(fused)

장점: 모달리티 간 저수준 상호작용 학습 가능

단점: 한 모달리티가 없으면 사용 불가

실무 사용: 모달리티가 항상 함께 있고, 저수준 상관관계가 중요할 때

2. Late Fusion (후기 융합)

각 모달리티를 개별 처리 후 결합.

multimodal diagram 3

class LateFusion(nn.Module):
    def __init__(self, image_model, text_model, hidden_dim, num_classes):
        super().__init__()
        self.image_model = image_model  # pretrained
        self.text_model = text_model    # pretrained

        # 각 모델 출력 차원
        self.image_fc = nn.Linear(image_model.output_dim, hidden_dim)
        self.text_fc = nn.Linear(text_model.output_dim, hidden_dim)

        # 융합 레이어
        self.fusion = nn.Sequential(
            nn.Linear(hidden_dim * 2, hidden_dim),
            nn.ReLU(),
            nn.Linear(hidden_dim, num_classes)
        )

    def forward(self, image, text, text_mask=None):
        # 독립적으로 처리
        img_out = self.image_model(image)
        txt_out = self.text_model(text, attention_mask=text_mask)

        # 최종 representation
        img_repr = self.image_fc(img_out)
        txt_repr = self.text_fc(txt_out)

        # 결합
        combined = torch.cat([img_repr, txt_repr], dim=-1)
        return self.fusion(combined)

    def forward_image_only(self, image):
        """이미지만 있을 때"""
        img_out = self.image_model(image)
        img_repr = self.image_fc(img_out)
        # zero padding for text
        txt_repr = torch.zeros_like(img_repr)
        combined = torch.cat([img_repr, txt_repr], dim=-1)
        return self.fusion(combined)

장점: 모달리티별 사전학습 모델 활용, 한 모달리티 없어도 사용 가능

단점: 모달리티 간 상호작용 학습 제한

실무 사용: 사전학습 모델 재사용, 모달리티가 선택적일 때

3. Cross-modal Fusion (교차 융합)

모달리티 간 상호작용을 Attention으로 학습.

multimodal diagram 4

class CrossModalFusion(nn.Module):
    def __init__(self, dim, num_heads=8, num_layers=4):
        super().__init__()

        self.layers = nn.ModuleList([
            CrossModalLayer(dim, num_heads)
            for _ in range(num_layers)
        ])

    def forward(self, image_tokens, text_tokens):
        for layer in self.layers:
            image_tokens, text_tokens = layer(image_tokens, text_tokens)
        return image_tokens, text_tokens


class CrossModalLayer(nn.Module):
    def __init__(self, dim, num_heads):
        super().__init__()

        # Image attends to Text
        self.img_cross_attn = nn.MultiheadAttention(
            dim, num_heads, batch_first=True
        )

        # Text attends to Image
        self.txt_cross_attn = nn.MultiheadAttention(
            dim, num_heads, batch_first=True
        )

        self.norm1 = nn.LayerNorm(dim)
        self.norm2 = nn.LayerNorm(dim)
        self.norm3 = nn.LayerNorm(dim)
        self.norm4 = nn.LayerNorm(dim)

        self.ffn_img = nn.Sequential(
            nn.Linear(dim, dim * 4), nn.GELU(), nn.Linear(dim * 4, dim)
        )
        self.ffn_txt = nn.Sequential(
            nn.Linear(dim, dim * 4), nn.GELU(), nn.Linear(dim * 4, dim)
        )

    def forward(self, image_tokens, text_tokens):
        # Image attends to Text
        img_normed = self.norm1(image_tokens)
        img_attn, _ = self.img_cross_attn(
            img_normed, text_tokens, text_tokens
        )
        image_tokens = image_tokens + img_attn
        image_tokens = image_tokens + self.ffn_img(self.norm2(image_tokens))

        # Text attends to Image
        txt_normed = self.norm3(text_tokens)
        txt_attn, _ = self.txt_cross_attn(
            txt_normed, image_tokens, image_tokens
        )
        text_tokens = text_tokens + txt_attn
        text_tokens = text_tokens + self.ffn_txt(self.norm4(text_tokens))

        return image_tokens, text_tokens

장점: 깊은 모달리티 상호작용, 세밀한 정렬

단점: 계산 비용, 구현 복잡도

실무 사용: VLM의 핵심 메커니즘, 정교한 이해가 필요할 때

Contrastive Learning (CLIP)

이미지-텍스트 쌍의 유사도 학습. VLM의 핵심 사전학습 방법.

핵심 아이디어

multimodal diagram 5

InfoNCE Loss

\[L = -\frac{1}{N}\sum_{i=1}^{N}\log\frac{\exp(s_{ii}/\tau)}{\sum_{j=1}^{N}\exp(s_{ij}/\tau)}\]
  • \(s_{ij}\): 이미지 i와 텍스트 j의 유사도 (코사인)
  • \(\tau\): 온도 파라미터 (작을수록 sharp)

구현

import torch
import torch.nn as nn
import torch.nn.functional as F
import numpy as np

class CLIPLoss(nn.Module):
    def __init__(self, temperature=0.07):
        super().__init__()
        # 학습 가능한 temperature
        self.logit_scale = nn.Parameter(
            torch.ones([]) * np.log(1 / temperature)
        )

    def forward(self, image_features, text_features):
        # L2 정규화 (코사인 유사도 계산을 위해)
        image_features = F.normalize(image_features, dim=-1)
        text_features = F.normalize(text_features, dim=-1)

        # 유사도 행렬 계산
        logit_scale = self.logit_scale.exp()
        logits_per_image = logit_scale * image_features @ text_features.T
        logits_per_text = logits_per_image.T

        # Ground truth: 대각선이 positive
        batch_size = image_features.shape[0]
        labels = torch.arange(batch_size, device=image_features.device)

        # 양방향 Cross-entropy Loss
        loss_i2t = F.cross_entropy(logits_per_image, labels)
        loss_t2i = F.cross_entropy(logits_per_text, labels)

        return (loss_i2t + loss_t2i) / 2


class CLIP(nn.Module):
    def __init__(self, vision_encoder, text_encoder, embed_dim=512):
        super().__init__()
        self.vision_encoder = vision_encoder
        self.text_encoder = text_encoder

        # Projection heads (임베딩 공간 정렬)
        self.visual_projection = nn.Linear(
            vision_encoder.output_dim, embed_dim, bias=False
        )
        self.text_projection = nn.Linear(
            text_encoder.output_dim, embed_dim, bias=False
        )

        self.logit_scale = nn.Parameter(torch.ones([]) * np.log(1/0.07))

    def encode_image(self, image):
        features = self.vision_encoder(image)
        return self.visual_projection(features)

    def encode_text(self, text):
        features = self.text_encoder(text)
        return self.text_projection(features)

    def forward(self, image, text):
        image_features = self.encode_image(image)
        text_features = self.encode_text(text)

        # 정규화
        image_features = F.normalize(image_features, dim=-1)
        text_features = F.normalize(text_features, dim=-1)

        return image_features, text_features

CLIP 학습 상세

def train_clip_step(model, images, texts, optimizer, loss_fn):
    optimizer.zero_grad()

    # Forward
    image_features, text_features = model(images, texts)

    # Loss 계산
    loss = loss_fn(image_features, text_features)

    # Backward
    loss.backward()

    # Gradient clipping (안정성)
    torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0)

    optimizer.step()

    return loss.item()


# 학습 설정
optimizer = torch.optim.AdamW(
    model.parameters(),
    lr=1e-4,
    weight_decay=0.1,
    betas=(0.9, 0.98)
)

# Learning rate schedule
scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(
    optimizer,
    T_max=num_epochs,
    eta_min=1e-6
)

왜 Contrastive Learning이 효과적인가

특성 설명
Self-supervised 레이블 없이 이미지-텍스트 쌍만으로 학습
대규모 학습 웹 크롤링 데이터 활용 가능 (400M+ 쌍)
Zero-shot 학습하지 않은 클래스도 분류 가능
전이 학습 다양한 downstream task에 적용

실무 활용: 멀티모달 표현 학습

제로샷 분류

def zero_shot_classification(model, image, class_descriptions):
    """
    학습하지 않은 클래스에 대해 분류

    Args:
        class_descriptions: ["a photo of a cat", "a photo of a dog", ...]
    """
    # 이미지 인코딩
    image_features = model.encode_image(image)
    image_features = F.normalize(image_features, dim=-1)

    # 클래스 텍스트 인코딩
    text_features = model.encode_text(class_descriptions)
    text_features = F.normalize(text_features, dim=-1)

    # 유사도 계산
    similarities = (image_features @ text_features.T) * 100

    # 예측
    probs = F.softmax(similarities, dim=-1)
    predicted_class = probs.argmax(dim=-1)

    return predicted_class, probs


# 사용 예시
classes = [
    "a photo of a cat",
    "a photo of a dog", 
    "a photo of a bird",
    "a photo of a car"
]

# Prompt engineering으로 성능 향상
classes_enhanced = [
    "a photo of a cat, a type of pet",
    "a photo of a dog, a type of pet",
    "a photo of a bird, a type of animal",
    "a photo of a car, a type of vehicle"
]

멀티모달 검색

class MultimodalRetriever:
    def __init__(self, model, text_database, text_embeddings=None):
        self.model = model
        self.text_database = text_database

        # 텍스트 임베딩 사전 계산 (효율성)
        if text_embeddings is None:
            with torch.no_grad():
                self.text_embeddings = model.encode_text(text_database)
                self.text_embeddings = F.normalize(self.text_embeddings, dim=-1)
        else:
            self.text_embeddings = text_embeddings

    def image_to_text(self, image, top_k=5):
        """이미지로 관련 텍스트 검색"""
        with torch.no_grad():
            image_feature = self.model.encode_image(image)
            image_feature = F.normalize(image_feature, dim=-1)

        # 유사도 계산
        similarities = image_feature @ self.text_embeddings.T

        # Top-k 결과
        top_indices = similarities.topk(k=top_k).indices

        return [self.text_database[i] for i in top_indices]

    def text_to_image(self, text, image_embeddings, image_database, top_k=5):
        """텍스트로 관련 이미지 검색"""
        with torch.no_grad():
            text_feature = self.model.encode_text([text])
            text_feature = F.normalize(text_feature, dim=-1)

        similarities = text_feature @ image_embeddings.T
        top_indices = similarities.topk(k=top_k).indices

        return [image_database[i] for i in top_indices]

유사 이미지 검색

def find_similar_images(model, query_image, image_database, top_k=5):
    """이미지 임베딩 기반 유사 이미지 검색"""

    # 쿼리 이미지 인코딩
    query_embedding = model.encode_image(query_image)
    query_embedding = F.normalize(query_embedding, dim=-1)

    # 데이터베이스 이미지 인코딩 (사전 계산 권장)
    db_embeddings = []
    for img in image_database:
        emb = model.encode_image(img)
        db_embeddings.append(F.normalize(emb, dim=-1))
    db_embeddings = torch.stack(db_embeddings)

    # 유사도 계산
    similarities = query_embedding @ db_embeddings.T
    top_indices = similarities.topk(k=top_k).indices

    return [image_database[i] for i in top_indices]

사전학습 태스크

VLM 학습에 사용되는 주요 사전학습 태스크:

Image-Text Contrastive (ITC)

CLIP 스타일 대조 학습. 위에서 설명.

Image-Text Matching (ITM)

이미지-텍스트 쌍이 매칭되는지 이진 분류.

class ITMHead(nn.Module):
    def __init__(self, hidden_dim):
        super().__init__()
        self.classifier = nn.Sequential(
            nn.Linear(hidden_dim, hidden_dim),
            nn.ReLU(),
            nn.Linear(hidden_dim, 2)  # match / not match
        )

    def forward(self, multimodal_features):
        # 보통 [CLS] 토큰 사용
        return self.classifier(multimodal_features[:, 0])


def itm_loss(model, images, texts, hard_negative_mining=True):
    batch_size = images.shape[0]

    # Positive pairs
    pos_features = model.encode_multimodal(images, texts)
    pos_labels = torch.ones(batch_size)

    # Negative pairs (Hard negative mining)
    if hard_negative_mining:
        # ITC 유사도 기반으로 어려운 negative 선택
        with torch.no_grad():
            img_emb = model.encode_image(images)
            txt_emb = model.encode_text(texts)
            sim = img_emb @ txt_emb.T

            # 대각선 마스킹 (positive 제외)
            sim.fill_diagonal_(-float('inf'))

            # 가장 유사한 negative 선택
            hard_neg_idx = sim.argmax(dim=1)

        neg_texts = texts[hard_neg_idx]
    else:
        # Random shuffle
        neg_texts = texts[torch.randperm(batch_size)]

    neg_features = model.encode_multimodal(images, neg_texts)
    neg_labels = torch.zeros(batch_size)

    # Loss 계산
    all_features = torch.cat([pos_features, neg_features])
    all_labels = torch.cat([pos_labels, neg_labels]).long()

    logits = model.itm_head(all_features)
    return F.cross_entropy(logits, all_labels)

Masked Language Modeling (MLM)

이미지 컨텍스트로 마스킹된 텍스트 예측.

def mlm_loss(model, images, texts, mask_prob=0.15):
    # 텍스트 마스킹
    masked_texts, labels = mask_tokens(texts, mask_prob)

    # 이미지와 함께 예측
    outputs = model.forward_mlm(images, masked_texts)

    # 마스킹된 위치만 loss 계산
    loss = F.cross_entropy(
        outputs.view(-1, vocab_size),
        labels.view(-1),
        ignore_index=-100  # 마스킹되지 않은 위치
    )

    return loss


def mask_tokens(texts, mask_prob, mask_token_id, vocab_size):
    """BERT 스타일 마스킹"""
    labels = texts.clone()

    # 마스킹 확률 행렬
    prob_matrix = torch.full(texts.shape, mask_prob)

    # 특수 토큰은 마스킹 안함
    special_tokens_mask = get_special_tokens_mask(texts)
    prob_matrix.masked_fill_(special_tokens_mask, 0.0)

    # 마스킹 위치 선택
    masked_indices = torch.bernoulli(prob_matrix).bool()
    labels[~masked_indices] = -100  # loss 계산 안함

    # 80% [MASK], 10% random, 10% unchanged
    indices_replaced = torch.bernoulli(torch.full(texts.shape, 0.8)).bool() & masked_indices
    texts[indices_replaced] = mask_token_id

    indices_random = torch.bernoulli(torch.full(texts.shape, 0.1)).bool() & masked_indices & ~indices_replaced
    random_words = torch.randint(vocab_size, texts.shape)
    texts[indices_random] = random_words[indices_random]

    return texts, labels

멀티모달 데이터셋

대규모 사전학습 데이터

데이터셋 규모 특징 용도
LAION-5B 5.8B 쌍 웹 크롤링, 다국어 CLIP 학습
LAION-400M 400M 쌍 영어 중심 CLIP 학습
CC3M 3.3M 품질 필터링 소규모 실험
CC12M 12M 중간 규모 연구용
DataComp 12.8B 품질 다양 데이터 연구

평가/미세조정 데이터

데이터셋 태스크 규모
COCO 캡셔닝, 검색 330K
Visual Genome 상세 annotation 100K
VQAv2 시각 질의응답 1.1M QA
Flickr30k 검색 31K

CLIP 변형 모델

SigLIP (Sigmoid Loss)

class SigLIPLoss(nn.Module):
    """
    CLIP의 softmax 대신 sigmoid 사용
    - 메모리 효율적 (배치 내 모든 쌍 계산 불필요)
    - 더 안정적인 학습
    """
    def __init__(self, temperature=10.0, bias=-10.0):
        super().__init__()
        self.temperature = nn.Parameter(torch.tensor(temperature))
        self.bias = nn.Parameter(torch.tensor(bias))

    def forward(self, image_features, text_features, labels=None):
        # 유사도 계산
        logits = image_features @ text_features.T * self.temperature + self.bias

        # Labels: 대각선이 1, 나머지 0 (또는 커스텀)
        if labels is None:
            batch_size = image_features.shape[0]
            labels = torch.eye(batch_size, device=image_features.device)

        # Binary cross-entropy (각 쌍에 대해 독립적)
        loss = F.binary_cross_entropy_with_logits(logits, labels)

        return loss

OpenCLIP

# OpenCLIP 사용 예시
import open_clip

# 모델 로드
model, preprocess_train, preprocess_val = open_clip.create_model_and_transforms(
    'ViT-L-14',
    pretrained='laion2b_s32b_b82k'
)
tokenizer = open_clip.get_tokenizer('ViT-L-14')

# 이미지 인코딩
image = preprocess_val(Image.open("image.jpg")).unsqueeze(0)
image_features = model.encode_image(image)

# 텍스트 인코딩
text = tokenizer(["a cat", "a dog"])
text_features = model.encode_text(text)

# 유사도
similarity = (image_features @ text_features.T).softmax(dim=-1)

참고 자료

핵심 논문

코드/라이브러리