콘텐츠로 이동
Data Prep
상세

Test-Time Adaptation (TTA)

개요

Test-Time Adaptation(TTA)은 사전 학습된 모델을 테스트 시점에 레이블 없는 데이터에 적응시켜 분포 이동(distribution shift)으로 인한 성능 저하를 완화하는 패러다임이다.

핵심 개념

왜 TTA가 필요한가?

실제 환경에서 모델을 배포하면 학습 데이터와 다른 분포의 데이터를 마주한다:

  • 조명 조건 변화 (이미지)
  • 새로운 방언/표현 (텍스트)
  • 센서 노화/교체 (IoT)
  • 계절/시간 변화 (시계열)

TTA는 소스 데이터 없이 타겟 도메인의 비지도 데이터만으로 모델을 조정한다는 점에서 기존 Domain Adaptation과 구별된다.

메타 정보

항목 내용
분야 Domain Adaptation, Transfer Learning, Robustness
핵심 논문 Liang et al., IJCV 2024 (Survey)
관련 학회 NeurIPS, ICML, ICLR, CVPR, ECCV
GitHub github.com/tim-learn/awesome-test-time-adaptation
최초 제안 TTT (Sun et al., ICML 2020)

문제 정의

학습 데이터 분포 \(P_{source}(X, Y)\)와 테스트 데이터 분포 \(P_{target}(X, Y)\)가 다를 때:

분포 이동 유형 수식 예시
Covariate Shift \(P_S(X) \neq P_T(X)\), \(P(Y\|X)\) 동일 밝은 이미지 → 어두운 이미지
Label Shift \(P_S(Y) \neq P_T(Y)\) 클래스 비율 변화
Concept Drift \(P(Y\|X)\) 변화 단어 의미 변화 (시간)

TTA의 목표: 테스트 시점에 \(P_{target}(X)\)에 접근하여 모델 \(f_\theta\)를 적응시킴


TTA 분류 체계

summary diagram

1. Source-Free Domain Adaptation (SFDA)

소스 데이터 없이 사전 학습된 모델만으로 타겟 도메인에 적응한다.

방법 핵심 아이디어 학회 코드
SHOT 정보 최대화 + 가상 레이블링 ICML 2020 GitHub
NRC 이웃 기반 클러스터링 NeurIPS 2021 GitHub
AdaContrast 대조 학습 기반 적응 CVPR 2022 GitHub

2. Test-Time Batch Adaptation (TTBA)

배치 단위로 테스트 데이터에 적응한다.

방법 핵심 아이디어 학회 코드
TENT 엔트로피 최소화로 BN 업데이트 ICLR 2021 GitHub
MEMO 다양한 증강에 대한 예측 일관성 NeurIPS 2022 GitHub
T3A 프로토타입 기반 분류기 조정 NeurIPS 2021 GitHub

3. Online Test-Time Adaptation (OTTA)

스트리밍 데이터에 대한 지속적 적응이다.

방법 핵심 아이디어 학회 코드
CoTTA 지속적 적응 + 망각 방지 CVPR 2022 GitHub
EATA 효율적 안티-포겟팅 적응 ICML 2022 GitHub
SAR 신뢰도 기반 샘플 선택 ICLR 2023 GitHub

4. Test-Time Instance Adaptation (TTIA)

개별 샘플 단위로 적응한다 (single sample adaptation).

방법 핵심 아이디어 학회 코드
TTT 자기지도 보조 태스크 ICML 2020 GitHub
TTT++ 대조 학습 보조 태스크 NeurIPS 2021 -
DDA 확산 모델 기반 적응 CVPR 2023 GitHub

핵심 알고리즘 상세

TENT (Test-time ENTropy minimization)

가장 기본적이고 널리 사용되는 TTA 방법이다.

summary diagram

알고리즘 수도코드:

Algorithm: TENT
Input: Pre-trained model f_θ, test batch X_t
Output: Adapted predictions

1. Forward pass: ŷ = f_θ(X_t)
2. Compute entropy: H(ŷ) = -Σ ŷ log(ŷ)
3. Update only BatchNorm parameters:
   θ_BN ← θ_BN - η∇H(ŷ)
4. Return predictions with adapted BN

Python 구현:

import torch
import torch.nn as nn
from typing import Iterator

