콘텐츠로 이동
Data Prep
상세

Domain Generalization (도메인 일반화)

메타 정보

항목 내용
분류 Robustness / Out-of-Distribution / Transfer Learning
핵심 논문 "Invariant Risk Minimization" (Arjovsky et al., 2019), "Domain Generalization: A Survey" (Zhou et al., IEEE TPAMI 2022), "In Search of Lost Domain Generalization" (Gulrajani & Lopez-Paz, ICLR 2021)
주요 저자 Martin Arjovsky (IRM); Kaiyang Zhou, Ziwei Liu, Chen Change Loy (Survey); Ishaan Gulrajani, David Lopez-Paz (DomainBed)
핵심 개념 학습 시 관측하지 못한 도메인에서도 잘 동작하는 모델 구축 -- 도메인 불변 표현 학습
관련 분야 Domain Adaptation, Transfer Learning, Causal Inference, Distributionally Robust Optimization

정의

Domain Generalization (DG)은 여러 소스 도메인의 데이터로 학습하여, 학습 시 전혀 접하지 못한 타겟 도메인에서도 강건하게 일반화되는 모델을 구축하는 문제다. Domain Adaptation과 달리 타겟 도메인 데이터에 접근할 수 없다는 점이 핵심 제약이다.

Domain Adaptation vs Domain Generalization:

Domain Adaptation:
  Source domains D_s + Target domain D_t (unlabeled) --> 모델 적응
  타겟 도메인 데이터 접근 가능

Domain Generalization:
  Source domains {D_1, D_2, ..., D_K} --> 모델 학습
  타겟 도메인 D_t 접근 불가 (학습 시 존재하지 않음)
  목표: argmin_theta max_{D_t} L(f_theta, D_t)

문제 설정

형식적 정의:

소스 도메인: S = {D_1, D_2, ..., D_K}, 각 D_k = {(x_i^k, y_i^k)}
각 도메인은 서로 다른 분포: P_k(X, Y) != P_j(X, Y)

목표: 학습되지 않은 도메인 D_t ~ P_t에서 일반화
  f* = argmin_theta E_{(x,y)~P_t} [L(f_theta(x), y)]

도메인 시프트 유형:
  +---------------------------------------------------+
  | Covariate Shift: P_s(X) != P_t(X), P(Y|X) 동일   |
  | Label Shift: P_s(Y) != P_t(Y), P(X|Y) 동일       |
  | Concept Shift: P_s(Y|X) != P_t(Y|X)              |
  | 복합 시프트: 위 조합                               |
  +---------------------------------------------------+

주요 접근법

Domain Generalization 분류 체계:

DG Methods
  |
  +-- Domain Alignment (도메인 정렬)
  |     CORAL, MMD, Adversarial (DANN)
  |
  +-- Data Augmentation (데이터 증강)
  |     Mixup, CrossGrad, Style Transfer
  |
  +-- Meta-Learning (메타 학습)
  |     MLDG, MetaReg, MASF
  |
  +-- Invariant Learning (불변 표현 학습)
  |     IRM, V-REx, FISH, Fishr
  |
  +-- Distributionally Robust Optimization
  |     GroupDRO, EQRM
  |
  +-- Foundation Model Adaptation
        CLIP zero-shot, Prompt Tuning, Adapter

1. Domain Alignment (도메인 정렬)

여러 소스 도메인의 특징 분포를 정렬하여 도메인 불변 표현을 학습한다.

CORAL (Correlation Alignment)

Sun & Saenko (2016). 소스와 타겟의 2차 통계량(공분산)을 정렬한다.

CORAL Loss:

  L_CORAL = (1 / 4d^2) * ||C_s - C_t||_F^2

  C_s = (1/(n_s-1)) * (D_s^T D_s - (1/n_s)(1^T D_s)^T (1^T D_s))
  C_t: 동일하게 타겟에 대해 계산

  Deep CORAL: CNN의 마지막 FC layer 출력에 CORAL loss 적용
  Total Loss = L_classification + lambda * L_CORAL

