콘텐츠로 이동
Data Prep
상세

Self-Supervised Learning (SSL)

메타 정보

항목 내용
분류 Representation Learning / Label-Efficient Learning
핵심 논문 "A Simple Framework for Contrastive Learning of Visual Representations" (Chen et al., ICML 2020 - SimCLR), "Momentum Contrast for Unsupervised Visual Representation Learning" (He et al., CVPR 2020 - MoCo v1/v2), "Bootstrap Your Own Latent" (Grill et al., NeurIPS 2020 - BYOL), "Masked Autoencoders Are Scalable Vision Learners" (He et al., CVPR 2022 - MAE), "Emerging Properties in Self-Supervised Vision Transformers" (Caron et al., ICCV 2021 - DINO), "DINOv2: Learning Robust Visual Features without Supervision" (Oquab et al., TMLR 2024), "BERT: Pre-Training of Deep Bidirectional Transformers" (Devlin et al., NAACL 2019), "Language Models are Unsupervised Multitask Learners" (Radford et al., 2019 - GPT-2)
주요 저자 Kaiming He (MoCo, MAE), Ting Chen & Geoffrey Hinton (SimCLR), Mathilde Caron (DINO), Jean-Baptiste Grill (BYOL), Jacob Devlin (BERT)
핵심 개념 레이블 없는 데이터에서 pretext task를 자동 생성하여 범용 표현(representation)을 학습하는 패러다임
관련 분야 Contrastive Learning, Masked Modeling, Foundation Models, Transfer Learning, Representation Learning

정의

Self-Supervised Learning(SSL)은 레이블 없는 데이터로부터 pretext task(자동 생성된 보조 과제)를 정의하고, 이를 풀면서 데이터의 구조적 표현을 학습하는 방법론이다. 지도학습의 레이블 비용 문제와 비지도학습의 목적 불명확성을 동시에 해결한다.

학습 패러다임 비교

항목 지도학습 비지도학습 자기지도학습
레이블 필요 (비용 높음) 불필요 불필요
학습 신호 사람이 제공한 정답 데이터 분포 데이터에서 자동 생성
목적함수 cross-entropy 등 reconstruction, density pretext task별 상이
표현 품질 태스크 특화 범용적이나 품질 불균일 범용적이며 고품질
대표 사례 ResNet + ImageNet VAE, GAN SimCLR, MAE, BERT
확장성 레이블에 의존 데이터 양에 비례 데이터 양에 비례

핵심 아이디어

Self-Supervised Learning의 2단계 프레임워크:

[1단계: Pretext Task로 사전학습]
+------------------------------------------------------------------+
|                                                                  |
|  대량의 비레이블 데이터  -->  Pretext Task 정의  -->  인코더 학습  |
|                                                                  |
|  예시:                                                           |
|  - 이미지 일부 마스킹 후 복원 (MAE)                              |
|  - 같은 이미지의 다른 augmentation끼리 유사하게 (SimCLR)         |
|  - 문장 내 단어 마스킹 후 예측 (BERT)                            |
|  - 다음 토큰 예측 (GPT)                                         |
|                                                                  |
+------------------------------------------------------------------+
                        |
                        v
[2단계: Downstream Task로 전이]
+------------------------------------------------------------------+
|                                                                  |
|  학습된 인코더  -->  소량 레이블로 fine-tuning 또는 linear probe  |
|                                                                  |
|  분류, 탐지, 분할, QA 등 다양한 태스크에 적용                    |
|                                                                  |
+------------------------------------------------------------------+

SSL 방법론 분류 체계

SSL 방법론은 크게 네 가지 패밀리로 분류된다:

Self-Supervised Learning 방법론 분류:

+---------------------------------------------+
| 1. Contrastive Methods                      |
| - positive/negative pair 구성               |
| - InfoNCE 손실 함수                         |
| - SimCLR, MoCo, CLIP                       |
+---------------------------------------------+
          |
          v
+---------------------------------------------+
| 2. Non-Contrastive (Self-Distillation)      |
| - negative pair 없이 학습                   |
| - EMA teacher + stop-gradient               |
| - BYOL, SimSiam, DINO, DINOv2              |
+---------------------------------------------+
          |
          v