class TENT:
    """
    Test-time Entropy minimization

    BatchNorm 파라미터만 업데이트하여 테스트 시점 적응 수행

    Reference: Wang et al., "TENT: Fully Test-Time Adaptation 
               by Entropy Minimization", ICLR 2021
    """

    def __init__(
        self, 
        model: nn.Module, 
        optimizer: torch.optim.Optimizer,
        steps: int = 1,
        episodic: bool = False
    ):
        """
        Args:
            model: 사전 학습된 모델
            optimizer: BN 파라미터용 옵티마이저
            steps: 각 배치당 적응 스텝 수
            episodic: True면 각 배치 후 모델 리셋
        """
        self.model = model
        self.optimizer = optimizer
        self.steps = steps
        self.episodic = episodic

        # 원본 상태 저장 (episodic용)
        if episodic:
            self.model_state = model.state_dict()
            self.optimizer_state = optimizer.state_dict()

        # BatchNorm 설정
        self._configure_model()

    def _configure_model(self):
        """BatchNorm 레이어만 학습 가능하게 설정"""
        self.model.train()  # train mode (BN 통계량 업데이트용)
        self.model.requires_grad_(False)  # 전체 freeze

        for module in self.model.modules():
            if isinstance(module, (nn.BatchNorm1d, nn.BatchNorm2d, nn.BatchNorm3d)):
                # BN 파라미터만 학습
                module.requires_grad_(True)
                # 배치 통계량 사용 (running stats 대신)
                module.track_running_stats = False
                module.running_mean = None
                module.running_var = None

    def reset(self):
        """모델을 원본 상태로 리셋 (episodic adaptation용)"""
        if self.episodic:
            self.model.load_state_dict(self.model_state, strict=True)
            self.optimizer.load_state_dict(self.optimizer_state)
            self._configure_model()

    @torch.enable_grad()
    def forward(self, x: torch.Tensor) -> torch.Tensor:
        """
        테스트 배치에 대해 적응 후 예측

        Args:
            x: 입력 텐서 (batch_size, ...)

        Returns:
            적응된 모델의 예측
        """
        if self.episodic:
            self.reset()

        for _ in range(self.steps):
            outputs = self.model(x)
            loss = self.softmax_entropy(outputs).mean()

            self.optimizer.zero_grad()
            loss.backward()
            self.optimizer.step()

        return self.model(x)

    @staticmethod
    def softmax_entropy(logits: torch.Tensor) -> torch.Tensor:
        """
        Softmax 엔트로피 계산: H(p) = -Σ p log(p)

        낮은 엔트로피 = 높은 확신 = 좋은 예측
        """
        probs = logits.softmax(dim=1)
        log_probs = logits.log_softmax(dim=1)
        return -(probs * log_probs).sum(dim=1)


def setup_tent(model: nn.Module, lr: float = 0.001) -> TENT:
    """TENT 설정 헬퍼 함수"""
    # BN 파라미터만 수집
    params = []
    for module in model.modules():
        if isinstance(module, (nn.BatchNorm1d, nn.BatchNorm2d, nn.BatchNorm3d)):
            params.extend(module.parameters())

    optimizer = torch.optim.SGD(params, lr=lr, momentum=0.9)
    return TENT(model, optimizer, steps=1)


# 사용 예시
def adapt_and_predict(model, test_loader, device='cuda'):
    tent = setup_tent(model, lr=0.001)

    all_predictions = []
    all_labels = []

    for x, y in test_loader:
        x = x.to(device)

        # 적응 및 예측
        with torch.no_grad():
            outputs = tent.forward(x)
            predictions = outputs.argmax(dim=1)

        all_predictions.append(predictions.cpu())
        all_labels.append(y)

    predictions = torch.cat(all_predictions)
    labels = torch.cat(all_labels)
    accuracy = (predictions == labels).float().mean()

    return accuracy.item()

CoTTA (Continual Test-Time Adaptation)

지속적인 도메인 이동에 대응하며 망각을 방지한다.

summary diagram

핵심 기법:

import torch
import torch.nn as nn
import copy