MMD (Maximum Mean Discrepancy)

Li et al. (2018). 커널 공간에서 두 분포의 평균 차이를 최소화한다.

MMD^2(P, Q) = E[k(x,x')] - 2E[k(x,y)] + E[k(y,y')]
  x, x' ~ P,  y, y' ~ Q
  k: 커널 함수 (주로 Gaussian RBF)

다중 도메인 확장:
  L_MMD = sum_{i<j} MMD^2(D_i, D_j)

2. Invariant Risk Minimization (IRM)

Arjovsky et al. (2019). 모든 환경에서 동시에 최적인 불변 예측자를 학습한다.

IRM 목표:

  min_{Phi, w} sum_{e in E_tr} R^e(w . Phi)
  subject to: w in argmin_{w'} R^e(w' . Phi), for all e in E_tr

  E_tr: 학습 환경(도메인) 집합
  Phi: 특징 추출기
  w: 선형 분류기
  R^e: 환경 e에서의 위험

  실용적 변형 (IRMv1):
  min_{Phi} sum_{e in E_tr} [R^e(Phi) + lambda * ||grad_{w|w=1.0} R^e(w . Phi)||^2]

  두 번째 항: 각 환경에서 w=1.0이 최적이 되도록 강제
  --> Phi가 불변 특징만 추출하도록 유도

핵심 아이디어:
  +-----------------------------------------------+
  | 환경 1: 소 + 초원 배경 --> "소" 예측           |
  | 환경 2: 소 + 해변 배경 --> "소" 예측           |
  |                                               |
  | 배경은 spurious correlation (가짜 상관)        |
  | IRM: 모든 환경에서 불변인 "소의 형태"만 활용   |
  +-----------------------------------------------+

3. V-REx (Variance Risk Extrapolation)

Krueger et al. (ICML 2021). 도메인 간 손실 분산을 최소화하여 일반화한다.

V-REx Objective:

  L = (1/K) * sum_k R^k(theta) + beta * Var({R^1, ..., R^K})

  첫째 항: 평균 위험 최소화
  둘째 항: 도메인 간 위험 분산 최소화
  beta: 정규화 강도

  직관: 모든 도메인에서 비슷한 손실 --> 도메인 불변 특징 사용

4. GroupDRO (Distributionally Robust Optimization)

Sagawa et al. (ICLR 2020). 최악 그룹의 성능을 최적화한다.

GroupDRO Objective:

  min_theta max_{q in Delta_K} sum_{k=1}^{K} q_k * R^k(theta)

  Delta_K: K-simplex (가중치 합 = 1)
  q: 도메인 가중치 (최악 도메인에 높은 가중치 부여)

  온라인 업데이트:
  q_k^{t+1} = q_k^t * exp(eta * R^k(theta_t))  (정규화)

  결과: worst-group accuracy 크게 개선

5. Data Augmentation 기반

Mixup 변형들

도메인 간 데이터 혼합:

Vanilla Mixup:
  x_mix = lambda * x_i + (1-lambda) * x_j
  y_mix = lambda * y_i + (1-lambda) * y_j

CrossGrad (Shankar et al., ICLR 2018):
  도메인 분류기의 gradient 방향으로 입력 변형
  x' = x + epsilon * sign(grad_x L_domain(x, d))

Style Transfer 기반:
  source 이미지의 content + 다른 도메인의 style 결합
  AdaIN(z) = sigma_target * (z - mu_source) / sigma_source + mu_target

6. Meta-Learning 기반

MLDG (Meta-Learning Domain Generalization)

Li et al. (AAAI 2018). 도메인 분할을 meta-train/meta-test로 활용한다.

MLDG 알고리즘:

