Self-Supervised Learning (SSL)
메타 정보
| 항목 |
내용 |
| 분류 |
Representation Learning / Label-Efficient Learning |
| 핵심 논문 |
"A Simple Framework for Contrastive Learning of Visual Representations" (Chen et al., ICML 2020 - SimCLR), "Momentum Contrast for Unsupervised Visual Representation Learning" (He et al., CVPR 2020 - MoCo v1/v2), "Bootstrap Your Own Latent" (Grill et al., NeurIPS 2020 - BYOL), "Masked Autoencoders Are Scalable Vision Learners" (He et al., CVPR 2022 - MAE), "Emerging Properties in Self-Supervised Vision Transformers" (Caron et al., ICCV 2021 - DINO), "DINOv2: Learning Robust Visual Features without Supervision" (Oquab et al., TMLR 2024), "BERT: Pre-Training of Deep Bidirectional Transformers" (Devlin et al., NAACL 2019), "Language Models are Unsupervised Multitask Learners" (Radford et al., 2019 - GPT-2) |
| 주요 저자 |
Kaiming He (MoCo, MAE), Ting Chen & Geoffrey Hinton (SimCLR), Mathilde Caron (DINO), Jean-Baptiste Grill (BYOL), Jacob Devlin (BERT) |
| 핵심 개념 |
레이블 없는 데이터에서 pretext task를 자동 생성하여 범용 표현(representation)을 학습하는 패러다임 |
| 관련 분야 |
Contrastive Learning, Masked Modeling, Foundation Models, Transfer Learning, Representation Learning |
정의
Self-Supervised Learning(SSL)은 레이블 없는 데이터로부터 pretext task(자동 생성된 보조 과제)를 정의하고, 이를 풀면서 데이터의 구조적 표현을 학습하는 방법론이다. 지도학습의 레이블 비용 문제와 비지도학습의 목적 불명확성을 동시에 해결한다.
학습 패러다임 비교
| 항목 |
지도학습 |
비지도학습 |
자기지도학습 |
| 레이블 |
필요 (비용 높음) |
불필요 |
불필요 |
| 학습 신호 |
사람이 제공한 정답 |
데이터 분포 |
데이터에서 자동 생성 |
| 목적함수 |
cross-entropy 등 |
reconstruction, density |
pretext task별 상이 |
| 표현 품질 |
태스크 특화 |
범용적이나 품질 불균일 |
범용적이며 고품질 |
| 대표 사례 |
ResNet + ImageNet |
VAE, GAN |
SimCLR, MAE, BERT |
| 확장성 |
레이블에 의존 |
데이터 양에 비례 |
데이터 양에 비례 |
핵심 아이디어
Self-Supervised Learning의 2단계 프레임워크:
[1단계: Pretext Task로 사전학습]
+------------------------------------------------------------------+
| |
| 대량의 비레이블 데이터 --> Pretext Task 정의 --> 인코더 학습 |
| |
| 예시: |
| - 이미지 일부 마스킹 후 복원 (MAE) |
| - 같은 이미지의 다른 augmentation끼리 유사하게 (SimCLR) |
| - 문장 내 단어 마스킹 후 예측 (BERT) |
| - 다음 토큰 예측 (GPT) |
| |
+------------------------------------------------------------------+
|
v
[2단계: Downstream Task로 전이]
+------------------------------------------------------------------+
| |
| 학습된 인코더 --> 소량 레이블로 fine-tuning 또는 linear probe |
| |
| 분류, 탐지, 분할, QA 등 다양한 태스크에 적용 |
| |
+------------------------------------------------------------------+
SSL 방법론 분류 체계
SSL 방법론은 크게 네 가지 패밀리로 분류된다:
Self-Supervised Learning 방법론 분류:
+---------------------------------------------+
| 1. Contrastive Methods |
| - positive/negative pair 구성 |
| - InfoNCE 손실 함수 |
| - SimCLR, MoCo, CLIP |
+---------------------------------------------+
|
v
+---------------------------------------------+
| 2. Non-Contrastive (Self-Distillation) |
| - negative pair 없이 학습 |
| - EMA teacher + stop-gradient |
| - BYOL, SimSiam, DINO, DINOv2 |
+---------------------------------------------+
|
v
+---------------------------------------------+
| 3. Masked Modeling |
| - 입력 일부를 마스킹하고 복원 |
| - 토큰 수준 또는 픽셀 수준 |
| - BERT, MAE, BEiT, data2vec |
+---------------------------------------------+
|
v
+---------------------------------------------+
| 4. Predictive (Pretext-based) |
| - 자동 생성 과제 풀기 |
| - 회전 예측, 직소 퍼즐, 순서 예측 |
| - RotNet, Jigsaw, CPC |
+---------------------------------------------+
핵심 방법론
1. Contrastive Methods
SimCLR (Chen et al., ICML 2020)
| 항목 |
내용 |
| 핵심 |
동일 이미지의 두 augmentation을 positive pair로 구성 |
| 인코더 |
ResNet-50 (이후 더 큰 모델도 사용) |
| 프로젝션 헤드 |
2-layer MLP (128-dim) |
| 배치 크기 |
4096~8192 (큰 배치가 핵심) |
| 손실 함수 |
NT-Xent (Normalized Temperature-scaled Cross Entropy) |
NT-Xent Loss:
\[
\ell_{i,j} = -\log \frac{\exp(\text{sim}(z_i, z_j) / \tau)}{\sum_{k=1}^{2N} \mathbb{1}_{[k \neq i]} \exp(\text{sim}(z_i, z_k) / \tau)}
\]
- \(z_i, z_j\): positive pair의 프로젝션 출력
- \(\tau\): temperature 파라미터 (0.1~0.5)
- \(\text{sim}\): cosine similarity
- 분모에 배치 내 모든 다른 샘플이 negative로 작용
핵심 발견:
- Data augmentation 조합이 성능에 가장 큰 영향 (random crop + color jitter가 최적)
- 비선형 프로젝션 헤드가 선형보다 +10% 성능 향상
- 배치 크기가 클수록 negative 샘플 다양성 증가 -> 성능 향상
MoCo v1/v2 (He et al., CVPR 2020 / Chen et al., 2020)
| 항목 |
내용 |
| 핵심 |
Momentum-updated queue로 large negative pool 유지 |
| 큐 크기 |
65536 (배치 크기와 무관) |
| 모멘텀 계수 |
m = 0.999 (천천히 업데이트) |
| 장점 |
작은 배치에서도 많은 negative 사용 가능 |
Momentum Update:
\[
\theta_k \leftarrow m \cdot \theta_k + (1 - m) \cdot \theta_q
\]
- \(\theta_q\): query encoder (gradient로 업데이트)
- \(\theta_k\): key encoder (momentum으로 업데이트)
- 큐에 가장 최근 key representation을 저장하고 FIFO로 관리
MoCo v2 개선점:
- SimCLR의 MLP 프로젝션 헤드 차용
- 더 강한 augmentation 적용
- cosine learning rate schedule
CLIP (Radford et al., ICML 2021)
| 항목 |
내용 |
| 핵심 |
이미지-텍스트 쌍의 contrastive learning |
| 학습 데이터 |
WebImageText (4억 이미지-텍스트 쌍) |
| 인코더 |
Vision Transformer + Text Transformer |
| 특징 |
Zero-shot 분류/검색 가능 |
Contrastive Loss (image-text):
\[
\mathcal{L} = -\frac{1}{N}\sum_{i=1}^{N} \left[\log \frac{\exp(I_i \cdot T_i / \tau)}{\sum_{j=1}^{N} \exp(I_i \cdot T_j / \tau)} + \log \frac{\exp(T_i \cdot I_i / \tau)}{\sum_{j=1}^{N} \exp(T_i \cdot I_j / \tau)}\right]
\]
2. Non-Contrastive (Self-Distillation) Methods
BYOL (Grill et al., NeurIPS 2020)
| 항목 |
내용 |
| 핵심 |
Negative pair 없이 학습 가능 |
| 구조 |
Online network + Target network (EMA) |
| 추가 구성 |
Predictor head (online 쪽에만) |
| 의의 |
Negative pair 없이도 collapse 방지 가능함을 증명 |
BYOL 아키텍처:
이미지 x
|
+-- aug1 --> Online Network --> projector --> predictor --> p
| (theta)
| |
| L2 loss (p, sg(z'))
| |
+-- aug2 --> Target Network --> projector -----------------> z'
(xi, EMA) (stop-gradient)
Target update: xi <- tau * xi + (1-tau) * theta
(tau = 0.996 -> 1.0, cosine schedule)
Collapse 방지 메커니즘:
BYOL이 trivial solution (모든 출력이 동일)으로 수렴하지 않는 이유:
1. Predictor head의 비대칭성
2. EMA target의 느린 업데이트
3. Batch normalization의 implicit regularization
DINO / DINOv2 (Caron et al., ICCV 2021 / Oquab et al., TMLR 2024)
| 항목 |
DINO (2021) |
DINOv2 (2024) |
| 인코더 |
ViT-S/B |
ViT-L/g (1B params) |
| 학습 데이터 |
ImageNet-1K |
LVD-142M (자동 큐레이션) |
| 핵심 기법 |
Self-distillation with centering |
+ 자동 데이터 큐레이션 파이프라인 |
| 성능 |
ImageNet linear probe 77.0% |
ImageNet linear probe 86.5% |
DINO의 핵심:
\[
\mathcal{L} = -\sum_{x \in \{x_1^g, x_2^g\}} \sum_{\substack{x' \in V \\ x' \neq x}} p_t(x) \log p_s(x')
\]
- Teacher: 전체 이미지의 global crop 처리
- Student: local + global crop 모두 처리
- Centering + sharpening으로 collapse 방지
- ViT의 [CLS] 토큰이 자연스럽게 segmentation 능력 획득
SimSiam (Chen & He, CVPR 2021)
| 항목 |
내용 |
| 핵심 |
Momentum encoder도 없이 Siamese network만으로 학습 |
| 구조 |
대칭 + stop-gradient |
| 의의 |
SSL의 최소 필수 구성 요소 규명 |
3. Masked Modeling Methods
MAE (He et al., CVPR 2022)
| 항목 |
내용 |
| 핵심 |
이미지 패치의 75%를 마스킹하고 복원 |
| 인코더 |
ViT (visible patches만 처리) |
| 디코더 |
경량 Transformer (마스킹 패치 복원) |
| 효율성 |
3배 이상 학습 속도 향상 (마스킹된 패치 스킵) |
MAE 아키텍처:
원본 이미지 --> 패치 분할 (예: 16x16)
|
+------+------+------+------+
| P1 | P2 | P3 | P4 | ...
+------+------+------+------+
^ ^ --> 75% 랜덤 마스킹
보임 보임
| |
v v
+----------------------------+
| ViT Encoder | <-- visible patches만 입력
| (깊고 무거움) |
+----------------------------+
|
v
+----------------------------+
| Lightweight Decoder | <-- mask tokens 추가
| (얕고 가벼움) |
+----------------------------+
|
v
복원된 전체 이미지
Loss: MSE(복원 픽셀, 원본 픽셀) (마스킹된 위치만 계산)
핵심 설계 원칙:
- 높은 마스킹 비율 (75%): 이미지는 텍스트보다 정보 중복이 높아 높은 비율 필요
- 비대칭 인코더-디코더: 인코더는 visible만 처리하여 효율적
- 픽셀 수준 복원: 토큰화 불필요, 간단한 목적함수
BEiT (Bao et al., ICLR 2022)
| 항목 |
내용 |
| 핵심 |
Visual tokenizer로 이산 토큰 예측 (BERT 방식) |
| 토크나이저 |
dVAE (discrete VAE) |
| 마스킹 비율 |
40% |
| 의의 |
NLP의 masked language modeling을 vision에 최초 적용 |
data2vec (Baevski et al., ICML 2022)
| 항목 |
내용 |
| 핵심 |
모달리티 무관 통합 SSL 프레임워크 |
| 적용 |
이미지, 텍스트, 음성 동시 지원 |
| 예측 대상 |
Teacher 네트워크의 latent representation |
| Teacher |
EMA 기반 |
4. Predictive (Pretext-based) Methods
초기 SSL 방법론으로, 데이터에서 자동 생성 가능한 과제를 정의한다:
| 방법 |
Pretext Task |
논문 |
연도 |
| Rotation |
0/90/180/270도 회전 예측 |
Gidaris et al. (ICLR 2018) |
2018 |
| Jigsaw |
9개 패치 순서 맞추기 |
Noroozi & Favaro (ECCV 2016) |
2016 |
| Colorization |
흑백->컬러 복원 |
Zhang et al. (ECCV 2016) |
2016 |
| CPC |
시퀀스의 미래 예측 |
van den Oord et al. (2018) |
2018 |
| Inpainting |
제거된 영역 복원 |
Pathak et al. (CVPR 2016) |
2016 |
NLP에서의 Self-Supervised Learning
NLP 분야에서 SSL은 사실상 표준 사전학습 방법론이 되었다:
Masked Language Modeling (MLM) - BERT 계열
입력: "The [MASK] sat on the [MASK]"
목표: [MASK] -> "cat", [MASK] -> "mat"
학습 전략:
- 전체 토큰의 15% 마스킹
- 80%: [MASK] 토큰으로 교체
- 10%: 랜덤 토큰으로 교체
- 10%: 원본 유지
Autoregressive Language Modeling - GPT 계열
입력: "The cat sat on"
목표: "the" (다음 토큰 예측)
학습:
P(x_t | x_1, ..., x_{t-1}) -- 왼쪽에서 오른쪽으로 순차 예측
비교
| 항목 |
BERT (MLM) |
GPT (AR) |
| 마스킹 방향 |
양방향 |
왼쪽->오른쪽 |
| 적합 태스크 |
분류, NER, QA |
생성, 요약, 대화 |
| 컨텍스트 |
전체 문맥 참조 가능 |
이전 토큰만 참조 |
| 스케일링 |
BERT -> RoBERTa -> DeBERTa |
GPT-2 -> GPT-3 -> GPT-4 |
음성에서의 Self-Supervised Learning
| 방법 |
핵심 |
논문 |
| wav2vec 2.0 |
음성 파형의 contrastive + masked prediction |
Baevski et al. (NeurIPS 2020) |
| HuBERT |
오프라인 클러스터링 + masked prediction |
Hsu et al. (IEEE/ACM 2021) |
| Whisper |
대규모 약한 지도 (자막-음성 매칭) |
Radford et al. (ICML 2023) |
성능 비교
ImageNet Linear Probe (Top-1 Accuracy)
| 방법 |
인코더 |
Epochs |
Top-1 (%) |
연도 |
| SimCLR |
ResNet-50 |
1000 |
69.3 |
2020 |
| MoCo v2 |
ResNet-50 |
800 |
71.1 |
2020 |
| BYOL |
ResNet-50 |
1000 |
74.3 |
2020 |
| SwAV |
ResNet-50 |
800 |
75.3 |
2020 |
| DINO |
ViT-B/16 |
400 |
78.2 |
2021 |
| MAE |
ViT-L/16 |
1600 |
75.8 |
2022 |
| iBOT |
ViT-L/16 |
800 |
81.6 |
2022 |
| DINOv2 |
ViT-g/14 |
- |
86.5 |
2024 |
주의: MAE는 linear probe보다 fine-tuning에서 더 강함 (ViT-L fine-tune: 85.9%)
ImageNet Fine-tuning (Top-1 Accuracy)
| 방법 |
인코더 |
Top-1 (%) |
| MAE |
ViT-H/14 |
87.8 |
| DINOv2 |
ViT-g/14 |
87.0 |
| BEiT v2 |
ViT-L/16 |
87.3 |
| Supervised baseline |
ViT-H/14 |
87.2 |
SSL이 지도학습과 동등하거나 초과하는 성능에 도달.
Collapse 문제와 해결
SSL에서 가장 큰 기술적 도전은 representation collapse -- 모든 입력에 대해 동일한 표현을 출력하는 trivial solution:
Collapse 유형
| 유형 |
설명 |
결과 |
| Complete collapse |
모든 출력이 상수 벡터 |
완전히 무의미한 표현 |
| Dimensional collapse |
출력의 일부 차원만 사용 |
표현 용량 낭비 |
| Cluster collapse |
소수의 클러스터로 수렴 |
세밀한 구분 불가 |
방지 전략
| 전략 |
방법 |
사용 모델 |
| Negative samples |
다른 샘플을 밀어내어 균일 분포 유도 |
SimCLR, MoCo |
| Stop-gradient |
한쪽 branch의 gradient 차단 |
BYOL, SimSiam |
| Centering |
Teacher 출력의 평균을 빼서 편향 제거 |
DINO |
| Sharpening |
Temperature를 낮추어 분포 첨예화 |
DINO |
| Variance regularization |
표현의 분산 유지 강제 |
VICReg |
| Batch normalization |
배치 내 통계로 implicit regularization |
BYOL |
최신 동향 (2024-2025)
1. Foundation Model 시대의 SSL
DINOv2, MAE v2 등 대규모 SSL 모델이 범용 vision backbone으로 자리매김. ImageNet 사전학습 대비 더 강건한 out-of-distribution 성능.
2. Multimodal SSL
CLIP, SigLIP, ImageBind 등 다중 모달리티를 통합하는 SSL이 주류. 텍스트-이미지-음성-비디오를 하나의 표현 공간에 정렬.
3. 효율적 SSL
- Masked modeling의 효율성 (MAE: 75% 마스킹으로 3배 빠른 학습)
- 작은 데이터셋에서의 SSL 적용 연구
- Self-supervised pre-training의 few-shot 성능 개선
4. 생성 모델과의 융합
Diffusion model의 내부 표현을 SSL로 추출하여 discriminative task에 활용. Score matching과 contrastive learning의 이론적 연결.
Python 구현 예시
SimCLR 핵심 구현
import torch
import torch.nn as nn
import torch.nn.functional as F
from torchvision import models, transforms
class SimCLR(nn.Module):
"""SimCLR 프레임워크 핵심 구현."""
def __init__(self, backbone='resnet50', projection_dim=128, hidden_dim=2048):
super().__init__()
# 인코더: pretrained weights 없이 시작
resnet = getattr(models, backbone)(weights=None)
self.encoder = nn.Sequential(*list(resnet.children())[:-1]) # avgpool까지
encoder_dim = resnet.fc.in_features # 2048 for ResNet-50
# 프로젝션 헤드: 2-layer MLP
self.projector = nn.Sequential(
nn.Linear(encoder_dim, hidden_dim),
nn.ReLU(inplace=True),
nn.Linear(hidden_dim, projection_dim)
)
def forward(self, x):
h = self.encoder(x).flatten(1) # representation
z = self.projector(h) # projection
return F.normalize(z, dim=1) # L2 정규화
def nt_xent_loss(z_i, z_j, temperature=0.5):
"""NT-Xent (Normalized Temperature-scaled Cross-Entropy) Loss.
Args:
z_i: [B, D] - 첫 번째 augmentation의 projection
z_j: [B, D] - 두 번째 augmentation의 projection
temperature: softmax temperature
Returns:
scalar loss
"""
batch_size = z_i.size(0)
# 전체 representation 결합: [2B, D]
z = torch.cat([z_i, z_j], dim=0)
# 코사인 유사도 행렬: [2B, 2B]
sim_matrix = F.cosine_similarity(z.unsqueeze(1), z.unsqueeze(0), dim=2)
sim_matrix = sim_matrix / temperature
# 자기 자신과의 유사도 제거
mask = torch.eye(2 * batch_size, dtype=torch.bool, device=z.device)
sim_matrix.masked_fill_(mask, -float('inf'))
# Positive pair 인덱스 구성
# z_i[k]의 positive는 z_j[k] (인덱스 k+B)
# z_j[k]의 positive는 z_i[k] (인덱스 k)
pos_indices = torch.cat([
torch.arange(batch_size, 2 * batch_size),
torch.arange(batch_size)
]).to(z.device)
# Cross-entropy loss
loss = F.cross_entropy(sim_matrix, pos_indices)
return loss
# 학습 augmentation 파이프라인
simclr_transform = transforms.Compose([
transforms.RandomResizedCrop(224, scale=(0.2, 1.0)),
transforms.RandomHorizontalFlip(),
transforms.RandomApply([
transforms.ColorJitter(0.4, 0.4, 0.4, 0.1)
], p=0.8),
transforms.RandomGrayscale(p=0.2),
transforms.GaussianBlur(kernel_size=23, sigma=(0.1, 2.0)),
transforms.ToTensor(),
transforms.Normalize(mean=[0.485, 0.456, 0.406],
std=[0.229, 0.224, 0.225])
])
# 사용 예시
if __name__ == "__main__":
model = SimCLR(projection_dim=128)
optimizer = torch.optim.Adam(model.parameters(), lr=3e-4, weight_decay=1e-4)
# 의사 학습 루프
for epoch in range(100):
# x1, x2 = 같은 이미지에 다른 augmentation 적용
x1 = torch.randn(256, 3, 224, 224) # 실제로는 DataLoader에서
x2 = torch.randn(256, 3, 224, 224)
z1 = model(x1)
z2 = model(x2)
loss = nt_xent_loss(z1, z2, temperature=0.5)
optimizer.zero_grad()
loss.backward()
optimizer.step()
MAE 핵심 구현
import torch
import torch.nn as nn
from functools import partial
class MAE(nn.Module):
"""Masked Autoencoder 핵심 구현 (간소화)."""
def __init__(
self,
img_size=224,
patch_size=16,
in_channels=3,
embed_dim=768,
encoder_depth=12,
decoder_embed_dim=512,
decoder_depth=4,
mask_ratio=0.75,
num_heads=12,
):
super().__init__()
self.patch_size = patch_size
self.mask_ratio = mask_ratio
num_patches = (img_size // patch_size) ** 2
# Patch Embedding
self.patch_embed = nn.Conv2d(
in_channels, embed_dim,
kernel_size=patch_size, stride=patch_size
)
self.pos_embed = nn.Parameter(
torch.zeros(1, num_patches, embed_dim)
)
# Encoder (visible patches만 처리)
encoder_layer = nn.TransformerEncoderLayer(
d_model=embed_dim, nhead=num_heads,
dim_feedforward=embed_dim * 4, batch_first=True
)
self.encoder = nn.TransformerEncoder(
encoder_layer, num_layers=encoder_depth
)
# Decoder
self.decoder_embed = nn.Linear(embed_dim, decoder_embed_dim)
self.mask_token = nn.Parameter(torch.zeros(1, 1, decoder_embed_dim))
self.decoder_pos_embed = nn.Parameter(
torch.zeros(1, num_patches, decoder_embed_dim)
)
decoder_layer = nn.TransformerEncoderLayer(
d_model=decoder_embed_dim, nhead=8,
dim_feedforward=decoder_embed_dim * 4, batch_first=True
)
self.decoder = nn.TransformerEncoder(
decoder_layer, num_layers=decoder_depth
)
# 픽셀 복원 헤드
self.pred = nn.Linear(
decoder_embed_dim, patch_size ** 2 * in_channels
)
def random_masking(self, x, mask_ratio):
"""랜덤 마스킹: 패치의 mask_ratio만큼 제거.
Args:
x: [B, N, D] 패치 임베딩
mask_ratio: 마스킹 비율
Returns:
x_masked: [B, N*(1-mask_ratio), D] visible patches
mask: [B, N] binary mask (1=masked)
ids_restore: 복원 순서
"""
B, N, D = x.shape
len_keep = int(N * (1 - mask_ratio))
# 랜덤 순열 생성
noise = torch.rand(B, N, device=x.device)
ids_shuffle = torch.argsort(noise, dim=1)
ids_restore = torch.argsort(ids_shuffle, dim=1)
# 상위 len_keep개만 유지
ids_keep = ids_shuffle[:, :len_keep]
x_masked = torch.gather(
x, 1, ids_keep.unsqueeze(-1).expand(-1, -1, D)
)
# Binary mask 생성
mask = torch.ones(B, N, device=x.device)
mask[:, :len_keep] = 0
mask = torch.gather(mask, 1, ids_restore)
return x_masked, mask, ids_restore
def forward(self, x):
# 패치 임베딩
patches = self.patch_embed(x).flatten(2).transpose(1, 2)
patches = patches + self.pos_embed
# 랜덤 마스킹
visible, mask, ids_restore = self.random_masking(
patches, self.mask_ratio
)
# 인코더 (visible patches만)
encoded = self.encoder(visible)
# 디코더 입력 준비
decoded = self.decoder_embed(encoded)
# mask tokens 추가
B, N_vis, D = decoded.shape
N_total = self.pos_embed.shape[1]
mask_tokens = self.mask_token.expand(B, N_total - N_vis, -1)
full_tokens = torch.cat([decoded, mask_tokens], dim=1)
# 원래 순서로 복원
full_tokens = torch.gather(
full_tokens, 1,
ids_restore.unsqueeze(-1).expand(-1, -1, D)
)
full_tokens = full_tokens + self.decoder_pos_embed
# 디코더
decoded = self.decoder(full_tokens)
pred = self.pred(decoded)
# Loss: 마스킹된 패치에 대해서만 MSE
return pred, mask
def mae_loss(pred, target, mask, patch_size=16):
"""MAE 복원 손실 (마스킹된 패치만).
Args:
pred: [B, N, patch_size^2 * 3] 예측 픽셀
target: [B, 3, H, W] 원본 이미지
mask: [B, N] binary mask (1=masked)
patch_size: 패치 크기
"""
# 원본을 패치로 변환
B, C, H, W = target.shape
p = patch_size
target_patches = target.reshape(B, C, H // p, p, W // p, p)
target_patches = target_patches.permute(0, 2, 4, 3, 5, 1)
target_patches = target_patches.reshape(B, -1, p * p * C)
# 패치별 정규화 (선택적)
mean = target_patches.mean(dim=-1, keepdim=True)
var = target_patches.var(dim=-1, keepdim=True)
target_patches = (target_patches - mean) / (var + 1e-6).sqrt()
# 마스킹된 패치에 대해서만 MSE
loss = (pred - target_patches) ** 2
loss = loss.mean(dim=-1) # [B, N]
loss = (loss * mask).sum() / mask.sum()
return loss
Linear Probe 평가
import torch
import torch.nn as nn
from torch.utils.data import DataLoader
from torchvision import datasets, transforms
def linear_probe(encoder, train_loader, test_loader,
feature_dim=2048, num_classes=1000, epochs=100):
"""SSL 모델의 표현 품질을 평가하는 Linear Probe.
인코더를 고정하고 선형 분류기만 학습.
"""
# 인코더 고정
encoder.eval()
for param in encoder.parameters():
param.requires_grad = False
# 선형 분류기
classifier = nn.Linear(feature_dim, num_classes).cuda()
optimizer = torch.optim.SGD(
classifier.parameters(), lr=0.3,
momentum=0.9, weight_decay=0
)
scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(
optimizer, T_max=epochs
)
# 학습
for epoch in range(epochs):
classifier.train()
for images, labels in train_loader:
images, labels = images.cuda(), labels.cuda()
with torch.no_grad():
features = encoder(images).flatten(1)
logits = classifier(features)
loss = nn.functional.cross_entropy(logits, labels)
optimizer.zero_grad()
loss.backward()
optimizer.step()
scheduler.step()
# 평가
classifier.eval()
correct, total = 0, 0
with torch.no_grad():
for images, labels in test_loader:
images, labels = images.cuda(), labels.cuda()
features = encoder(images).flatten(1)
logits = classifier(features)
correct += (logits.argmax(1) == labels).sum().item()
total += labels.size(0)
accuracy = correct / total * 100
return accuracy
실무 적용 가이드
어떤 SSL 방법을 선택할 것인가
| 상황 |
추천 방법 |
이유 |
| 범용 vision backbone |
DINOv2 |
최고 성능, 다양한 downstream task |
| 대규모 ViT 사전학습 |
MAE |
학습 효율성 (75% 마스킹) |
| 작은 데이터셋 |
MoCo v3 |
Queue로 배치 크기 무관하게 학습 |
| 멀티모달 |
CLIP / SigLIP |
텍스트-이미지 정렬 |
| 의료/특수 도메인 |
MAE + domain data |
도메인 특화 사전학습 |
| NLP |
BERT/RoBERTa 또는 GPT 계열 |
태스크 특성에 따라 선택 |
| 음성 |
wav2vec 2.0 / HuBERT |
레이블 효율적 음성 인식 |
Hyperparameter 가이드
| 파라미터 |
SimCLR |
MoCo v3 |
MAE |
DINO |
| 학습률 |
0.3 (LARS) |
1.5e-4 (AdamW) |
1.5e-4 (AdamW) |
5e-4 (AdamW) |
| 배치 크기 |
4096 |
4096 |
4096 |
1024 |
| Epochs |
800-1000 |
300-600 |
800-1600 |
300-400 |
| Weight decay |
1e-6 |
0.1 |
0.05 |
0.04-0.4 |
| Temperature |
0.1-0.5 |
0.2 |
N/A |
0.04-0.07 |
| EMA momentum |
N/A |
0.99-0.999 |
N/A |
0.996-1.0 |
참고 자료
| 자료 |
링크 |
| SimCLR 논문 |
https://arxiv.org/abs/2002.05709 |
| MoCo 논문 |
https://arxiv.org/abs/1911.05722 |
| BYOL 논문 |
https://arxiv.org/abs/2006.07733 |
| MAE 논문 |
https://arxiv.org/abs/2111.06377 |
| DINO 논문 |
https://arxiv.org/abs/2104.14294 |
| DINOv2 논문 |
https://arxiv.org/abs/2304.07193 |
| CLIP 논문 |
https://arxiv.org/abs/2103.00020 |
| SSL Survey (Gui et al.) |
https://arxiv.org/abs/2301.05712 |
| NeurIPS 2024 SSL Workshop |
https://sslneurips2024.github.io/ |
| lightly (SSL 라이브러리) |
https://github.com/lightly-ai/lightly |
| solo-learn (SSL 벤치마크) |
https://github.com/vturrisi/solo-learn |