class CoTTA:
    """
    Continual Test-Time Adaptation

    지속적 도메인 변화에서 망각을 방지하며 적응

    Reference: Wang et al., "Continual Test-Time Domain Adaptation", CVPR 2022
    """

    def __init__(
        self,
        model: nn.Module,
        optimizer: torch.optim.Optimizer,
        ema_decay: float = 0.999,
        restore_prob: float = 0.01,
        augmentation_fn = None
    ):
        self.model = model
        self.model_ema = copy.deepcopy(model)  # Teacher (EMA)
        self.model_anchor = copy.deepcopy(model)  # Source model (복원용)
        self.optimizer = optimizer
        self.ema_decay = ema_decay
        self.restore_prob = restore_prob
        self.augmentation_fn = augmentation_fn

        # Anchor model 고정
        for param in self.model_anchor.parameters():
            param.requires_grad = False

    def update_ema(self):
        """Teacher model EMA 업데이트"""
        with torch.no_grad():
            for ema_param, param in zip(
                self.model_ema.parameters(), 
                self.model.parameters()
            ):
                ema_param.data = (
                    self.ema_decay * ema_param.data + 
                    (1 - self.ema_decay) * param.data
                )

    def stochastic_restore(self):
        """일부 파라미터를 소스 모델로 복원"""
        for (name, param), anchor_param in zip(
            self.model.named_parameters(),
            self.model_anchor.parameters()
        ):
            # 확률적으로 복원
            mask = torch.rand_like(param) < self.restore_prob
            param.data = torch.where(mask, anchor_param.data, param.data)

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        """
        적응 및 예측
        """
        # 증강된 예측들의 평균으로 가상 레이블 생성
        with torch.no_grad():
            if self.augmentation_fn is not None:
                # 여러 증강에 대한 Teacher 예측 앙상블
                aug_outputs = []
                for _ in range(4):  # 4개 증강
                    x_aug = self.augmentation_fn(x)
                    out = self.model_ema(x_aug)
                    aug_outputs.append(out.softmax(dim=1))

                pseudo_labels = torch.stack(aug_outputs).mean(dim=0)
            else:
                pseudo_labels = self.model_ema(x).softmax(dim=1)

        # Student 모델 학습
        outputs = self.model(x)
        loss = self.soft_cross_entropy(outputs, pseudo_labels)

        self.optimizer.zero_grad()
        loss.backward()
        self.optimizer.step()

        # EMA 업데이트
        self.update_ema()

        # 확률적 복원 (망각 방지)
        self.stochastic_restore()

        return outputs

    @staticmethod
    def soft_cross_entropy(pred, target):
        """Soft label cross entropy"""
        log_pred = pred.log_softmax(dim=1)
        return -(target * log_pred).sum(dim=1).mean()

T3A (Test-Time Template Adjustment)

프로토타입 기반으로 분류기를 조정한다.

import torch
import torch.nn.functional as F