+---------------------------------------------+
| 3. Masked Modeling                          |
| - 입력 일부를 마스킹하고 복원               |
| - 토큰 수준 또는 픽셀 수준                  |
| - BERT, MAE, BEiT, data2vec               |
+---------------------------------------------+
          |
          v
+---------------------------------------------+
| 4. Predictive (Pretext-based)               |
| - 자동 생성 과제 풀기                       |
| - 회전 예측, 직소 퍼즐, 순서 예측           |
| - RotNet, Jigsaw, CPC                      |
+---------------------------------------------+

핵심 방법론

1. Contrastive Methods

SimCLR (Chen et al., ICML 2020)

항목 내용
핵심 동일 이미지의 두 augmentation을 positive pair로 구성
인코더 ResNet-50 (이후 더 큰 모델도 사용)
프로젝션 헤드 2-layer MLP (128-dim)
배치 크기 4096~8192 (큰 배치가 핵심)
손실 함수 NT-Xent (Normalized Temperature-scaled Cross Entropy)

NT-Xent Loss:

\[ \ell_{i,j} = -\log \frac{\exp(\text{sim}(z_i, z_j) / \tau)}{\sum_{k=1}^{2N} \mathbb{1}_{[k \neq i]} \exp(\text{sim}(z_i, z_k) / \tau)} \]
  • \(z_i, z_j\): positive pair의 프로젝션 출력
  • \(\tau\): temperature 파라미터 (0.1~0.5)
  • \(\text{sim}\): cosine similarity
  • 분모에 배치 내 모든 다른 샘플이 negative로 작용

핵심 발견:

  • Data augmentation 조합이 성능에 가장 큰 영향 (random crop + color jitter가 최적)
  • 비선형 프로젝션 헤드가 선형보다 +10% 성능 향상
  • 배치 크기가 클수록 negative 샘플 다양성 증가 -> 성능 향상

MoCo v1/v2 (He et al., CVPR 2020 / Chen et al., 2020)

항목 내용
핵심 Momentum-updated queue로 large negative pool 유지
큐 크기 65536 (배치 크기와 무관)
모멘텀 계수 m = 0.999 (천천히 업데이트)
장점 작은 배치에서도 많은 negative 사용 가능

Momentum Update:

\[ \theta_k \leftarrow m \cdot \theta_k + (1 - m) \cdot \theta_q \]
  • \(\theta_q\): query encoder (gradient로 업데이트)
  • \(\theta_k\): key encoder (momentum으로 업데이트)
  • 큐에 가장 최근 key representation을 저장하고 FIFO로 관리

MoCo v2 개선점: - SimCLR의 MLP 프로젝션 헤드 차용 - 더 강한 augmentation 적용 - cosine learning rate schedule

CLIP (Radford et al., ICML 2021)

항목 내용
핵심 이미지-텍스트 쌍의 contrastive learning
학습 데이터 WebImageText (4억 이미지-텍스트 쌍)
인코더 Vision Transformer + Text Transformer
특징 Zero-shot 분류/검색 가능

Contrastive Loss (image-text):

\[ \mathcal{L} = -\frac{1}{N}\sum_{i=1}^{N} \left[\log \frac{\exp(I_i \cdot T_i / \tau)}{\sum_{j=1}^{N} \exp(I_i \cdot T_j / \tau)} + \log \frac{\exp(T_i \cdot I_i / \tau)}{\sum_{j=1}^{N} \exp(T_i \cdot I_j / \tau)}\right] \]

2. Non-Contrastive (Self-Distillation) Methods

BYOL (Grill et al., NeurIPS 2020)

