Test-Time Adaptation (TTA)¶
개요¶
Test-Time Adaptation(TTA)은 사전 학습된 모델을 테스트 시점에 레이블 없는 데이터에 적응시켜 분포 이동(distribution shift)으로 인한 성능 저하를 완화하는 패러다임이다.
핵심 개념¶
왜 TTA가 필요한가?
실제 환경에서 모델을 배포하면 학습 데이터와 다른 분포의 데이터를 마주한다:
- 조명 조건 변화 (이미지)
- 새로운 방언/표현 (텍스트)
- 센서 노화/교체 (IoT)
- 계절/시간 변화 (시계열)
TTA는 소스 데이터 없이 타겟 도메인의 비지도 데이터만으로 모델을 조정한다는 점에서 기존 Domain Adaptation과 구별된다.
메타 정보¶
| 항목 | 내용 |
|---|---|
| 분야 | Domain Adaptation, Transfer Learning, Robustness |
| 핵심 논문 | Liang et al., IJCV 2024 (Survey) |
| 관련 학회 | NeurIPS, ICML, ICLR, CVPR, ECCV |
| GitHub | github.com/tim-learn/awesome-test-time-adaptation |
| 최초 제안 | TTT (Sun et al., ICML 2020) |
문제 정의¶
학습 데이터 분포 \(P_{source}(X, Y)\)와 테스트 데이터 분포 \(P_{target}(X, Y)\)가 다를 때:
| 분포 이동 유형 | 수식 | 예시 |
|---|---|---|
| Covariate Shift | \(P_S(X) \neq P_T(X)\), \(P(Y\|X)\) 동일 | 밝은 이미지 → 어두운 이미지 |
| Label Shift | \(P_S(Y) \neq P_T(Y)\) | 클래스 비율 변화 |
| Concept Drift | \(P(Y\|X)\) 변화 | 단어 의미 변화 (시간) |
TTA의 목표: 테스트 시점에 \(P_{target}(X)\)에 접근하여 모델 \(f_\theta\)를 적응시킴
TTA 분류 체계¶
1. Source-Free Domain Adaptation (SFDA)¶
소스 데이터 없이 사전 학습된 모델만으로 타겟 도메인에 적응한다.
| 방법 | 핵심 아이디어 | 학회 | 코드 |
|---|---|---|---|
| SHOT | 정보 최대화 + 가상 레이블링 | ICML 2020 | GitHub |
| NRC | 이웃 기반 클러스터링 | NeurIPS 2021 | GitHub |
| AdaContrast | 대조 학습 기반 적응 | CVPR 2022 | GitHub |
2. Test-Time Batch Adaptation (TTBA)¶
배치 단위로 테스트 데이터에 적응한다.
| 방법 | 핵심 아이디어 | 학회 | 코드 |
|---|---|---|---|
| TENT | 엔트로피 최소화로 BN 업데이트 | ICLR 2021 | GitHub |
| MEMO | 다양한 증강에 대한 예측 일관성 | NeurIPS 2022 | GitHub |
| T3A | 프로토타입 기반 분류기 조정 | NeurIPS 2021 | GitHub |
3. Online Test-Time Adaptation (OTTA)¶
스트리밍 데이터에 대한 지속적 적응이다.
| 방법 | 핵심 아이디어 | 학회 | 코드 |
|---|---|---|---|
| CoTTA | 지속적 적응 + 망각 방지 | CVPR 2022 | GitHub |
| EATA | 효율적 안티-포겟팅 적응 | ICML 2022 | GitHub |
| SAR | 신뢰도 기반 샘플 선택 | ICLR 2023 | GitHub |
4. Test-Time Instance Adaptation (TTIA)¶
개별 샘플 단위로 적응한다 (single sample adaptation).
| 방법 | 핵심 아이디어 | 학회 | 코드 |
|---|---|---|---|
| TTT | 자기지도 보조 태스크 | ICML 2020 | GitHub |
| TTT++ | 대조 학습 보조 태스크 | NeurIPS 2021 | - |
| DDA | 확산 모델 기반 적응 | CVPR 2023 | GitHub |
핵심 알고리즘 상세¶
TENT (Test-time ENTropy minimization)¶
가장 기본적이고 널리 사용되는 TTA 방법이다.
알고리즘 수도코드:
Algorithm: TENT
Input: Pre-trained model f_θ, test batch X_t
Output: Adapted predictions
1. Forward pass: ŷ = f_θ(X_t)
2. Compute entropy: H(ŷ) = -Σ ŷ log(ŷ)
3. Update only BatchNorm parameters:
θ_BN ← θ_BN - η∇H(ŷ)
4. Return predictions with adapted BN
Python 구현:
import torch
import torch.nn as nn
from typing import Iterator
class TENT:
"""
Test-time Entropy minimization
BatchNorm 파라미터만 업데이트하여 테스트 시점 적응 수행
Reference: Wang et al., "TENT: Fully Test-Time Adaptation
by Entropy Minimization", ICLR 2021
"""
def __init__(
self,
model: nn.Module,
optimizer: torch.optim.Optimizer,
steps: int = 1,
episodic: bool = False
):
"""
Args:
model: 사전 학습된 모델
optimizer: BN 파라미터용 옵티마이저
steps: 각 배치당 적응 스텝 수
episodic: True면 각 배치 후 모델 리셋
"""
self.model = model
self.optimizer = optimizer
self.steps = steps
self.episodic = episodic
# 원본 상태 저장 (episodic용)
if episodic:
self.model_state = model.state_dict()
self.optimizer_state = optimizer.state_dict()
# BatchNorm 설정
self._configure_model()
def _configure_model(self):
"""BatchNorm 레이어만 학습 가능하게 설정"""
self.model.train() # train mode (BN 통계량 업데이트용)
self.model.requires_grad_(False) # 전체 freeze
for module in self.model.modules():
if isinstance(module, (nn.BatchNorm1d, nn.BatchNorm2d, nn.BatchNorm3d)):
# BN 파라미터만 학습
module.requires_grad_(True)
# 배치 통계량 사용 (running stats 대신)
module.track_running_stats = False
module.running_mean = None
module.running_var = None
def reset(self):
"""모델을 원본 상태로 리셋 (episodic adaptation용)"""
if self.episodic:
self.model.load_state_dict(self.model_state, strict=True)
self.optimizer.load_state_dict(self.optimizer_state)
self._configure_model()
@torch.enable_grad()
def forward(self, x: torch.Tensor) -> torch.Tensor:
"""
테스트 배치에 대해 적응 후 예측
Args:
x: 입력 텐서 (batch_size, ...)
Returns:
적응된 모델의 예측
"""
if self.episodic:
self.reset()
for _ in range(self.steps):
outputs = self.model(x)
loss = self.softmax_entropy(outputs).mean()
self.optimizer.zero_grad()
loss.backward()
self.optimizer.step()
return self.model(x)
@staticmethod
def softmax_entropy(logits: torch.Tensor) -> torch.Tensor:
"""
Softmax 엔트로피 계산: H(p) = -Σ p log(p)
낮은 엔트로피 = 높은 확신 = 좋은 예측
"""
probs = logits.softmax(dim=1)
log_probs = logits.log_softmax(dim=1)
return -(probs * log_probs).sum(dim=1)
def setup_tent(model: nn.Module, lr: float = 0.001) -> TENT:
"""TENT 설정 헬퍼 함수"""
# BN 파라미터만 수집
params = []
for module in model.modules():
if isinstance(module, (nn.BatchNorm1d, nn.BatchNorm2d, nn.BatchNorm3d)):
params.extend(module.parameters())
optimizer = torch.optim.SGD(params, lr=lr, momentum=0.9)
return TENT(model, optimizer, steps=1)
# 사용 예시
def adapt_and_predict(model, test_loader, device='cuda'):
tent = setup_tent(model, lr=0.001)
all_predictions = []
all_labels = []
for x, y in test_loader:
x = x.to(device)
# 적응 및 예측
with torch.no_grad():
outputs = tent.forward(x)
predictions = outputs.argmax(dim=1)
all_predictions.append(predictions.cpu())
all_labels.append(y)
predictions = torch.cat(all_predictions)
labels = torch.cat(all_labels)
accuracy = (predictions == labels).float().mean()
return accuracy.item()
CoTTA (Continual Test-Time Adaptation)¶
지속적인 도메인 이동에 대응하며 망각을 방지한다.
핵심 기법:
import torch
import torch.nn as nn
import copy
class CoTTA:
"""
Continual Test-Time Adaptation
지속적 도메인 변화에서 망각을 방지하며 적응
Reference: Wang et al., "Continual Test-Time Domain Adaptation", CVPR 2022
"""
def __init__(
self,
model: nn.Module,
optimizer: torch.optim.Optimizer,
ema_decay: float = 0.999,
restore_prob: float = 0.01,
augmentation_fn = None
):
self.model = model
self.model_ema = copy.deepcopy(model) # Teacher (EMA)
self.model_anchor = copy.deepcopy(model) # Source model (복원용)
self.optimizer = optimizer
self.ema_decay = ema_decay
self.restore_prob = restore_prob
self.augmentation_fn = augmentation_fn
# Anchor model 고정
for param in self.model_anchor.parameters():
param.requires_grad = False
def update_ema(self):
"""Teacher model EMA 업데이트"""
with torch.no_grad():
for ema_param, param in zip(
self.model_ema.parameters(),
self.model.parameters()
):
ema_param.data = (
self.ema_decay * ema_param.data +
(1 - self.ema_decay) * param.data
)
def stochastic_restore(self):
"""일부 파라미터를 소스 모델로 복원"""
for (name, param), anchor_param in zip(
self.model.named_parameters(),
self.model_anchor.parameters()
):
# 확률적으로 복원
mask = torch.rand_like(param) < self.restore_prob
param.data = torch.where(mask, anchor_param.data, param.data)
def forward(self, x: torch.Tensor) -> torch.Tensor:
"""
적응 및 예측
"""
# 증강된 예측들의 평균으로 가상 레이블 생성
with torch.no_grad():
if self.augmentation_fn is not None:
# 여러 증강에 대한 Teacher 예측 앙상블
aug_outputs = []
for _ in range(4): # 4개 증강
x_aug = self.augmentation_fn(x)
out = self.model_ema(x_aug)
aug_outputs.append(out.softmax(dim=1))
pseudo_labels = torch.stack(aug_outputs).mean(dim=0)
else:
pseudo_labels = self.model_ema(x).softmax(dim=1)
# Student 모델 학습
outputs = self.model(x)
loss = self.soft_cross_entropy(outputs, pseudo_labels)
self.optimizer.zero_grad()
loss.backward()
self.optimizer.step()
# EMA 업데이트
self.update_ema()
# 확률적 복원 (망각 방지)
self.stochastic_restore()
return outputs
@staticmethod
def soft_cross_entropy(pred, target):
"""Soft label cross entropy"""
log_pred = pred.log_softmax(dim=1)
return -(target * log_pred).sum(dim=1).mean()
T3A (Test-Time Template Adjustment)¶
프로토타입 기반으로 분류기를 조정한다.
import torch
import torch.nn.functional as F
class T3A:
"""
Test-Time Template Adjuster
클래스별 프로토타입을 온라인으로 업데이트하여 분류기 조정
Reference: Iwasawa & Matsuo, "Test-Time Classifier Adjustment
Module for Model-Agnostic Domain Generalization", NeurIPS 2021
"""
def __init__(
self,
model: torch.nn.Module,
num_classes: int,
feature_dim: int,
filter_k: int = 100
):
"""
Args:
model: 사전 학습된 모델 (feature extractor + classifier)
num_classes: 클래스 수
feature_dim: 피처 차원
filter_k: 프로토타입 계산에 사용할 최근 샘플 수
"""
self.model = model
self.num_classes = num_classes
self.filter_k = filter_k
# 클래스별 프로토타입 저장소
self.prototypes = torch.zeros(num_classes, feature_dim)
self.prototype_counts = torch.zeros(num_classes)
# 최근 피처 저장 (filter_k개 유지)
self.feature_bank = []
self.label_bank = []
self.model.eval()
def get_features(self, x: torch.Tensor) -> torch.Tensor:
"""
모델에서 피처 추출 (classifier 직전 레이어)
Note: 모델 구조에 따라 수정 필요
"""
# ResNet 예시
with torch.no_grad():
x = self.model.conv1(x)
x = self.model.bn1(x)
x = self.model.relu(x)
x = self.model.maxpool(x)
x = self.model.layer1(x)
x = self.model.layer2(x)
x = self.model.layer3(x)
x = self.model.layer4(x)
x = self.model.avgpool(x)
features = x.flatten(1)
return features
def forward(self, x: torch.Tensor) -> torch.Tensor:
"""
프로토타입 기반 예측
"""
with torch.no_grad():
# 피처 추출
features = self.get_features(x)
device = features.device
# 초기 예측으로 가상 레이블 생성
initial_logits = self.model(x)
pseudo_labels = initial_logits.argmax(dim=1)
# 프로토타입을 device로 이동
self.prototypes = self.prototypes.to(device)
self.prototype_counts = self.prototype_counts.to(device)
# 프로토타입 업데이트
for feat, label in zip(features, pseudo_labels):
self.prototypes[label] += feat
self.prototype_counts[label] += 1
# 피처 뱅크에 저장
self.feature_bank.append(feat.cpu())
self.label_bank.append(label.cpu())
# 오래된 샘플 제거
if len(self.feature_bank) > self.filter_k:
old_feat = self.feature_bank.pop(0)
old_label = self.label_bank.pop(0)
self.prototypes[old_label] -= old_feat.to(device)
self.prototype_counts[old_label] -= 1
# 정규화된 프로토타입 계산
valid_mask = self.prototype_counts > 0
normed_prototypes = torch.zeros_like(self.prototypes)
normed_prototypes[valid_mask] = F.normalize(
self.prototypes[valid_mask] /
self.prototype_counts[valid_mask].unsqueeze(1),
dim=1
)
# 코사인 유사도 기반 예측
normed_features = F.normalize(features, dim=1)
adjusted_logits = normed_features @ normed_prototypes.T
return adjusted_logits
def reset(self):
"""프로토타입 초기화"""
self.prototypes.zero_()
self.prototype_counts.zero_()
self.feature_bank.clear()
self.label_bank.clear()
실험 프레임워크¶
벤치마크 평가 코드¶
import torch
import copy
from torchvision import datasets, transforms
from typing import Callable, Dict, List
def evaluate_tta(
model: torch.nn.Module,
tta_method_cls: type,
corruption_types: List[str],
severities: List[int] = [1, 2, 3, 4, 5],
data_root: str = './data',
batch_size: int = 64,
device: str = 'cuda'
) -> Dict[str, Dict[int, float]]:
"""
ImageNet-C 스타일 벤치마크 평가
Args:
model: 사전 학습된 모델
tta_method_cls: TTA 알고리즘 클래스
corruption_types: 부패 유형 리스트
severities: 부패 강도 리스트
data_root: 데이터 경로
batch_size: 배치 크기
device: 디바이스
Returns:
corruption별, severity별 정확도 딕셔너리
"""
results = {}
transform = transforms.Compose([
transforms.Resize(256),
transforms.CenterCrop(224),
transforms.ToTensor(),
transforms.Normalize(
mean=[0.485, 0.456, 0.406],
std=[0.229, 0.224, 0.225]
)
])
for corruption in corruption_types:
results[corruption] = {}
for severity in severities:
# 모델 복사 (각 corruption마다 리셋)
model_copy = copy.deepcopy(model).to(device)
# TTA 설정
adapter = tta_method_cls(model_copy)
# 데이터 로드
dataset = datasets.ImageFolder(
f'{data_root}/{corruption}/{severity}',
transform=transform
)
loader = torch.utils.data.DataLoader(
dataset,
batch_size=batch_size,
shuffle=False,
num_workers=4
)
# 평가
correct = 0
total = 0
for images, labels in loader:
images = images.to(device)
labels = labels.to(device)
# TTA 적응 및 예측
with torch.no_grad():
outputs = adapter.forward(images)
_, predicted = outputs.max(1)
correct += (predicted == labels).sum().item()
total += labels.size(0)
accuracy = 100 * correct / total
results[corruption][severity] = accuracy
print(f'{corruption}-{severity}: {accuracy:.2f}%')
return results
# 부패 유형 정의 (ImageNet-C 기준)
CORRUPTION_TYPES = {
'noise': ['gaussian_noise', 'shot_noise', 'impulse_noise'],
'blur': ['defocus_blur', 'glass_blur', 'motion_blur', 'zoom_blur'],
'weather': ['snow', 'frost', 'fog', 'brightness'],
'digital': ['contrast', 'elastic_transform', 'pixelate', 'jpeg_compression']
}
# 사용 예시
# results = evaluate_tta(
# model=resnet50,
# tta_method_cls=TENT,
# corruption_types=CORRUPTION_TYPES['noise'],
# data_root='./ImageNet-C'
# )
주요 벤치마크¶
데이터셋¶
| 데이터셋 | 도메인 이동 유형 | 클래스 수 | 용도 |
|---|---|---|---|
| ImageNet-C | 15가지 부패 유형 | 1000 | 자연 이미지 |
| CIFAR-10-C | 19가지 부패 유형 | 10 | 소규모 실험 |
| Office-Home | 4개 도메인 | 65 | 도메인 적응 |
| DomainNet | 6개 도메인 | 345 | 대규모 도메인 |
| PACS | 4개 도메인 (스타일) | 7 | 도메인 일반화 |
| Cityscapes-C | 도시 환경 부패 | 19 | Segmentation |
성능 비교 (ImageNet-C, ResNet-50)¶
| 방법 | Mean Error (%) | 유형 | 추가 학습 |
|---|---|---|---|
| Source Only | 76.7 | Baseline | X |
| BN Adapt | 65.4 | Statistics | X |
| TENT | 62.2 | Entropy | O (BN만) |
| MEMO | 59.8 | Augmentation | X |
| CoTTA | 57.1 | Continual | O |
| SAR | 55.8 | Selective | O |
실무 적용 가이드¶
방법 선택 플로우¶
주의사항¶
| 문제 | 원인 | 해결책 |
|---|---|---|
| 배치 크기 의존성 | TENT는 작은 배치에서 불안정 | MEMO 사용 또는 배치 크기 증가 |
| 망각 문제 | 지속적 적응 시 원본 지식 손실 | CoTTA, EATA의 restoration 사용 |
| 적대적 취약성 | TTA는 적대적 공격에 취약 | SAR의 신뢰도 필터링 |
| 계산 비용 | 추론 시간 증가 | 적응 스텝 수 제한 |
에러 케이스와 디버깅¶
# 흔한 문제 1: BN 통계량 불안정
# 해결: 배치 크기 확인, 충분한 샘플 확보
if batch_size < 16:
print("Warning: Small batch may cause unstable BN stats")
# MEMO나 instance normalization 고려
# 흔한 문제 2: 성능 하락
# 원인: 과도한 적응, 소스 지식 망각
def check_adaptation_quality(model, source_val_loader, device):
"""소스 데이터 성능 모니터링"""
model.eval()
correct = 0
total = 0
for x, y in source_val_loader:
x, y = x.to(device), y.to(device)
pred = model(x).argmax(1)
correct += (pred == y).sum().item()
total += y.size(0)
acc = correct / total
if acc < 0.5: # 임계값
print(f"Warning: Source accuracy dropped to {acc:.2%}")
print("Consider: reducing learning rate or using restoration")
return acc
# 흔한 문제 3: 메모리 부족
# 해결: 그래디언트 체크포인팅, mixed precision
from torch.cuda.amp import autocast, GradScaler
scaler = GradScaler()
with autocast():
outputs = model(x)
loss = compute_loss(outputs)
scaler.scale(loss).backward()
scaler.step(optimizer)
scaler.update()
최신 연구 동향 (2024-2025)¶
| 연구 방향 | 설명 | 대표 논문 |
|---|---|---|
| Vision-Language TTA | CLIP 기반 모델 적응 | TPT, RLCF (2023) |
| LLM TTA | 대규모 언어 모델 분포 이동 대응 | - |
| Active TTA | 제한된 인간 피드백 활용 | AETTA (2024) |
| Robust TTA | 적대적 환경 안정적 적응 | Anti-Adv TTA (2024) |
| Multi-modal TTA | 다중 모달리티 공동 적응 | MM-TTA (2024) |
참고문헌¶
| 논문 | 학회 | 핵심 기여 |
|---|---|---|
| Liang et al., "Survey on TTA" | IJCV 2024 | 종합 서베이 |
| Wang et al., "TENT" | ICLR 2021 | 엔트로피 최소화 |
| Zhang et al., "MEMO" | NeurIPS 2022 | 단일 샘플 적응 |
| Wang et al., "CoTTA" | CVPR 2022 | 연속 적응 |
| Niu et al., "SAR" | ICLR 2023 | 신뢰도 기반 선택 |
| Iwasawa & Matsuo, "T3A" | NeurIPS 2021 | 프로토타입 조정 |
추가 리소스: