Domain Generalization (도메인 일반화)¶
메타 정보¶
| 항목 | 내용 |
|---|---|
| 분류 | Robustness / Out-of-Distribution / Transfer Learning |
| 핵심 논문 | "Invariant Risk Minimization" (Arjovsky et al., 2019), "Domain Generalization: A Survey" (Zhou et al., IEEE TPAMI 2022), "In Search of Lost Domain Generalization" (Gulrajani & Lopez-Paz, ICLR 2021) |
| 주요 저자 | Martin Arjovsky (IRM); Kaiyang Zhou, Ziwei Liu, Chen Change Loy (Survey); Ishaan Gulrajani, David Lopez-Paz (DomainBed) |
| 핵심 개념 | 학습 시 관측하지 못한 도메인에서도 잘 동작하는 모델 구축 -- 도메인 불변 표현 학습 |
| 관련 분야 | Domain Adaptation, Transfer Learning, Causal Inference, Distributionally Robust Optimization |
정의¶
Domain Generalization (DG)은 여러 소스 도메인의 데이터로 학습하여, 학습 시 전혀 접하지 못한 타겟 도메인에서도 강건하게 일반화되는 모델을 구축하는 문제다. Domain Adaptation과 달리 타겟 도메인 데이터에 접근할 수 없다는 점이 핵심 제약이다.
Domain Adaptation vs Domain Generalization:
Domain Adaptation:
Source domains D_s + Target domain D_t (unlabeled) --> 모델 적응
타겟 도메인 데이터 접근 가능
Domain Generalization:
Source domains {D_1, D_2, ..., D_K} --> 모델 학습
타겟 도메인 D_t 접근 불가 (학습 시 존재하지 않음)
목표: argmin_theta max_{D_t} L(f_theta, D_t)
문제 설정¶
형식적 정의:
소스 도메인: S = {D_1, D_2, ..., D_K}, 각 D_k = {(x_i^k, y_i^k)}
각 도메인은 서로 다른 분포: P_k(X, Y) != P_j(X, Y)
목표: 학습되지 않은 도메인 D_t ~ P_t에서 일반화
f* = argmin_theta E_{(x,y)~P_t} [L(f_theta(x), y)]
도메인 시프트 유형:
+---------------------------------------------------+
| Covariate Shift: P_s(X) != P_t(X), P(Y|X) 동일 |
| Label Shift: P_s(Y) != P_t(Y), P(X|Y) 동일 |
| Concept Shift: P_s(Y|X) != P_t(Y|X) |
| 복합 시프트: 위 조합 |
+---------------------------------------------------+
주요 접근법¶
Domain Generalization 분류 체계:
DG Methods
|
+-- Domain Alignment (도메인 정렬)
| CORAL, MMD, Adversarial (DANN)
|
+-- Data Augmentation (데이터 증강)
| Mixup, CrossGrad, Style Transfer
|
+-- Meta-Learning (메타 학습)
| MLDG, MetaReg, MASF
|
+-- Invariant Learning (불변 표현 학습)
| IRM, V-REx, FISH, Fishr
|
+-- Distributionally Robust Optimization
| GroupDRO, EQRM
|
+-- Foundation Model Adaptation
CLIP zero-shot, Prompt Tuning, Adapter
1. Domain Alignment (도메인 정렬)¶
여러 소스 도메인의 특징 분포를 정렬하여 도메인 불변 표현을 학습한다.
CORAL (Correlation Alignment)¶
Sun & Saenko (2016). 소스와 타겟의 2차 통계량(공분산)을 정렬한다.
CORAL Loss:
L_CORAL = (1 / 4d^2) * ||C_s - C_t||_F^2
C_s = (1/(n_s-1)) * (D_s^T D_s - (1/n_s)(1^T D_s)^T (1^T D_s))
C_t: 동일하게 타겟에 대해 계산
Deep CORAL: CNN의 마지막 FC layer 출력에 CORAL loss 적용
Total Loss = L_classification + lambda * L_CORAL
MMD (Maximum Mean Discrepancy)¶
Li et al. (2018). 커널 공간에서 두 분포의 평균 차이를 최소화한다.
MMD^2(P, Q) = E[k(x,x')] - 2E[k(x,y)] + E[k(y,y')]
x, x' ~ P, y, y' ~ Q
k: 커널 함수 (주로 Gaussian RBF)
다중 도메인 확장:
L_MMD = sum_{i<j} MMD^2(D_i, D_j)
2. Invariant Risk Minimization (IRM)¶
Arjovsky et al. (2019). 모든 환경에서 동시에 최적인 불변 예측자를 학습한다.
IRM 목표:
min_{Phi, w} sum_{e in E_tr} R^e(w . Phi)
subject to: w in argmin_{w'} R^e(w' . Phi), for all e in E_tr
E_tr: 학습 환경(도메인) 집합
Phi: 특징 추출기
w: 선형 분류기
R^e: 환경 e에서의 위험
실용적 변형 (IRMv1):
min_{Phi} sum_{e in E_tr} [R^e(Phi) + lambda * ||grad_{w|w=1.0} R^e(w . Phi)||^2]
두 번째 항: 각 환경에서 w=1.0이 최적이 되도록 강제
--> Phi가 불변 특징만 추출하도록 유도
핵심 아이디어:
+-----------------------------------------------+
| 환경 1: 소 + 초원 배경 --> "소" 예측 |
| 환경 2: 소 + 해변 배경 --> "소" 예측 |
| |
| 배경은 spurious correlation (가짜 상관) |
| IRM: 모든 환경에서 불변인 "소의 형태"만 활용 |
+-----------------------------------------------+
3. V-REx (Variance Risk Extrapolation)¶
Krueger et al. (ICML 2021). 도메인 간 손실 분산을 최소화하여 일반화한다.
V-REx Objective:
L = (1/K) * sum_k R^k(theta) + beta * Var({R^1, ..., R^K})
첫째 항: 평균 위험 최소화
둘째 항: 도메인 간 위험 분산 최소화
beta: 정규화 강도
직관: 모든 도메인에서 비슷한 손실 --> 도메인 불변 특징 사용
4. GroupDRO (Distributionally Robust Optimization)¶
Sagawa et al. (ICLR 2020). 최악 그룹의 성능을 최적화한다.
GroupDRO Objective:
min_theta max_{q in Delta_K} sum_{k=1}^{K} q_k * R^k(theta)
Delta_K: K-simplex (가중치 합 = 1)
q: 도메인 가중치 (최악 도메인에 높은 가중치 부여)
온라인 업데이트:
q_k^{t+1} = q_k^t * exp(eta * R^k(theta_t)) (정규화)
결과: worst-group accuracy 크게 개선
5. Data Augmentation 기반¶
Mixup 변형들¶
도메인 간 데이터 혼합:
Vanilla Mixup:
x_mix = lambda * x_i + (1-lambda) * x_j
y_mix = lambda * y_i + (1-lambda) * y_j
CrossGrad (Shankar et al., ICLR 2018):
도메인 분류기의 gradient 방향으로 입력 변형
x' = x + epsilon * sign(grad_x L_domain(x, d))
Style Transfer 기반:
source 이미지의 content + 다른 도메인의 style 결합
AdaIN(z) = sigma_target * (z - mu_source) / sigma_source + mu_target
6. Meta-Learning 기반¶
MLDG (Meta-Learning Domain Generalization)¶
Li et al. (AAAI 2018). 도메인 분할을 meta-train/meta-test로 활용한다.
MLDG 알고리즘:
반복:
1. 소스 도메인을 S(meta-train)과 V(meta-test)로 분할
2. S에서 일반 학습: theta' = theta - alpha * grad L_S(theta)
3. V에서 메타 테스트: L_V(theta')
4. 메타 업데이트: theta = theta - beta * grad [L_S(theta) + gamma * L_V(theta')]
이중 최적화로 새 도메인에 대한 일반화 능력 직접 최적화
7. Foundation Model 활용¶
최근 CLIP 등 사전학습된 대규모 모델의 zero-shot 성능이 기존 DG 기법을 능가하는 경우가 많다.
Foundation Model 기반 DG:
Zero-shot CLIP:
텍스트 프롬프트로 분류 --> DomainBed 벤치마크에서 경쟁력 있는 성능
Prompt Tuning (CoOp, CoCoOp):
고정 텍스트 대신 학습 가능한 soft prompt 사용
CoOp: 도메인 일반화 성능 부족
CoCoOp: 이미지 조건부 프롬프트로 개선
LP-FT (Linear Probing then Fine-Tuning):
1단계: linear probe로 classifier head 학습
2단계: 전체 네트워크 fine-tuning
feature distortion 감소 효과
주요 벤치마크¶
| 벤치마크 | 도메인 수 | 특징 |
|---|---|---|
| PACS | 4 (Photo, Art, Cartoon, Sketch) | 7 클래스, 이미지 스타일 시프트 |
| VLCS | 4 (VOC, LabelMe, Caltech, Sun) | 5 클래스, 데이터셋 간 시프트 |
| OfficeHome | 4 (Art, Clipart, Product, Real) | 65 클래스, 상업 이미지 |
| TerraIncognita | 4 (위치별 카메라 트랩) | 10 클래스, 지리적 시프트 |
| DomainNet | 6 (Clipart, Infograph, Painting, ...) | 345 클래스, 대규모 |
| Wilds | 다수 | 실제 분포 시프트 벤치마크 |
DomainBed 프레임워크¶
Gulrajani & Lopez-Paz (ICLR 2021). DG 알고리즘의 공정한 비교 프레임워크.
DomainBed 주요 발견:
1. 많은 DG 알고리즘이 잘 튜닝된 ERM과 비슷하거나 못함
2. 모델 선택 방법이 결과에 큰 영향
3. 하이퍼파라미터 튜닝과 데이터 증강이 알고리즘 선택보다 중요
DomainBed 평균 정확도 (leave-one-domain-out):
ERM: ~72%
IRM: ~71%
GroupDRO: ~71%
CORAL: ~73%
SWAD: ~76%
CLIP zero-shot:~80%+ (도메인에 따라 다름)
Python 구현 예시¶
ERM Baseline + CORAL¶
import torch
import torch.nn as nn
import torch.nn.functional as F
from torchvision import models
class CORALLoss(nn.Module):
"""CORAL: Correlation Alignment Loss"""
def forward(self, source, target):
d = source.size(1)
ns, nt = source.size(0), target.size(0)
# 공분산 행렬 계산
source_centered = source - source.mean(0)
target_centered = target - target.mean(0)
cov_s = (source_centered.T @ source_centered) / (ns - 1)
cov_t = (target_centered.T @ target_centered) / (nt - 1)
loss = (cov_s - cov_t).pow(2).sum() / (4 * d * d)
return loss
class DomainGeneralizationModel(nn.Module):
def __init__(self, num_classes, backbone="resnet50"):
super().__init__()
self.featurizer = models.resnet50(pretrained=True)
feat_dim = self.featurizer.fc.in_features
self.featurizer.fc = nn.Identity()
self.classifier = nn.Linear(feat_dim, num_classes)
def forward(self, x):
features = self.featurizer(x)
return self.classifier(features), features
def train_coral(model, train_loaders, optimizer, lambda_coral=1.0):
"""Multi-source CORAL training step"""
model.train()
coral_loss_fn = CORALLoss()
total_loss = 0
# 각 도메인에서 배치 가져오기
domain_batches = [next(iter(loader)) for loader in train_loaders]
all_features = []
cls_loss = 0
for x, y in domain_batches:
logits, features = model(x.cuda())
cls_loss += F.cross_entropy(logits, y.cuda())
all_features.append(features)
cls_loss /= len(domain_batches)
# 도메인 쌍별 CORAL loss
coral = 0
n_pairs = 0
for i in range(len(all_features)):
for j in range(i + 1, len(all_features)):
coral += coral_loss_fn(all_features[i], all_features[j])
n_pairs += 1
coral /= max(n_pairs, 1)
loss = cls_loss + lambda_coral * coral
optimizer.zero_grad()
loss.backward()
optimizer.step()
return loss.item()
IRM (IRMv1)¶
class IRMTrainer:
"""Invariant Risk Minimization (IRMv1)"""
def __init__(self, model, lr=1e-3, lambda_irm=1.0, anneal_steps=500):
self.model = model
self.lambda_irm = lambda_irm
self.anneal_steps = anneal_steps
self.optimizer = torch.optim.Adam(model.parameters(), lr=lr)
self.step = 0
@staticmethod
def irm_penalty(logits, y):
"""IRMv1 penalty: ||grad_{w|w=1} R^e(w . Phi)||^2"""
scale = torch.tensor(1.0, requires_grad=True, device=logits.device)
loss = F.cross_entropy(logits * scale, y)
grad = torch.autograd.grad(loss, scale, create_graph=True)[0]
return grad ** 2
def train_step(self, domain_batches):
self.model.train()
self.step += 1
env_losses = []
env_penalties = []
for x, y in domain_batches:
x, y = x.cuda(), y.cuda()
logits, _ = self.model(x)
env_losses.append(F.cross_entropy(logits, y))
env_penalties.append(self.irm_penalty(logits, y))
mean_loss = torch.stack(env_losses).mean()
mean_penalty = torch.stack(env_penalties).mean()
# Annealing: 초기에는 ERM, 이후 IRM penalty 점진적 증가
penalty_weight = self.lambda_irm if self.step > self.anneal_steps else 1.0
loss = mean_loss + penalty_weight * mean_penalty
# ERM penalty 보정 (penalty 최소화가 loss도 줄이도록)
if penalty_weight > 1.0:
loss /= penalty_weight
self.optimizer.zero_grad()
loss.backward()
self.optimizer.step()
return {"loss": loss.item(), "penalty": mean_penalty.item()}
GroupDRO¶
class GroupDRO:
"""Group Distributionally Robust Optimization"""
def __init__(self, model, n_domains, lr=1e-3, eta=0.01):
self.model = model
self.n_domains = n_domains
self.eta = eta # 도메인 가중치 학습률
self.optimizer = torch.optim.Adam(model.parameters(), lr=lr)
# 균등 초기화
self.q = torch.ones(n_domains, device="cuda") / n_domains
def train_step(self, domain_batches):
self.model.train()
env_losses = []
for x, y in domain_batches:
x, y = x.cuda(), y.cuda()
logits, _ = self.model(x)
env_losses.append(F.cross_entropy(logits, y))
env_losses = torch.stack(env_losses)
# 도메인 가중치 업데이트 (최악 도메인에 가중치 증가)
with torch.no_grad():
self.q *= torch.exp(self.eta * env_losses.detach())
self.q /= self.q.sum() # 정규화
# 가중 합 손실
loss = (self.q * env_losses).sum()
self.optimizer.zero_grad()
loss.backward()
self.optimizer.step()
return {"loss": loss.item(), "q": self.q.cpu().tolist()}
평가 (Leave-One-Domain-Out)¶
def evaluate_dg(model, test_loader, domain_name):
"""도메인 일반화 평가"""
model.eval()
correct, total = 0, 0
with torch.no_grad():
for x, y in test_loader:
x, y = x.cuda(), y.cuda()
logits, _ = model(x)
preds = logits.argmax(dim=1)
correct += (preds == y).sum().item()
total += y.size(0)
acc = correct / total * 100
print(f"Domain [{domain_name}] Accuracy: {acc:.1f}%")
return acc
def leave_one_domain_out(datasets, model_fn, train_fn, n_epochs=50):
"""Leave-one-domain-out cross-validation"""
results = {}
domain_names = list(datasets.keys())
for test_domain in domain_names:
print(f"\n--- Test domain: {test_domain} ---")
train_domains = [d for d in domain_names if d != test_domain]
train_loaders = [make_loader(datasets[d]) for d in train_domains]
test_loader = make_loader(datasets[test_domain], shuffle=False)
model = model_fn().cuda()
for epoch in range(n_epochs):
train_fn(model, train_loaders)
acc = evaluate_dg(model, test_loader, test_domain)
results[test_domain] = acc
avg = sum(results.values()) / len(results)
print(f"\nAverage: {avg:.1f}%")
return results
최신 동향 (2024-2025)¶
| 방향 | 설명 | 대표 연구 |
|---|---|---|
| Foundation Model 활용 | CLIP zero-shot이 기존 DG 능가 | Cha et al. (2024), DomainBed + CLIP |
| Test-Time Adaptation 결합 | 추론 시 타겟 도메인에 적응 | TENT, TTT++, SAR |
| Causal Representation | 인과적 불변 특징 학습 | CausIRL (Chevalley et al., 2022) |
| Sharpness-Aware 최적화 | 평탄한 손실 지형이 일반화에 유리 | SWAD (Cha et al., NeurIPS 2021) |
| Multi-Modal DG | 텍스트-이미지 결합 DG | CLIP-DG, MaPLe |
| 도메인 메타데이터 활용 | 추론 시에도 도메인 정보 사용 | D3G (Yao et al., 2024) |
실무 가이드라인¶
도메인 일반화 적용 판단:
1. 문제가 DG인가?
- 학습 데이터에 없는 새 도메인이 등재할 것인가? --> Yes: DG
- 타겟 도메인 데이터 (unlabeled이라도) 있는가? --> Yes: DA (Domain Adaptation)
- 단일 도메인, 분포 변화 없음? --> Standard ML
2. 기법 선택 순서:
a. Foundation Model (CLIP) zero-shot 먼저 시도
b. 충분한 소스 도메인 (3+): CORAL, IRM, SWAD
c. 최악 그룹 성능 중요: GroupDRO
d. 소규모 데이터: Meta-Learning (MLDG)
e. 위 모두 시도 후 잘 튜닝된 ERM과 비교 필수
3. 공통 주의사항:
- DomainBed의 교훈: 하이퍼파라미터 > 알고리즘
- 데이터 증강 (RandAugment 등) 효과가 큼
- 모델 선택 (validation) 방법이 결과 좌우
관련 문서¶
참고문헌¶
- Arjovsky, M., et al. (2019). "Invariant Risk Minimization." arXiv:1907.02893.
- Zhou, K., et al. (2022). "Domain Generalization: A Survey." IEEE TPAMI.
- Gulrajani, I. & Lopez-Paz, D. (2021). "In Search of Lost Domain Generalization." ICLR 2021.
- Sagawa, S., et al. (2020). "Distributionally Robust Neural Networks for Group Shifts." ICLR 2020.
- Cha, J., et al. (2021). "SWAD: Domain Generalization by Seeking Flat Minima." NeurIPS 2021.
- Krueger, D., et al. (2021). "Out-of-Distribution Generalization via Risk Extrapolation." ICML 2021.
- Sun, B. & Saenko, K. (2016). "Deep CORAL: Correlation Alignment for Deep Domain Adaptation." ECCV 2016 Workshops.
- Li, D., et al. (2018). "Domain Generalization with Adversarial Feature Learning." CVPR 2018.