배치 정규화와 정규화 기법 (Normalization)¶
신경망 학습을 안정화하고 가속화하는 기법. 내부 공변량 이동(Internal Covariate Shift)을 줄임.
정규화의 필요성¶
내부 공변량 이동¶
정규화의 효과¶
- 더 큰 학습률 사용 가능
- 초기화에 덜 민감
- 정규화 효과 (드롭아웃과 유사)
- 기울기 흐름 개선
Batch Normalization¶
미니배치 통계를 사용한 정규화.
수식¶
\[\hat{x}_i = \frac{x_i - \mu_B}{\sqrt{\sigma_B^2 + \epsilon}}$$
$$y_i = \gamma \hat{x}_i + \beta\]
- \(\mu_B\): 배치 평균
- \(\sigma_B^2\): 배치 분산
- \(\gamma, \beta\): 학습 가능한 스케일, 시프트 파라미터
구현¶
import torch
import torch.nn as nn
class BatchNorm1d(nn.Module):
def __init__(self, num_features, eps=1e-5, momentum=0.1):
super().__init__()
self.eps = eps
self.momentum = momentum
# 학습 파라미터
self.gamma = nn.Parameter(torch.ones(num_features))
self.beta = nn.Parameter(torch.zeros(num_features))
# Running statistics (추론용)
self.register_buffer('running_mean', torch.zeros(num_features))
self.register_buffer('running_var', torch.ones(num_features))
def forward(self, x):
if self.training:
# 배치 통계 계산
mean = x.mean(dim=0)
var = x.var(dim=0, unbiased=False)
# Running statistics 업데이트
self.running_mean = (1 - self.momentum) * self.running_mean + self.momentum * mean
self.running_var = (1 - self.momentum) * self.running_var + self.momentum * var
else:
# 추론 시 running statistics 사용
mean = self.running_mean
var = self.running_var
# 정규화
x_norm = (x - mean) / torch.sqrt(var + self.eps)
return self.gamma * x_norm + self.beta
# PyTorch 내장
bn1d = nn.BatchNorm1d(256)
bn2d = nn.BatchNorm2d(64) # CNN용
CNN에서의 적용¶
class ConvBlock(nn.Module):
def __init__(self, in_channels, out_channels):
super().__init__()
self.conv = nn.Conv2d(in_channels, out_channels, 3, padding=1)
self.bn = nn.BatchNorm2d(out_channels)
self.relu = nn.ReLU()
def forward(self, x):
x = self.conv(x)
x = self.bn(x) # Conv 다음, 활성화 이전
x = self.relu(x)
return x
BatchNorm의 문제점¶
| 문제 | 설명 |
|---|---|
| 배치 크기 의존 | 작은 배치에서 불안정 |
| 시퀀스 길이 | 가변 길이 시퀀스에 부적합 |
| 분산 학습 | 배치 통계 동기화 필요 |
| 추론 불일치 | 훈련/추론 시 다른 통계 사용 |
Layer Normalization¶
시퀀스 모델과 Transformer의 표준.
수식¶
각 샘플 내에서 정규화:
\[\hat{x} = \frac{x - \mu}{\sqrt{\sigma^2 + \epsilon}}$$
$$y = \gamma \odot \hat{x} + \beta\]
\(\mu, \sigma\): 한 샘플의 모든 특성에 대한 통계
구현¶
class LayerNorm(nn.Module):
def __init__(self, normalized_shape, eps=1e-5):
super().__init__()
self.eps = eps
if isinstance(normalized_shape, int):
normalized_shape = (normalized_shape,)
self.normalized_shape = normalized_shape
self.gamma = nn.Parameter(torch.ones(normalized_shape))
self.beta = nn.Parameter(torch.zeros(normalized_shape))
def forward(self, x):
# 마지막 len(normalized_shape) 차원에 대해 정규화
dims = tuple(range(-len(self.normalized_shape), 0))
mean = x.mean(dim=dims, keepdim=True)
var = x.var(dim=dims, keepdim=True, unbiased=False)
x_norm = (x - mean) / torch.sqrt(var + self.eps)
return self.gamma * x_norm + self.beta
# PyTorch 내장
ln = nn.LayerNorm(normalized_shape=768)
ln = nn.LayerNorm(normalized_shape=[768]) # 동일
ln = nn.LayerNorm(normalized_shape=[seq_len, 768]) # 여러 차원
Transformer에서의 적용¶
class TransformerBlock(nn.Module):
def __init__(self, d_model, num_heads, d_ff, dropout=0.1):
super().__init__()
self.attention = nn.MultiheadAttention(d_model, num_heads)
self.norm1 = nn.LayerNorm(d_model)
self.ffn = nn.Sequential(
nn.Linear(d_model, d_ff),
nn.GELU(),
nn.Linear(d_ff, d_model)
)
self.norm2 = nn.LayerNorm(d_model)
self.dropout = nn.Dropout(dropout)
def forward(self, x):
# Pre-LN (GPT 스타일) - 더 안정적
attn_out = self.attention(self.norm1(x), self.norm1(x), self.norm1(x))[0]
x = x + self.dropout(attn_out)
ff_out = self.ffn(self.norm2(x))
x = x + self.dropout(ff_out)
return x
# Post-LN (원래 Transformer)
class PostLNTransformerBlock(nn.Module):
def forward(self, x):
attn_out = self.attention(x, x, x)[0]
x = self.norm1(x + self.dropout(attn_out)) # Norm 이후
ff_out = self.ffn(x)
x = self.norm2(x + self.dropout(ff_out))
return x
RMSNorm (Root Mean Square Normalization)¶
LLaMA, Gemma 등에서 사용. LayerNorm의 간소화 버전.
수식¶
\[\text{RMS}(x) = \sqrt{\frac{1}{n}\sum_{i=1}^{n}x_i^2}$$
$$\hat{x} = \frac{x}{\text{RMS}(x)}$$
$$y = \gamma \odot \hat{x}\]
평균 빼기 없음, 편향 없음.
구현¶
class RMSNorm(nn.Module):
def __init__(self, dim, eps=1e-6):
super().__init__()
self.eps = eps
self.weight = nn.Parameter(torch.ones(dim))
def forward(self, x):
# RMS 계산
rms = torch.sqrt(x.pow(2).mean(-1, keepdim=True) + self.eps)
x_norm = x / rms
return self.weight * x_norm
# 사용 (LLaMA 스타일)
class LLaMABlock(nn.Module):
def __init__(self, dim, num_heads, ff_dim):
super().__init__()
self.attention = MultiHeadAttention(dim, num_heads)
self.norm1 = RMSNorm(dim)
self.ffn = SwiGLU(dim, ff_dim)
self.norm2 = RMSNorm(dim)
def forward(self, x):
x = x + self.attention(self.norm1(x))
x = x + self.ffn(self.norm2(x))
return x
RMSNorm vs LayerNorm¶
| 특성 | LayerNorm | RMSNorm |
|---|---|---|
| 평균 빼기 | O | X |
| 편향 (β) | O | X |
| 계산 비용 | 높음 | 낮음 |
| 파라미터 | 2n | n |
| 성능 | 기준 | 유사 |
정규화 비교¶
정규화 축¶
적용 가이드¶
| 정규화 | 사용 사례 |
|---|---|
| BatchNorm | CNN, 큰 배치 |
| LayerNorm | Transformer, RNN, 작은 배치 |
| RMSNorm | LLM (효율성) |
| GroupNorm | 작은 배치 CNN |
| InstanceNorm | 스타일 트랜스퍼 |
정규화 위치¶
Pre-Norm vs Post-Norm¶
# Post-Norm (원래 Transformer)
x = x + sublayer(x)
x = norm(x)
# Pre-Norm (GPT, LLaMA)
x = x + sublayer(norm(x))
# Pre-Norm이 더 안정적 (기울기 흐름 개선)
마지막 LayerNorm¶
class GPT(nn.Module):
def __init__(self, ...):
self.transformer_blocks = nn.ModuleList([...])
self.ln_f = nn.LayerNorm(d_model) # 최종 LayerNorm
self.lm_head = nn.Linear(d_model, vocab_size)
def forward(self, x):
for block in self.transformer_blocks:
x = block(x)
x = self.ln_f(x) # 마지막에 정규화
logits = self.lm_head(x)
return logits
구현 최적화¶
Fused LayerNorm¶
# CUDA 커널 퓨전으로 속도 향상
try:
from apex.normalization import FusedLayerNorm
LayerNorm = FusedLayerNorm
except ImportError:
LayerNorm = nn.LayerNorm
# 또는 PyTorch 2.0+의 compiled 버전
model = torch.compile(model) # 자동 최적화
메모리 효율적 LayerNorm¶
# Gradient checkpointing과 함께
from torch.utils.checkpoint import checkpoint
class EfficientTransformer(nn.Module):
def forward(self, x):
for block in self.blocks:
x = checkpoint(block, x) # 메모리 절약
return x
실무 가이드¶
정규화 레이어 선택¶
성능 영향 분석¶
# 정규화 유무에 따른 학습 곡선 비교
def compare_normalization(model_with_norm, model_without_norm, train_loader, num_epochs=10):
"""정규화의 효과 비교"""
results = {'with_norm': [], 'without_norm': []}
for model, key in [(model_with_norm, 'with_norm'), (model_without_norm, 'without_norm')]:
optimizer = torch.optim.Adam(model.parameters(), lr=0.001)
for epoch in range(num_epochs):
epoch_loss = 0
for x, y in train_loader:
optimizer.zero_grad()
loss = F.cross_entropy(model(x), y)
loss.backward()
optimizer.step()
epoch_loss += loss.item()
results[key].append(epoch_loss / len(train_loader))
# 시각화
plt.plot(results['with_norm'], label='With Normalization')
plt.plot(results['without_norm'], label='Without Normalization')
plt.xlabel('Epoch')
plt.ylabel('Loss')
plt.legend()
plt.title('Effect of Normalization on Training')
plt.show()
추론 시 BatchNorm 주의사항¶
# 훈련 모드와 추론 모드의 차이
model.train() # BatchNorm: 배치 통계 사용
model.eval() # BatchNorm: running 통계 사용 (중요!)
# 잘못된 추론 (흔한 실수)
output = model(input) # model.eval() 호출 안 함 → 배치 통계 사용
# 올바른 추론
model.eval()
with torch.no_grad():
output = model(input)
# 배치 크기 1로 추론 시 문제
# BatchNorm + batch_size=1 → 분산이 0 → NaN 가능
# 해결: eval() 모드 사용 또는 LayerNorm으로 교체
# Running statistics 업데이트 확인
for module in model.modules():
if isinstance(module, nn.BatchNorm2d):
print(f"running_mean: {module.running_mean[:5]}")
print(f"running_var: {module.running_var[:5]}")
print(f"num_batches_tracked: {module.num_batches_tracked}")
디버깅 가이드¶
정규화 통계 모니터링¶
class NormalizationDebugger:
"""정규화 레이어 디버깅 도구"""
def __init__(self, model):
self.model = model
self.stats_history = []
def capture_stats(self):
"""현재 정규화 통계 캡처"""
stats = {}
for name, module in self.model.named_modules():
if isinstance(module, nn.BatchNorm2d):
stats[name] = {
'type': 'BatchNorm',
'running_mean': module.running_mean.mean().item(),
'running_var': module.running_var.mean().item(),
'weight_mean': module.weight.mean().item(),
'bias_mean': module.bias.mean().item()
}
elif isinstance(module, nn.LayerNorm):
stats[name] = {
'type': 'LayerNorm',
'weight_mean': module.weight.mean().item(),
'weight_std': module.weight.std().item(),
'bias_mean': module.bias.mean().item()
}
self.stats_history.append(stats)
return stats
def check_health(self, stats=None):
"""정규화 레이어 건강 체크"""
if stats is None:
stats = self.capture_stats()
issues = []
for name, info in stats.items():
if info['type'] == 'BatchNorm':
if info['running_var'] < 1e-5:
issues.append(f"[WARNING] {name}: running_var very small ({info['running_var']:.2e})")
if info['running_var'] > 100:
issues.append(f"[WARNING] {name}: running_var very large ({info['running_var']:.2e})")
if abs(info.get('weight_mean', 1) - 1) > 0.5:
issues.append(f"[INFO] {name}: weight mean far from 1 ({info['weight_mean']:.3f})")
return issues
# 사용
debugger = NormalizationDebugger(model)
for epoch in range(num_epochs):
train_one_epoch(model, train_loader)
issues = debugger.check_health()
if issues:
print(f"Epoch {epoch}:")
for issue in issues:
print(f" {issue}")
일반적인 문제와 해결¶
| 문제 | 증상 | 해결책 |
|---|---|---|
| 배치 크기가 너무 작음 | BatchNorm 불안정 | GroupNorm 또는 LayerNorm 사용 |
| eval() 미호출 | 추론 시 다른 결과 | model.eval() 호출 확인 |
| 분산이 0 | NaN 또는 Inf | eps 증가 또는 입력 확인 |
| 학습 느림 | 정규화 후 분포 이상 | gamma/beta 초기화 확인 |
Pre-LN vs Post-LN 비교¶
# Post-LN (원래 Transformer) - 불안정할 수 있음
class PostLNBlock(nn.Module):
def forward(self, x):
x = self.norm1(x + self.attn(x)) # Norm이 뒤에
x = self.norm2(x + self.ffn(x))
return x
# Pre-LN (GPT 스타일) - 더 안정적
class PreLNBlock(nn.Module):
def forward(self, x):
x = x + self.attn(self.norm1(x)) # Norm이 앞에
x = x + self.ffn(self.norm2(x))
return x
# Pre-LN 장점:
# - 기울기 흐름이 더 안정적 (잔차 연결을 통해 직접 전파)
# - 학습률에 덜 민감
# - Warmup 없이도 학습 가능한 경우 있음
# 대부분의 현대 LLM은 Pre-LN + RMSNorm 사용
분산 학습에서의 BatchNorm¶
# 단일 GPU BatchNorm - 각 GPU의 배치만 사용
bn = nn.BatchNorm2d(64)
# SyncBatchNorm - 모든 GPU의 통계 동기화
bn = nn.SyncBatchNorm(64)
# 모델 전체를 SyncBatchNorm으로 변환
model = nn.SyncBatchNorm.convert_sync_batchnorm(model)
# 주의: SyncBatchNorm은 통신 오버헤드 발생
# 대규모 분산 학습에서는 LayerNorm 선호