class T3A:
    """
    Test-Time Template Adjuster

    클래스별 프로토타입을 온라인으로 업데이트하여 분류기 조정

    Reference: Iwasawa & Matsuo, "Test-Time Classifier Adjustment 
               Module for Model-Agnostic Domain Generalization", NeurIPS 2021
    """

    def __init__(
        self, 
        model: torch.nn.Module,
        num_classes: int,
        feature_dim: int,
        filter_k: int = 100
    ):
        """
        Args:
            model: 사전 학습된 모델 (feature extractor + classifier)
            num_classes: 클래스 수
            feature_dim: 피처 차원
            filter_k: 프로토타입 계산에 사용할 최근 샘플 수
        """
        self.model = model
        self.num_classes = num_classes
        self.filter_k = filter_k

        # 클래스별 프로토타입 저장소
        self.prototypes = torch.zeros(num_classes, feature_dim)
        self.prototype_counts = torch.zeros(num_classes)

        # 최근 피처 저장 (filter_k개 유지)
        self.feature_bank = []
        self.label_bank = []

        self.model.eval()

    def get_features(self, x: torch.Tensor) -> torch.Tensor:
        """
        모델에서 피처 추출 (classifier 직전 레이어)

        Note: 모델 구조에 따라 수정 필요
        """
        # ResNet 예시
        with torch.no_grad():
            x = self.model.conv1(x)
            x = self.model.bn1(x)
            x = self.model.relu(x)
            x = self.model.maxpool(x)
            x = self.model.layer1(x)
            x = self.model.layer2(x)
            x = self.model.layer3(x)
            x = self.model.layer4(x)
            x = self.model.avgpool(x)
            features = x.flatten(1)
        return features

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        """
        프로토타입 기반 예측
        """
        with torch.no_grad():
            # 피처 추출
            features = self.get_features(x)
            device = features.device

            # 초기 예측으로 가상 레이블 생성
            initial_logits = self.model(x)
            pseudo_labels = initial_logits.argmax(dim=1)

            # 프로토타입을 device로 이동
            self.prototypes = self.prototypes.to(device)
            self.prototype_counts = self.prototype_counts.to(device)

            # 프로토타입 업데이트
            for feat, label in zip(features, pseudo_labels):
                self.prototypes[label] += feat
                self.prototype_counts[label] += 1

                # 피처 뱅크에 저장
                self.feature_bank.append(feat.cpu())
                self.label_bank.append(label.cpu())

                # 오래된 샘플 제거
                if len(self.feature_bank) > self.filter_k:
                    old_feat = self.feature_bank.pop(0)
                    old_label = self.label_bank.pop(0)
                    self.prototypes[old_label] -= old_feat.to(device)
                    self.prototype_counts[old_label] -= 1

            # 정규화된 프로토타입 계산
            valid_mask = self.prototype_counts > 0
            normed_prototypes = torch.zeros_like(self.prototypes)
            normed_prototypes[valid_mask] = F.normalize(
                self.prototypes[valid_mask] / 
                self.prototype_counts[valid_mask].unsqueeze(1),
                dim=1
            )

            # 코사인 유사도 기반 예측
            normed_features = F.normalize(features, dim=1)
            adjusted_logits = normed_features @ normed_prototypes.T

        return adjusted_logits

    def reset(self):
        """프로토타입 초기화"""
        self.prototypes.zero_()
        self.prototype_counts.zero_()
        self.feature_bank.clear()
        self.label_bank.clear()

실험 프레임워크

벤치마크 평가 코드

import torch
import copy
from torchvision import datasets, transforms
from typing import Callable, Dict, List

def evaluate_tta(
    model: torch.nn.Module,
    tta_method_cls: type,
    corruption_types: List[str],
    severities: List[int] = [1, 2, 3, 4, 5],
    data_root: str = './data',
    batch_size: int = 64,
    device: str = 'cuda'
) -> Dict[str, Dict[int, float]]:
    """
    ImageNet-C 스타일 벤치마크 평가

    Args:
        model: 사전 학습된 모델
        tta_method_cls: TTA 알고리즘 클래스
        corruption_types: 부패 유형 리스트
        severities: 부패 강도 리스트
        data_root: 데이터 경로
        batch_size: 배치 크기
        device: 디바이스

    Returns:
        corruption별, severity별 정확도 딕셔너리
    """
    results = {}

    transform = transforms.Compose([
        transforms.Resize(256),
        transforms.CenterCrop(224),
        transforms.ToTensor(),
        transforms.Normalize(
            mean=[0.485, 0.456, 0.406],
            std=[0.229, 0.224, 0.225]
        )
    ])

    for corruption in corruption_types:
        results[corruption] = {}

        for severity in severities:
            # 모델 복사 (각 corruption마다 리셋)
            model_copy = copy.deepcopy(model).to(device)

            # TTA 설정
            adapter = tta_method_cls(model_copy)

            # 데이터 로드
            dataset = datasets.ImageFolder(
                f'{data_root}/{corruption}/{severity}',
                transform=transform
            )
            loader = torch.utils.data.DataLoader(
                dataset, 
                batch_size=batch_size, 
                shuffle=False,
                num_workers=4
            )

            # 평가
            correct = 0
            total = 0

            for images, labels in loader:
                images = images.to(device)
                labels = labels.to(device)

                # TTA 적응 및 예측
                with torch.no_grad():
                    outputs = adapter.forward(images)
                    _, predicted = outputs.max(1)

                correct += (predicted == labels).sum().item()
                total += labels.size(0)

            accuracy = 100 * correct / total
            results[corruption][severity] = accuracy
            print(f'{corruption}-{severity}: {accuracy:.2f}%')

    return results