항목 내용
핵심 Negative pair 없이 학습 가능
구조 Online network + Target network (EMA)
추가 구성 Predictor head (online 쪽에만)
의의 Negative pair 없이도 collapse 방지 가능함을 증명
BYOL 아키텍처:

  이미지 x
    |
    +-- aug1 --> Online Network  --> projector --> predictor --> p
    |              (theta)
    |                                                           |
    |                                              L2 loss (p, sg(z'))
    |                                                           |
    +-- aug2 --> Target Network --> projector -----------------> z'
                   (xi, EMA)           (stop-gradient)

  Target update: xi <- tau * xi + (1-tau) * theta
  (tau = 0.996 -> 1.0, cosine schedule)

Collapse 방지 메커니즘:

BYOL이 trivial solution (모든 출력이 동일)으로 수렴하지 않는 이유: 1. Predictor head의 비대칭성 2. EMA target의 느린 업데이트 3. Batch normalization의 implicit regularization

DINO / DINOv2 (Caron et al., ICCV 2021 / Oquab et al., TMLR 2024)

항목 DINO (2021) DINOv2 (2024)
인코더 ViT-S/B ViT-L/g (1B params)
학습 데이터 ImageNet-1K LVD-142M (자동 큐레이션)
핵심 기법 Self-distillation with centering + 자동 데이터 큐레이션 파이프라인
성능 ImageNet linear probe 77.0% ImageNet linear probe 86.5%

DINO의 핵심:

\[ \mathcal{L} = -\sum_{x \in \{x_1^g, x_2^g\}} \sum_{\substack{x' \in V \\ x' \neq x}} p_t(x) \log p_s(x') \]
  • Teacher: 전체 이미지의 global crop 처리
  • Student: local + global crop 모두 처리
  • Centering + sharpening으로 collapse 방지
  • ViT의 [CLS] 토큰이 자연스럽게 segmentation 능력 획득

SimSiam (Chen & He, CVPR 2021)

항목 내용
핵심 Momentum encoder도 없이 Siamese network만으로 학습
구조 대칭 + stop-gradient
의의 SSL의 최소 필수 구성 요소 규명

3. Masked Modeling Methods

MAE (He et al., CVPR 2022)

항목 내용
핵심 이미지 패치의 75%를 마스킹하고 복원
인코더 ViT (visible patches만 처리)
디코더 경량 Transformer (마스킹 패치 복원)
효율성 3배 이상 학습 속도 향상 (마스킹된 패치 스킵)
MAE 아키텍처:

  원본 이미지 --> 패치 분할 (예: 16x16)
                    |
      +------+------+------+------+
      | P1   | P2   | P3   | P4   |  ...
      +------+------+------+------+
        ^            ^                 --> 75% 랜덤 마스킹
       보임         보임
        |            |
        v            v
  +----------------------------+
  |     ViT Encoder            |   <-- visible patches만 입력
  |     (깊고 무거움)          |
  +----------------------------+
              |
              v
  +----------------------------+
  |     Lightweight Decoder    |   <-- mask tokens 추가
  |     (얕고 가벼움)          |
  +----------------------------+
              |
              v
       복원된 전체 이미지

  Loss: MSE(복원 픽셀, 원본 픽셀)  (마스킹된 위치만 계산)

핵심 설계 원칙:

  1. 높은 마스킹 비율 (75%): 이미지는 텍스트보다 정보 중복이 높아 높은 비율 필요
  2. 비대칭 인코더-디코더: 인코더는 visible만 처리하여 효율적
  3. 픽셀 수준 복원: 토큰화 불필요, 간단한 목적함수

BEiT (Bao et al., ICLR 2022)

항목 내용
핵심 Visual tokenizer로 이산 토큰 예측 (BERT 방식)
토크나이저 dVAE (discrete VAE)
마스킹 비율 40%
의의 NLP의 masked language modeling을 vision에 최초 적용

data2vec (Baevski et al., ICML 2022)

항목 내용
핵심 모달리티 무관 통합 SSL 프레임워크
적용 이미지, 텍스트, 음성 동시 지원
예측 대상 Teacher 네트워크의 latent representation
Teacher EMA 기반

4. Predictive (Pretext-based) Methods

초기 SSL 방법론으로, 데이터에서 자동 생성 가능한 과제를 정의한다:

방법 Pretext Task 논문 연도
Rotation 0/90/180/270도 회전 예측 Gidaris et al. (ICLR 2018) 2018
Jigsaw 9개 패치 순서 맞추기 Noroozi & Favaro (ECCV 2016) 2016
Colorization 흑백->컬러 복원 Zhang et al. (ECCV 2016) 2016
CPC 시퀀스의 미래 예측 van den Oord et al. (2018) 2018
Inpainting 제거된 영역 복원 Pathak et al. (CVPR 2016) 2016

NLP에서의 Self-Supervised Learning

NLP 분야에서 SSL은 사실상 표준 사전학습 방법론이 되었다:

Masked Language Modeling (MLM) - BERT 계열

입력:  "The [MASK] sat on the [MASK]"
목표:  [MASK] -> "cat", [MASK] -> "mat"

학습 전략:
- 전체 토큰의 15% 마스킹
  - 80%: [MASK] 토큰으로 교체
  - 10%: 랜덤 토큰으로 교체
  - 10%: 원본 유지

Autoregressive Language Modeling - GPT 계열

입력:  "The cat sat on"
목표:  "the" (다음 토큰 예측)

학습:
  P(x_t | x_1, ..., x_{t-1})  -- 왼쪽에서 오른쪽으로 순차 예측

비교

항목 BERT (MLM) GPT (AR)
마스킹 방향 양방향 왼쪽->오른쪽
적합 태스크 분류, NER, QA 생성, 요약, 대화
컨텍스트 전체 문맥 참조 가능 이전 토큰만 참조
스케일링 BERT -> RoBERTa -> DeBERTa GPT-2 -> GPT-3 -> GPT-4

음성에서의 Self-Supervised Learning

방법 핵심 논문
wav2vec 2.0 음성 파형의 contrastive + masked prediction Baevski et al. (NeurIPS 2020)
HuBERT 오프라인 클러스터링 + masked prediction Hsu et al. (IEEE/ACM 2021)
Whisper 대규모 약한 지도 (자막-음성 매칭) Radford et al. (ICML 2023)

성능 비교

ImageNet Linear Probe (Top-1 Accuracy)

방법 인코더 Epochs Top-1 (%) 연도
SimCLR ResNet-50 1000 69.3 2020
MoCo v2 ResNet-50 800 71.1 2020
BYOL ResNet-50 1000 74.3 2020
SwAV ResNet-50 800 75.3 2020
DINO ViT-B/16 400 78.2 2021
MAE ViT-L/16 1600 75.8 2022
iBOT ViT-L/16 800 81.6 2022
DINOv2 ViT-g/14 - 86.5 2024

주의: MAE는 linear probe보다 fine-tuning에서 더 강함 (ViT-L fine-tune: 85.9%)

ImageNet Fine-tuning (Top-1 Accuracy)

방법 인코더 Top-1 (%)
MAE ViT-H/14 87.8
DINOv2 ViT-g/14 87.0
BEiT v2 ViT-L/16 87.3
Supervised baseline ViT-H/14 87.2

SSL이 지도학습과 동등하거나 초과하는 성능에 도달.


Collapse 문제와 해결

SSL에서 가장 큰 기술적 도전은 representation collapse -- 모든 입력에 대해 동일한 표현을 출력하는 trivial solution:

Collapse 유형

유형 설명 결과
Complete collapse 모든 출력이 상수 벡터 완전히 무의미한 표현
Dimensional collapse 출력의 일부 차원만 사용 표현 용량 낭비
Cluster collapse 소수의 클러스터로 수렴 세밀한 구분 불가

방지 전략

전략 방법 사용 모델
Negative samples 다른 샘플을 밀어내어 균일 분포 유도 SimCLR, MoCo
Stop-gradient 한쪽 branch의 gradient 차단 BYOL, SimSiam
Centering Teacher 출력의 평균을 빼서 편향 제거 DINO
Sharpening Temperature를 낮추어 분포 첨예화 DINO
Variance regularization 표현의 분산 유지 강제 VICReg
Batch normalization 배치 내 통계로 implicit regularization BYOL

최신 동향 (2024-2025)

1. Foundation Model 시대의 SSL

DINOv2, MAE v2 등 대규모 SSL 모델이 범용 vision backbone으로 자리매김. ImageNet 사전학습 대비 더 강건한 out-of-distribution 성능.

2. Multimodal SSL

CLIP, SigLIP, ImageBind 등 다중 모달리티를 통합하는 SSL이 주류. 텍스트-이미지-음성-비디오를 하나의 표현 공간에 정렬.

3. 효율적 SSL

  • Masked modeling의 효율성 (MAE: 75% 마스킹으로 3배 빠른 학습)
  • 작은 데이터셋에서의 SSL 적용 연구
  • Self-supervised pre-training의 few-shot 성능 개선

4. 생성 모델과의 융합

Diffusion model의 내부 표현을 SSL로 추출하여 discriminative task에 활용. Score matching과 contrastive learning의 이론적 연결.


Python 구현 예시

SimCLR 핵심 구현

import torch
import torch.nn as nn
import torch.nn.functional as F
from torchvision import models, transforms


class SimCLR(nn.Module):
    """SimCLR 프레임워크 핵심 구현."""

    def __init__(self, backbone='resnet50', projection_dim=128, hidden_dim=2048):
        super().__init__()
        # 인코더: pretrained weights 없이 시작
        resnet = getattr(models, backbone)(weights=None)
        self.encoder = nn.Sequential(*list(resnet.children())[:-1])  # avgpool까지

        encoder_dim = resnet.fc.in_features  # 2048 for ResNet-50

        # 프로젝션 헤드: 2-layer MLP
        self.projector = nn.Sequential(
            nn.Linear(encoder_dim, hidden_dim),
            nn.ReLU(inplace=True),
            nn.Linear(hidden_dim, projection_dim)
        )

    def forward(self, x):
        h = self.encoder(x).flatten(1)  # representation
        z = self.projector(h)            # projection
        return F.normalize(z, dim=1)     # L2 정규화


def nt_xent_loss(z_i, z_j, temperature=0.5):
    """NT-Xent (Normalized Temperature-scaled Cross-Entropy) Loss.

    Args:
        z_i: [B, D] - 첫 번째 augmentation의 projection
        z_j: [B, D] - 두 번째 augmentation의 projection
        temperature: softmax temperature

    Returns:
        scalar loss
    """
    batch_size = z_i.size(0)

    # 전체 representation 결합: [2B, D]
    z = torch.cat([z_i, z_j], dim=0)

    # 코사인 유사도 행렬: [2B, 2B]
    sim_matrix = F.cosine_similarity(z.unsqueeze(1), z.unsqueeze(0), dim=2)
    sim_matrix = sim_matrix / temperature

    # 자기 자신과의 유사도 제거
    mask = torch.eye(2 * batch_size, dtype=torch.bool, device=z.device)
    sim_matrix.masked_fill_(mask, -float('inf'))

    # Positive pair 인덱스 구성
    # z_i[k]의 positive는 z_j[k] (인덱스 k+B)
    # z_j[k]의 positive는 z_i[k] (인덱스 k)
    pos_indices = torch.cat([
        torch.arange(batch_size, 2 * batch_size),
        torch.arange(batch_size)
    ]).to(z.device)

    # Cross-entropy loss
    loss = F.cross_entropy(sim_matrix, pos_indices)
    return loss


# 학습 augmentation 파이프라인
simclr_transform = transforms.Compose([
    transforms.RandomResizedCrop(224, scale=(0.2, 1.0)),
    transforms.RandomHorizontalFlip(),
    transforms.RandomApply([
        transforms.ColorJitter(0.4, 0.4, 0.4, 0.1)
    ], p=0.8),
    transforms.RandomGrayscale(p=0.2),
    transforms.GaussianBlur(kernel_size=23, sigma=(0.1, 2.0)),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406],
                         std=[0.229, 0.224, 0.225])
])


# 사용 예시
if __name__ == "__main__":
    model = SimCLR(projection_dim=128)
    optimizer = torch.optim.Adam(model.parameters(), lr=3e-4, weight_decay=1e-4)

    # 의사 학습 루프
    for epoch in range(100):
        # x1, x2 = 같은 이미지에 다른 augmentation 적용
        x1 = torch.randn(256, 3, 224, 224)  # 실제로는 DataLoader에서
        x2 = torch.randn(256, 3, 224, 224)

        z1 = model(x1)
        z2 = model(x2)

        loss = nt_xent_loss(z1, z2, temperature=0.5)

        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

MAE 핵심 구현

import torch
import torch.nn as nn
from functools import partial


class MAE(nn.Module):
    """Masked Autoencoder 핵심 구현 (간소화)."""

    def __init__(
        self,
        img_size=224,
        patch_size=16,
        in_channels=3,
        embed_dim=768,
        encoder_depth=12,
        decoder_embed_dim=512,
        decoder_depth=4,
        mask_ratio=0.75,
        num_heads=12,
    ):
        super().__init__()
        self.patch_size = patch_size
        self.mask_ratio = mask_ratio
        num_patches = (img_size // patch_size) ** 2

        # Patch Embedding
        self.patch_embed = nn.Conv2d(
            in_channels, embed_dim,
            kernel_size=patch_size, stride=patch_size
        )
        self.pos_embed = nn.Parameter(
            torch.zeros(1, num_patches, embed_dim)
        )

        # Encoder (visible patches만 처리)
        encoder_layer = nn.TransformerEncoderLayer(
            d_model=embed_dim, nhead=num_heads,
            dim_feedforward=embed_dim * 4, batch_first=True
        )
        self.encoder = nn.TransformerEncoder(
            encoder_layer, num_layers=encoder_depth
        )

        # Decoder
        self.decoder_embed = nn.Linear(embed_dim, decoder_embed_dim)
        self.mask_token = nn.Parameter(torch.zeros(1, 1, decoder_embed_dim))
        self.decoder_pos_embed = nn.Parameter(
            torch.zeros(1, num_patches, decoder_embed_dim)
        )

        decoder_layer = nn.TransformerEncoderLayer(
            d_model=decoder_embed_dim, nhead=8,
            dim_feedforward=decoder_embed_dim * 4, batch_first=True
        )
        self.decoder = nn.TransformerEncoder(
            decoder_layer, num_layers=decoder_depth
        )

        # 픽셀 복원 헤드
        self.pred = nn.Linear(
            decoder_embed_dim, patch_size ** 2 * in_channels
        )

    def random_masking(self, x, mask_ratio):
        """랜덤 마스킹: 패치의 mask_ratio만큼 제거.

        Args:
            x: [B, N, D] 패치 임베딩
            mask_ratio: 마스킹 비율

        Returns:
            x_masked: [B, N*(1-mask_ratio), D] visible patches
            mask: [B, N] binary mask (1=masked)
            ids_restore: 복원 순서
        """
        B, N, D = x.shape
        len_keep = int(N * (1 - mask_ratio))

        # 랜덤 순열 생성
        noise = torch.rand(B, N, device=x.device)
        ids_shuffle = torch.argsort(noise, dim=1)
        ids_restore = torch.argsort(ids_shuffle, dim=1)

        # 상위 len_keep개만 유지
        ids_keep = ids_shuffle[:, :len_keep]
        x_masked = torch.gather(
            x, 1, ids_keep.unsqueeze(-1).expand(-1, -1, D)
        )

        # Binary mask 생성
        mask = torch.ones(B, N, device=x.device)
        mask[:, :len_keep] = 0
        mask = torch.gather(mask, 1, ids_restore)

        return x_masked, mask, ids_restore

    def forward(self, x):
        # 패치 임베딩
        patches = self.patch_embed(x).flatten(2).transpose(1, 2)
        patches = patches + self.pos_embed

        # 랜덤 마스킹
        visible, mask, ids_restore = self.random_masking(
            patches, self.mask_ratio
        )

        # 인코더 (visible patches만)
        encoded = self.encoder(visible)

        # 디코더 입력 준비
        decoded = self.decoder_embed(encoded)

        # mask tokens 추가
        B, N_vis, D = decoded.shape
        N_total = self.pos_embed.shape[1]
        mask_tokens = self.mask_token.expand(B, N_total - N_vis, -1)

        full_tokens = torch.cat([decoded, mask_tokens], dim=1)
        # 원래 순서로 복원
        full_tokens = torch.gather(
            full_tokens, 1,
            ids_restore.unsqueeze(-1).expand(-1, -1, D)
        )
        full_tokens = full_tokens + self.decoder_pos_embed

        # 디코더
        decoded = self.decoder(full_tokens)
        pred = self.pred(decoded)

        # Loss: 마스킹된 패치에 대해서만 MSE
        return pred, mask


def mae_loss(pred, target, mask, patch_size=16):
    """MAE 복원 손실 (마스킹된 패치만).

    Args:
        pred: [B, N, patch_size^2 * 3] 예측 픽셀
        target: [B, 3, H, W] 원본 이미지
        mask: [B, N] binary mask (1=masked)
        patch_size: 패치 크기
    """
    # 원본을 패치로 변환
    B, C, H, W = target.shape
    p = patch_size
    target_patches = target.reshape(B, C, H // p, p, W // p, p)
    target_patches = target_patches.permute(0, 2, 4, 3, 5, 1)
    target_patches = target_patches.reshape(B, -1, p * p * C)

    # 패치별 정규화 (선택적)
    mean = target_patches.mean(dim=-1, keepdim=True)
    var = target_patches.var(dim=-1, keepdim=True)
    target_patches = (target_patches - mean) / (var + 1e-6).sqrt()

    # 마스킹된 패치에 대해서만 MSE
    loss = (pred - target_patches) ** 2
    loss = loss.mean(dim=-1)  # [B, N]
    loss = (loss * mask).sum() / mask.sum()

    return loss

Linear Probe 평가

import torch
import torch.nn as nn
from torch.utils.data import DataLoader
from torchvision import datasets, transforms


def linear_probe(encoder, train_loader, test_loader, 
                 feature_dim=2048, num_classes=1000, epochs=100):
    """SSL 모델의 표현 품질을 평가하는 Linear Probe.

    인코더를 고정하고 선형 분류기만 학습.
    """
    # 인코더 고정
    encoder.eval()
    for param in encoder.parameters():
        param.requires_grad = False

    # 선형 분류기
    classifier = nn.Linear(feature_dim, num_classes).cuda()
    optimizer = torch.optim.SGD(
        classifier.parameters(), lr=0.3,
        momentum=0.9, weight_decay=0
    )
    scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(
        optimizer, T_max=epochs
    )

    # 학습
    for epoch in range(epochs):
        classifier.train()
        for images, labels in train_loader:
            images, labels = images.cuda(), labels.cuda()

            with torch.no_grad():
                features = encoder(images).flatten(1)

            logits = classifier(features)
            loss = nn.functional.cross_entropy(logits, labels)

            optimizer.zero_grad()
            loss.backward()
            optimizer.step()

        scheduler.step()

    # 평가
    classifier.eval()
    correct, total = 0, 0
    with torch.no_grad():
        for images, labels in test_loader:
            images, labels = images.cuda(), labels.cuda()
            features = encoder(images).flatten(1)
            logits = classifier(features)
            correct += (logits.argmax(1) == labels).sum().item()
            total += labels.size(0)

    accuracy = correct / total * 100
    return accuracy

실무 적용 가이드

어떤 SSL 방법을 선택할 것인가

상황 추천 방법 이유
범용 vision backbone DINOv2 최고 성능, 다양한 downstream task
대규모 ViT 사전학습 MAE 학습 효율성 (75% 마스킹)
작은 데이터셋 MoCo v3 Queue로 배치 크기 무관하게 학습
멀티모달 CLIP / SigLIP 텍스트-이미지 정렬
의료/특수 도메인 MAE + domain data 도메인 특화 사전학습
NLP BERT/RoBERTa 또는 GPT 계열 태스크 특성에 따라 선택
음성 wav2vec 2.0 / HuBERT 레이블 효율적 음성 인식

Hyperparameter 가이드

파라미터 SimCLR MoCo v3 MAE DINO
학습률 0.3 (LARS) 1.5e-4 (AdamW) 1.5e-4 (AdamW) 5e-4 (AdamW)
배치 크기 4096 4096 4096 1024
Epochs 800-1000 300-600 800-1600 300-400
Weight decay 1e-6 0.1 0.05 0.04-0.4
Temperature 0.1-0.5 0.2 N/A 0.04-0.07
EMA momentum N/A 0.99-0.999 N/A 0.996-1.0

참고 자료

자료 링크
SimCLR 논문 https://arxiv.org/abs/2002.05709
MoCo 논문 https://arxiv.org/abs/1911.05722
BYOL 논문 https://arxiv.org/abs/2006.07733
MAE 논문 https://arxiv.org/abs/2111.06377
DINO 논문 https://arxiv.org/abs/2104.14294
DINOv2 논문 https://arxiv.org/abs/2304.07193
CLIP 논문 https://arxiv.org/abs/2103.00020
SSL Survey (Gui et al.) https://arxiv.org/abs/2301.05712
NeurIPS 2024 SSL Workshop https://sslneurips2024.github.io/
lightly (SSL 라이브러리) https://github.com/lightly-ai/lightly
solo-learn (SSL 벤치마크) https://github.com/vturrisi/solo-learn