반복:
  1. 소스 도메인을 S(meta-train)과 V(meta-test)로 분할
  2. S에서 일반 학습: theta' = theta - alpha * grad L_S(theta)
  3. V에서 메타 테스트: L_V(theta')
  4. 메타 업데이트: theta = theta - beta * grad [L_S(theta) + gamma * L_V(theta')]

  이중 최적화로 새 도메인에 대한 일반화 능력 직접 최적화

7. Foundation Model 활용

최근 CLIP 등 사전학습된 대규모 모델의 zero-shot 성능이 기존 DG 기법을 능가하는 경우가 많다.

Foundation Model 기반 DG:

Zero-shot CLIP:
  텍스트 프롬프트로 분류 --> DomainBed 벤치마크에서 경쟁력 있는 성능

Prompt Tuning (CoOp, CoCoOp):
  고정 텍스트 대신 학습 가능한 soft prompt 사용
  CoOp: 도메인 일반화 성능 부족
  CoCoOp: 이미지 조건부 프롬프트로 개선

LP-FT (Linear Probing then Fine-Tuning):
  1단계: linear probe로 classifier head 학습
  2단계: 전체 네트워크 fine-tuning
  feature distortion 감소 효과

주요 벤치마크

벤치마크 도메인 수 특징
PACS 4 (Photo, Art, Cartoon, Sketch) 7 클래스, 이미지 스타일 시프트
VLCS 4 (VOC, LabelMe, Caltech, Sun) 5 클래스, 데이터셋 간 시프트
OfficeHome 4 (Art, Clipart, Product, Real) 65 클래스, 상업 이미지
TerraIncognita 4 (위치별 카메라 트랩) 10 클래스, 지리적 시프트
DomainNet 6 (Clipart, Infograph, Painting, ...) 345 클래스, 대규모
Wilds 다수 실제 분포 시프트 벤치마크

DomainBed 프레임워크

Gulrajani & Lopez-Paz (ICLR 2021). DG 알고리즘의 공정한 비교 프레임워크.

DomainBed 주요 발견:

1. 많은 DG 알고리즘이 잘 튜닝된 ERM과 비슷하거나 못함
2. 모델 선택 방법이 결과에 큰 영향
3. 하이퍼파라미터 튜닝과 데이터 증강이 알고리즘 선택보다 중요

DomainBed 평균 정확도 (leave-one-domain-out):
  ERM:           ~72%
  IRM:           ~71%
  GroupDRO:      ~71%
  CORAL:         ~73%
  SWAD:          ~76%
  CLIP zero-shot:~80%+ (도메인에 따라 다름)

Python 구현 예시

ERM Baseline + CORAL

import torch
import torch.nn as nn
import torch.nn.functional as F
from torchvision import models

class CORALLoss(nn.Module):
    """CORAL: Correlation Alignment Loss"""
    def forward(self, source, target):
        d = source.size(1)
        ns, nt = source.size(0), target.size(0)

        # 공분산 행렬 계산
        source_centered = source - source.mean(0)
        target_centered = target - target.mean(0)
        cov_s = (source_centered.T @ source_centered) / (ns - 1)
        cov_t = (target_centered.T @ target_centered) / (nt - 1)

        loss = (cov_s - cov_t).pow(2).sum() / (4 * d * d)
        return loss

class DomainGeneralizationModel(nn.Module):
    def __init__(self, num_classes, backbone="resnet50"):
        super().__init__()
        self.featurizer = models.resnet50(pretrained=True)
        feat_dim = self.featurizer.fc.in_features
        self.featurizer.fc = nn.Identity()
        self.classifier = nn.Linear(feat_dim, num_classes)

    def forward(self, x):
        features = self.featurizer(x)
        return self.classifier(features), features

def train_coral(model, train_loaders, optimizer, lambda_coral=1.0):
    """Multi-source CORAL training step"""
    model.train()
    coral_loss_fn = CORALLoss()
    total_loss = 0

    # 각 도메인에서 배치 가져오기
    domain_batches = [next(iter(loader)) for loader in train_loaders]

    all_features = []
    cls_loss = 0
    for x, y in domain_batches:
        logits, features = model(x.cuda())
        cls_loss += F.cross_entropy(logits, y.cuda())
        all_features.append(features)

    cls_loss /= len(domain_batches)

    # 도메인 쌍별 CORAL loss
    coral = 0
    n_pairs = 0
    for i in range(len(all_features)):
        for j in range(i + 1, len(all_features)):
            coral += coral_loss_fn(all_features[i], all_features[j])
            n_pairs += 1
    coral /= max(n_pairs, 1)

    loss = cls_loss + lambda_coral * coral
    optimizer.zero_grad()
    loss.backward()
    optimizer.step()
    return loss.item()

IRM (IRMv1)

class IRMTrainer:
    """Invariant Risk Minimization (IRMv1)"""
    def __init__(self, model, lr=1e-3, lambda_irm=1.0, anneal_steps=500):
        self.model = model
        self.lambda_irm = lambda_irm
        self.anneal_steps = anneal_steps
        self.optimizer = torch.optim.Adam(model.parameters(), lr=lr)
        self.step = 0

    @staticmethod
    def irm_penalty(logits, y):
        """IRMv1 penalty: ||grad_{w|w=1} R^e(w . Phi)||^2"""
        scale = torch.tensor(1.0, requires_grad=True, device=logits.device)
        loss = F.cross_entropy(logits * scale, y)
        grad = torch.autograd.grad(loss, scale, create_graph=True)[0]
        return grad ** 2

    def train_step(self, domain_batches):
        self.model.train()
        self.step += 1

        env_losses = []
        env_penalties = []

        for x, y in domain_batches:
            x, y = x.cuda(), y.cuda()
            logits, _ = self.model(x)
            env_losses.append(F.cross_entropy(logits, y))
            env_penalties.append(self.irm_penalty(logits, y))

        mean_loss = torch.stack(env_losses).mean()
        mean_penalty = torch.stack(env_penalties).mean()

        # Annealing: 초기에는 ERM, 이후 IRM penalty 점진적 증가
        penalty_weight = self.lambda_irm if self.step > self.anneal_steps else 1.0
        loss = mean_loss + penalty_weight * mean_penalty

        # ERM penalty 보정 (penalty 최소화가 loss도 줄이도록)
        if penalty_weight > 1.0:
            loss /= penalty_weight

        self.optimizer.zero_grad()
        loss.backward()
        self.optimizer.step()
        return {"loss": loss.item(), "penalty": mean_penalty.item()}

GroupDRO

class GroupDRO:
    """Group Distributionally Robust Optimization"""
    def __init__(self, model, n_domains, lr=1e-3, eta=0.01):
        self.model = model
        self.n_domains = n_domains
        self.eta = eta  # 도메인 가중치 학습률
        self.optimizer = torch.optim.Adam(model.parameters(), lr=lr)
        # 균등 초기화
        self.q = torch.ones(n_domains, device="cuda") / n_domains

    def train_step(self, domain_batches):
        self.model.train()
        env_losses = []

        for x, y in domain_batches:
            x, y = x.cuda(), y.cuda()
            logits, _ = self.model(x)
            env_losses.append(F.cross_entropy(logits, y))

        env_losses = torch.stack(env_losses)

        # 도메인 가중치 업데이트 (최악 도메인에 가중치 증가)
        with torch.no_grad():
            self.q *= torch.exp(self.eta * env_losses.detach())
            self.q /= self.q.sum()  # 정규화

        # 가중 합 손실
        loss = (self.q * env_losses).sum()

        self.optimizer.zero_grad()
        loss.backward()
        self.optimizer.step()
        return {"loss": loss.item(), "q": self.q.cpu().tolist()}

평가 (Leave-One-Domain-Out)

def evaluate_dg(model, test_loader, domain_name):
    """도메인 일반화 평가"""
    model.eval()
    correct, total = 0, 0

    with torch.no_grad():
        for x, y in test_loader:
            x, y = x.cuda(), y.cuda()
            logits, _ = model(x)
            preds = logits.argmax(dim=1)
            correct += (preds == y).sum().item()
            total += y.size(0)

    acc = correct / total * 100
    print(f"Domain [{domain_name}] Accuracy: {acc:.1f}%")
    return acc

def leave_one_domain_out(datasets, model_fn, train_fn, n_epochs=50):
    """Leave-one-domain-out cross-validation"""
    results = {}
    domain_names = list(datasets.keys())

    for test_domain in domain_names:
        print(f"\n--- Test domain: {test_domain} ---")
        train_domains = [d for d in domain_names if d != test_domain]
        train_loaders = [make_loader(datasets[d]) for d in train_domains]
        test_loader = make_loader(datasets[test_domain], shuffle=False)

        model = model_fn().cuda()
        for epoch in range(n_epochs):
            train_fn(model, train_loaders)

        acc = evaluate_dg(model, test_loader, test_domain)
        results[test_domain] = acc

    avg = sum(results.values()) / len(results)
    print(f"\nAverage: {avg:.1f}%")
    return results

최신 동향 (2024-2025)

방향 설명 대표 연구
Foundation Model 활용 CLIP zero-shot이 기존 DG 능가 Cha et al. (2024), DomainBed + CLIP
Test-Time Adaptation 결합 추론 시 타겟 도메인에 적응 TENT, TTT++, SAR
Causal Representation 인과적 불변 특징 학습 CausIRL (Chevalley et al., 2022)
Sharpness-Aware 최적화 평탄한 손실 지형이 일반화에 유리 SWAD (Cha et al., NeurIPS 2021)
Multi-Modal DG 텍스트-이미지 결합 DG CLIP-DG, MaPLe
도메인 메타데이터 활용 추론 시에도 도메인 정보 사용 D3G (Yao et al., 2024)

실무 가이드라인

도메인 일반화 적용 판단:

1. 문제가 DG인가?
   - 학습 데이터에 없는 새 도메인이 등재할 것인가?    --> Yes: DG
   - 타겟 도메인 데이터 (unlabeled이라도) 있는가?    --> Yes: DA (Domain Adaptation)
   - 단일 도메인, 분포 변화 없음?                    --> Standard ML

2. 기법 선택 순서:
   a. Foundation Model (CLIP) zero-shot 먼저 시도
   b. 충분한 소스 도메인 (3+): CORAL, IRM, SWAD
   c. 최악 그룹 성능 중요: GroupDRO
   d. 소규모 데이터: Meta-Learning (MLDG)
   e. 위 모두 시도 후 잘 튜닝된 ERM과 비교 필수

3. 공통 주의사항:
   - DomainBed의 교훈: 하이퍼파라미터 > 알고리즘
   - 데이터 증강 (RandAugment 등) 효과가 큼
   - 모델 선택 (validation) 방법이 결과 좌우

관련 문서

참고문헌

  1. Arjovsky, M., et al. (2019). "Invariant Risk Minimization." arXiv:1907.02893.
  2. Zhou, K., et al. (2022). "Domain Generalization: A Survey." IEEE TPAMI.
  3. Gulrajani, I. & Lopez-Paz, D. (2021). "In Search of Lost Domain Generalization." ICLR 2021.
  4. Sagawa, S., et al. (2020). "Distributionally Robust Neural Networks for Group Shifts." ICLR 2020.
  5. Cha, J., et al. (2021). "SWAD: Domain Generalization by Seeking Flat Minima." NeurIPS 2021.
  6. Krueger, D., et al. (2021). "Out-of-Distribution Generalization via Risk Extrapolation." ICML 2021.
  7. Sun, B. & Saenko, K. (2016). "Deep CORAL: Correlation Alignment for Deep Domain Adaptation." ECCV 2016 Workshops.
  8. Li, D., et al. (2018). "Domain Generalization with Adversarial Feature Learning." CVPR 2018.