# 부패 유형 정의 (ImageNet-C 기준)
CORRUPTION_TYPES = {
    'noise': ['gaussian_noise', 'shot_noise', 'impulse_noise'],
    'blur': ['defocus_blur', 'glass_blur', 'motion_blur', 'zoom_blur'],
    'weather': ['snow', 'frost', 'fog', 'brightness'],
    'digital': ['contrast', 'elastic_transform', 'pixelate', 'jpeg_compression']
}

# 사용 예시
# results = evaluate_tta(
#     model=resnet50,
#     tta_method_cls=TENT,
#     corruption_types=CORRUPTION_TYPES['noise'],
#     data_root='./ImageNet-C'
# )

주요 벤치마크

데이터셋

데이터셋 도메인 이동 유형 클래스 수 용도
ImageNet-C 15가지 부패 유형 1000 자연 이미지
CIFAR-10-C 19가지 부패 유형 10 소규모 실험
Office-Home 4개 도메인 65 도메인 적응
DomainNet 6개 도메인 345 대규모 도메인
PACS 4개 도메인 (스타일) 7 도메인 일반화
Cityscapes-C 도시 환경 부패 19 Segmentation

성능 비교 (ImageNet-C, ResNet-50)

방법 Mean Error (%) 유형 추가 학습
Source Only 76.7 Baseline X
BN Adapt 65.4 Statistics X
TENT 62.2 Entropy O (BN만)
MEMO 59.8 Augmentation X
CoTTA 57.1 Continual O
SAR 55.8 Selective O

실무 적용 가이드

방법 선택 플로우

summary diagram

주의사항

문제 원인 해결책
배치 크기 의존성 TENT는 작은 배치에서 불안정 MEMO 사용 또는 배치 크기 증가
망각 문제 지속적 적응 시 원본 지식 손실 CoTTA, EATA의 restoration 사용
적대적 취약성 TTA는 적대적 공격에 취약 SAR의 신뢰도 필터링
계산 비용 추론 시간 증가 적응 스텝 수 제한

에러 케이스와 디버깅

# 흔한 문제 1: BN 통계량 불안정
# 해결: 배치 크기 확인, 충분한 샘플 확보
if batch_size < 16:
    print("Warning: Small batch may cause unstable BN stats")
    # MEMO나 instance normalization 고려

# 흔한 문제 2: 성능 하락
# 원인: 과도한 적응, 소스 지식 망각
def check_adaptation_quality(model, source_val_loader, device):
    """소스 데이터 성능 모니터링"""
    model.eval()
    correct = 0
    total = 0
    for x, y in source_val_loader:
        x, y = x.to(device), y.to(device)
        pred = model(x).argmax(1)
        correct += (pred == y).sum().item()
        total += y.size(0)

    acc = correct / total
    if acc < 0.5:  # 임계값
        print(f"Warning: Source accuracy dropped to {acc:.2%}")
        print("Consider: reducing learning rate or using restoration")
    return acc

# 흔한 문제 3: 메모리 부족
# 해결: 그래디언트 체크포인팅, mixed precision
from torch.cuda.amp import autocast, GradScaler

scaler = GradScaler()
with autocast():
    outputs = model(x)
    loss = compute_loss(outputs)
scaler.scale(loss).backward()
scaler.step(optimizer)
scaler.update()

최신 연구 동향 (2024-2025)

연구 방향 설명 대표 논문
Vision-Language TTA CLIP 기반 모델 적응 TPT, RLCF (2023)
LLM TTA 대규모 언어 모델 분포 이동 대응 -
Active TTA 제한된 인간 피드백 활용 AETTA (2024)
Robust TTA 적대적 환경 안정적 적응 Anti-Adv TTA (2024)
Multi-modal TTA 다중 모달리티 공동 적응 MM-TTA (2024)

참고문헌

논문 학회 핵심 기여
Liang et al., "Survey on TTA" IJCV 2024 종합 서베이
Wang et al., "TENT" ICLR 2021 엔트로피 최소화
Zhang et al., "MEMO" NeurIPS 2022 단일 샘플 적응
Wang et al., "CoTTA" CVPR 2022 연속 적응
Niu et al., "SAR" ICLR 2023 신뢰도 기반 선택
Iwasawa & Matsuo, "T3A" NeurIPS 2021 프로토타입 조정

추가 리소스: