멀티모달 학습 (Multimodal Learning)¶
서로 다른 모달리티(이미지, 텍스트, 오디오 등)의 정보를 통합하여 학습하는 방법.
왜 멀티모달인가¶
단일 모달리티의 한계¶
실세계 데이터는 멀티모달¶
| 도메인 | 모달리티 조합 |
|---|---|
| 소셜 미디어 | 이미지 + 텍스트 + 해시태그 |
| 의료 | X-ray + 진단서 + 환자 정보 |
| 자율주행 | 카메라 + LiDAR + GPS |
| 전자상거래 | 상품 이미지 + 설명 + 리뷰 |
모달리티 통합 전략¶
1. Early Fusion (조기 융합)¶
입력 단계에서 모달리티 결합.
import torch
import torch.nn as nn
class EarlyFusion(nn.Module):
def __init__(self, image_dim, text_dim, hidden_dim, num_classes):
super().__init__()
self.image_proj = nn.Linear(image_dim, hidden_dim)
self.text_proj = nn.Linear(text_dim, hidden_dim)
self.classifier = nn.Sequential(
nn.Linear(hidden_dim * 2, hidden_dim),
nn.ReLU(),
nn.Dropout(0.1),
nn.Linear(hidden_dim, num_classes)
)
def forward(self, image_features, text_features):
img = self.image_proj(image_features)
txt = self.text_proj(text_features)
# 연결 (concatenation)
fused = torch.cat([img, txt], dim=-1)
return self.classifier(fused)
장점: 모달리티 간 저수준 상호작용 학습 가능
단점: 한 모달리티가 없으면 사용 불가
실무 사용: 모달리티가 항상 함께 있고, 저수준 상관관계가 중요할 때
2. Late Fusion (후기 융합)¶
각 모달리티를 개별 처리 후 결합.
class LateFusion(nn.Module):
def __init__(self, image_model, text_model, hidden_dim, num_classes):
super().__init__()
self.image_model = image_model # pretrained
self.text_model = text_model # pretrained
# 각 모델 출력 차원
self.image_fc = nn.Linear(image_model.output_dim, hidden_dim)
self.text_fc = nn.Linear(text_model.output_dim, hidden_dim)
# 융합 레이어
self.fusion = nn.Sequential(
nn.Linear(hidden_dim * 2, hidden_dim),
nn.ReLU(),
nn.Linear(hidden_dim, num_classes)
)
def forward(self, image, text, text_mask=None):
# 독립적으로 처리
img_out = self.image_model(image)
txt_out = self.text_model(text, attention_mask=text_mask)
# 최종 representation
img_repr = self.image_fc(img_out)
txt_repr = self.text_fc(txt_out)
# 결합
combined = torch.cat([img_repr, txt_repr], dim=-1)
return self.fusion(combined)
def forward_image_only(self, image):
"""이미지만 있을 때"""
img_out = self.image_model(image)
img_repr = self.image_fc(img_out)
# zero padding for text
txt_repr = torch.zeros_like(img_repr)
combined = torch.cat([img_repr, txt_repr], dim=-1)
return self.fusion(combined)
장점: 모달리티별 사전학습 모델 활용, 한 모달리티 없어도 사용 가능
단점: 모달리티 간 상호작용 학습 제한
실무 사용: 사전학습 모델 재사용, 모달리티가 선택적일 때
3. Cross-modal Fusion (교차 융합)¶
모달리티 간 상호작용을 Attention으로 학습.
class CrossModalFusion(nn.Module):
def __init__(self, dim, num_heads=8, num_layers=4):
super().__init__()
self.layers = nn.ModuleList([
CrossModalLayer(dim, num_heads)
for _ in range(num_layers)
])
def forward(self, image_tokens, text_tokens):
for layer in self.layers:
image_tokens, text_tokens = layer(image_tokens, text_tokens)
return image_tokens, text_tokens
class CrossModalLayer(nn.Module):
def __init__(self, dim, num_heads):
super().__init__()
# Image attends to Text
self.img_cross_attn = nn.MultiheadAttention(
dim, num_heads, batch_first=True
)
# Text attends to Image
self.txt_cross_attn = nn.MultiheadAttention(
dim, num_heads, batch_first=True
)
self.norm1 = nn.LayerNorm(dim)
self.norm2 = nn.LayerNorm(dim)
self.norm3 = nn.LayerNorm(dim)
self.norm4 = nn.LayerNorm(dim)
self.ffn_img = nn.Sequential(
nn.Linear(dim, dim * 4), nn.GELU(), nn.Linear(dim * 4, dim)
)
self.ffn_txt = nn.Sequential(
nn.Linear(dim, dim * 4), nn.GELU(), nn.Linear(dim * 4, dim)
)
def forward(self, image_tokens, text_tokens):
# Image attends to Text
img_normed = self.norm1(image_tokens)
img_attn, _ = self.img_cross_attn(
img_normed, text_tokens, text_tokens
)
image_tokens = image_tokens + img_attn
image_tokens = image_tokens + self.ffn_img(self.norm2(image_tokens))
# Text attends to Image
txt_normed = self.norm3(text_tokens)
txt_attn, _ = self.txt_cross_attn(
txt_normed, image_tokens, image_tokens
)
text_tokens = text_tokens + txt_attn
text_tokens = text_tokens + self.ffn_txt(self.norm4(text_tokens))
return image_tokens, text_tokens
장점: 깊은 모달리티 상호작용, 세밀한 정렬
단점: 계산 비용, 구현 복잡도
실무 사용: VLM의 핵심 메커니즘, 정교한 이해가 필요할 때
Contrastive Learning (CLIP)¶
이미지-텍스트 쌍의 유사도 학습. VLM의 핵심 사전학습 방법.
핵심 아이디어¶
InfoNCE Loss¶
\[L = -\frac{1}{N}\sum_{i=1}^{N}\log\frac{\exp(s_{ii}/\tau)}{\sum_{j=1}^{N}\exp(s_{ij}/\tau)}\]
- \(s_{ij}\): 이미지 i와 텍스트 j의 유사도 (코사인)
- \(\tau\): 온도 파라미터 (작을수록 sharp)
구현¶
import torch
import torch.nn as nn
import torch.nn.functional as F
import numpy as np
class CLIPLoss(nn.Module):
def __init__(self, temperature=0.07):
super().__init__()
# 학습 가능한 temperature
self.logit_scale = nn.Parameter(
torch.ones([]) * np.log(1 / temperature)
)
def forward(self, image_features, text_features):
# L2 정규화 (코사인 유사도 계산을 위해)
image_features = F.normalize(image_features, dim=-1)
text_features = F.normalize(text_features, dim=-1)
# 유사도 행렬 계산
logit_scale = self.logit_scale.exp()
logits_per_image = logit_scale * image_features @ text_features.T
logits_per_text = logits_per_image.T
# Ground truth: 대각선이 positive
batch_size = image_features.shape[0]
labels = torch.arange(batch_size, device=image_features.device)
# 양방향 Cross-entropy Loss
loss_i2t = F.cross_entropy(logits_per_image, labels)
loss_t2i = F.cross_entropy(logits_per_text, labels)
return (loss_i2t + loss_t2i) / 2
class CLIP(nn.Module):
def __init__(self, vision_encoder, text_encoder, embed_dim=512):
super().__init__()
self.vision_encoder = vision_encoder
self.text_encoder = text_encoder
# Projection heads (임베딩 공간 정렬)
self.visual_projection = nn.Linear(
vision_encoder.output_dim, embed_dim, bias=False
)
self.text_projection = nn.Linear(
text_encoder.output_dim, embed_dim, bias=False
)
self.logit_scale = nn.Parameter(torch.ones([]) * np.log(1/0.07))
def encode_image(self, image):
features = self.vision_encoder(image)
return self.visual_projection(features)
def encode_text(self, text):
features = self.text_encoder(text)
return self.text_projection(features)
def forward(self, image, text):
image_features = self.encode_image(image)
text_features = self.encode_text(text)
# 정규화
image_features = F.normalize(image_features, dim=-1)
text_features = F.normalize(text_features, dim=-1)
return image_features, text_features
CLIP 학습 상세¶
def train_clip_step(model, images, texts, optimizer, loss_fn):
optimizer.zero_grad()
# Forward
image_features, text_features = model(images, texts)
# Loss 계산
loss = loss_fn(image_features, text_features)
# Backward
loss.backward()
# Gradient clipping (안정성)
torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0)
optimizer.step()
return loss.item()
# 학습 설정
optimizer = torch.optim.AdamW(
model.parameters(),
lr=1e-4,
weight_decay=0.1,
betas=(0.9, 0.98)
)
# Learning rate schedule
scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(
optimizer,
T_max=num_epochs,
eta_min=1e-6
)
왜 Contrastive Learning이 효과적인가¶
| 특성 | 설명 |
|---|---|
| Self-supervised | 레이블 없이 이미지-텍스트 쌍만으로 학습 |
| 대규모 학습 | 웹 크롤링 데이터 활용 가능 (400M+ 쌍) |
| Zero-shot | 학습하지 않은 클래스도 분류 가능 |
| 전이 학습 | 다양한 downstream task에 적용 |
실무 활용: 멀티모달 표현 학습¶
제로샷 분류¶
def zero_shot_classification(model, image, class_descriptions):
"""
학습하지 않은 클래스에 대해 분류
Args:
class_descriptions: ["a photo of a cat", "a photo of a dog", ...]
"""
# 이미지 인코딩
image_features = model.encode_image(image)
image_features = F.normalize(image_features, dim=-1)
# 클래스 텍스트 인코딩
text_features = model.encode_text(class_descriptions)
text_features = F.normalize(text_features, dim=-1)
# 유사도 계산
similarities = (image_features @ text_features.T) * 100
# 예측
probs = F.softmax(similarities, dim=-1)
predicted_class = probs.argmax(dim=-1)
return predicted_class, probs
# 사용 예시
classes = [
"a photo of a cat",
"a photo of a dog",
"a photo of a bird",
"a photo of a car"
]
# Prompt engineering으로 성능 향상
classes_enhanced = [
"a photo of a cat, a type of pet",
"a photo of a dog, a type of pet",
"a photo of a bird, a type of animal",
"a photo of a car, a type of vehicle"
]
멀티모달 검색¶
class MultimodalRetriever:
def __init__(self, model, text_database, text_embeddings=None):
self.model = model
self.text_database = text_database
# 텍스트 임베딩 사전 계산 (효율성)
if text_embeddings is None:
with torch.no_grad():
self.text_embeddings = model.encode_text(text_database)
self.text_embeddings = F.normalize(self.text_embeddings, dim=-1)
else:
self.text_embeddings = text_embeddings
def image_to_text(self, image, top_k=5):
"""이미지로 관련 텍스트 검색"""
with torch.no_grad():
image_feature = self.model.encode_image(image)
image_feature = F.normalize(image_feature, dim=-1)
# 유사도 계산
similarities = image_feature @ self.text_embeddings.T
# Top-k 결과
top_indices = similarities.topk(k=top_k).indices
return [self.text_database[i] for i in top_indices]
def text_to_image(self, text, image_embeddings, image_database, top_k=5):
"""텍스트로 관련 이미지 검색"""
with torch.no_grad():
text_feature = self.model.encode_text([text])
text_feature = F.normalize(text_feature, dim=-1)
similarities = text_feature @ image_embeddings.T
top_indices = similarities.topk(k=top_k).indices
return [image_database[i] for i in top_indices]
유사 이미지 검색¶
def find_similar_images(model, query_image, image_database, top_k=5):
"""이미지 임베딩 기반 유사 이미지 검색"""
# 쿼리 이미지 인코딩
query_embedding = model.encode_image(query_image)
query_embedding = F.normalize(query_embedding, dim=-1)
# 데이터베이스 이미지 인코딩 (사전 계산 권장)
db_embeddings = []
for img in image_database:
emb = model.encode_image(img)
db_embeddings.append(F.normalize(emb, dim=-1))
db_embeddings = torch.stack(db_embeddings)
# 유사도 계산
similarities = query_embedding @ db_embeddings.T
top_indices = similarities.topk(k=top_k).indices
return [image_database[i] for i in top_indices]
사전학습 태스크¶
VLM 학습에 사용되는 주요 사전학습 태스크:
Image-Text Contrastive (ITC)¶
CLIP 스타일 대조 학습. 위에서 설명.
Image-Text Matching (ITM)¶
이미지-텍스트 쌍이 매칭되는지 이진 분류.
class ITMHead(nn.Module):
def __init__(self, hidden_dim):
super().__init__()
self.classifier = nn.Sequential(
nn.Linear(hidden_dim, hidden_dim),
nn.ReLU(),
nn.Linear(hidden_dim, 2) # match / not match
)
def forward(self, multimodal_features):
# 보통 [CLS] 토큰 사용
return self.classifier(multimodal_features[:, 0])
def itm_loss(model, images, texts, hard_negative_mining=True):
batch_size = images.shape[0]
# Positive pairs
pos_features = model.encode_multimodal(images, texts)
pos_labels = torch.ones(batch_size)
# Negative pairs (Hard negative mining)
if hard_negative_mining:
# ITC 유사도 기반으로 어려운 negative 선택
with torch.no_grad():
img_emb = model.encode_image(images)
txt_emb = model.encode_text(texts)
sim = img_emb @ txt_emb.T
# 대각선 마스킹 (positive 제외)
sim.fill_diagonal_(-float('inf'))
# 가장 유사한 negative 선택
hard_neg_idx = sim.argmax(dim=1)
neg_texts = texts[hard_neg_idx]
else:
# Random shuffle
neg_texts = texts[torch.randperm(batch_size)]
neg_features = model.encode_multimodal(images, neg_texts)
neg_labels = torch.zeros(batch_size)
# Loss 계산
all_features = torch.cat([pos_features, neg_features])
all_labels = torch.cat([pos_labels, neg_labels]).long()
logits = model.itm_head(all_features)
return F.cross_entropy(logits, all_labels)
Masked Language Modeling (MLM)¶
이미지 컨텍스트로 마스킹된 텍스트 예측.
def mlm_loss(model, images, texts, mask_prob=0.15):
# 텍스트 마스킹
masked_texts, labels = mask_tokens(texts, mask_prob)
# 이미지와 함께 예측
outputs = model.forward_mlm(images, masked_texts)
# 마스킹된 위치만 loss 계산
loss = F.cross_entropy(
outputs.view(-1, vocab_size),
labels.view(-1),
ignore_index=-100 # 마스킹되지 않은 위치
)
return loss
def mask_tokens(texts, mask_prob, mask_token_id, vocab_size):
"""BERT 스타일 마스킹"""
labels = texts.clone()
# 마스킹 확률 행렬
prob_matrix = torch.full(texts.shape, mask_prob)
# 특수 토큰은 마스킹 안함
special_tokens_mask = get_special_tokens_mask(texts)
prob_matrix.masked_fill_(special_tokens_mask, 0.0)
# 마스킹 위치 선택
masked_indices = torch.bernoulli(prob_matrix).bool()
labels[~masked_indices] = -100 # loss 계산 안함
# 80% [MASK], 10% random, 10% unchanged
indices_replaced = torch.bernoulli(torch.full(texts.shape, 0.8)).bool() & masked_indices
texts[indices_replaced] = mask_token_id
indices_random = torch.bernoulli(torch.full(texts.shape, 0.1)).bool() & masked_indices & ~indices_replaced
random_words = torch.randint(vocab_size, texts.shape)
texts[indices_random] = random_words[indices_random]
return texts, labels
멀티모달 데이터셋¶
대규모 사전학습 데이터¶
| 데이터셋 | 규모 | 특징 | 용도 |
|---|---|---|---|
| LAION-5B | 5.8B 쌍 | 웹 크롤링, 다국어 | CLIP 학습 |
| LAION-400M | 400M 쌍 | 영어 중심 | CLIP 학습 |
| CC3M | 3.3M | 품질 필터링 | 소규모 실험 |
| CC12M | 12M | 중간 규모 | 연구용 |
| DataComp | 12.8B | 품질 다양 | 데이터 연구 |
평가/미세조정 데이터¶
| 데이터셋 | 태스크 | 규모 |
|---|---|---|
| COCO | 캡셔닝, 검색 | 330K |
| Visual Genome | 상세 annotation | 100K |
| VQAv2 | 시각 질의응답 | 1.1M QA |
| Flickr30k | 검색 | 31K |
CLIP 변형 모델¶
SigLIP (Sigmoid Loss)¶
class SigLIPLoss(nn.Module):
"""
CLIP의 softmax 대신 sigmoid 사용
- 메모리 효율적 (배치 내 모든 쌍 계산 불필요)
- 더 안정적인 학습
"""
def __init__(self, temperature=10.0, bias=-10.0):
super().__init__()
self.temperature = nn.Parameter(torch.tensor(temperature))
self.bias = nn.Parameter(torch.tensor(bias))
def forward(self, image_features, text_features, labels=None):
# 유사도 계산
logits = image_features @ text_features.T * self.temperature + self.bias
# Labels: 대각선이 1, 나머지 0 (또는 커스텀)
if labels is None:
batch_size = image_features.shape[0]
labels = torch.eye(batch_size, device=image_features.device)
# Binary cross-entropy (각 쌍에 대해 독립적)
loss = F.binary_cross_entropy_with_logits(logits, labels)
return loss
OpenCLIP¶
# OpenCLIP 사용 예시
import open_clip
# 모델 로드
model, preprocess_train, preprocess_val = open_clip.create_model_and_transforms(
'ViT-L-14',
pretrained='laion2b_s32b_b82k'
)
tokenizer = open_clip.get_tokenizer('ViT-L-14')
# 이미지 인코딩
image = preprocess_val(Image.open("image.jpg")).unsqueeze(0)
image_features = model.encode_image(image)
# 텍스트 인코딩
text = tokenizer(["a cat", "a dog"])
text_features = model.encode_text(text)
# 유사도
similarity = (image_features @ text_features.T).softmax(dim=-1)
참고 자료¶
핵심 논문¶
- CLIP Paper - Contrastive Learning 기초
- ALIGN Paper - 대규모 노이즈 데이터 학습
- BLIP Paper - 캡셔닝과 이해 통합
- SigLIP Paper - Sigmoid 기반 학습
코드/라이브러리¶
- OpenCLIP - 오픈소스 CLIP 구현
- Hugging Face Transformers - CLIP 모델
- LAION - 대규모 데이터셋