콘텐츠로 이동
Data Prep
상세

배치 정규화와 정규화 기법 (Normalization)

신경망 학습을 안정화하고 가속화하는 기법. 내부 공변량 이동(Internal Covariate Shift)을 줄임.

정규화의 필요성

내부 공변량 이동

normalization diagram 1

정규화의 효과

  • 더 큰 학습률 사용 가능
  • 초기화에 덜 민감
  • 정규화 효과 (드롭아웃과 유사)
  • 기울기 흐름 개선

Batch Normalization

미니배치 통계를 사용한 정규화.

수식

\[\hat{x}_i = \frac{x_i - \mu_B}{\sqrt{\sigma_B^2 + \epsilon}}$$ $$y_i = \gamma \hat{x}_i + \beta\]
  • \(\mu_B\): 배치 평균
  • \(\sigma_B^2\): 배치 분산
  • \(\gamma, \beta\): 학습 가능한 스케일, 시프트 파라미터

구현

import torch
import torch.nn as nn

class BatchNorm1d(nn.Module):
    def __init__(self, num_features, eps=1e-5, momentum=0.1):
        super().__init__()
        self.eps = eps
        self.momentum = momentum

        # 학습 파라미터
        self.gamma = nn.Parameter(torch.ones(num_features))
        self.beta = nn.Parameter(torch.zeros(num_features))

        # Running statistics (추론용)
        self.register_buffer('running_mean', torch.zeros(num_features))
        self.register_buffer('running_var', torch.ones(num_features))

    def forward(self, x):
        if self.training:
            # 배치 통계 계산
            mean = x.mean(dim=0)
            var = x.var(dim=0, unbiased=False)

            # Running statistics 업데이트
            self.running_mean = (1 - self.momentum) * self.running_mean + self.momentum * mean
            self.running_var = (1 - self.momentum) * self.running_var + self.momentum * var
        else:
            # 추론 시 running statistics 사용
            mean = self.running_mean
            var = self.running_var

        # 정규화
        x_norm = (x - mean) / torch.sqrt(var + self.eps)
        return self.gamma * x_norm + self.beta

# PyTorch 내장
bn1d = nn.BatchNorm1d(256)
bn2d = nn.BatchNorm2d(64)  # CNN용

CNN에서의 적용

class ConvBlock(nn.Module):
    def __init__(self, in_channels, out_channels):
        super().__init__()
        self.conv = nn.Conv2d(in_channels, out_channels, 3, padding=1)
        self.bn = nn.BatchNorm2d(out_channels)
        self.relu = nn.ReLU()

    def forward(self, x):
        x = self.conv(x)
        x = self.bn(x)  # Conv 다음, 활성화 이전
        x = self.relu(x)
        return x

BatchNorm의 문제점

문제 설명
배치 크기 의존 작은 배치에서 불안정
시퀀스 길이 가변 길이 시퀀스에 부적합
분산 학습 배치 통계 동기화 필요
추론 불일치 훈련/추론 시 다른 통계 사용

Layer Normalization

시퀀스 모델과 Transformer의 표준.

수식

각 샘플 내에서 정규화:

\[\hat{x} = \frac{x - \mu}{\sqrt{\sigma^2 + \epsilon}}$$ $$y = \gamma \odot \hat{x} + \beta\]

\(\mu, \sigma\): 한 샘플의 모든 특성에 대한 통계

구현

class LayerNorm(nn.Module):
    def __init__(self, normalized_shape, eps=1e-5):
        super().__init__()
        self.eps = eps

        if isinstance(normalized_shape, int):
            normalized_shape = (normalized_shape,)
        self.normalized_shape = normalized_shape

        self.gamma = nn.Parameter(torch.ones(normalized_shape))
        self.beta = nn.Parameter(torch.zeros(normalized_shape))

    def forward(self, x):
        # 마지막 len(normalized_shape) 차원에 대해 정규화
        dims = tuple(range(-len(self.normalized_shape), 0))

        mean = x.mean(dim=dims, keepdim=True)
        var = x.var(dim=dims, keepdim=True, unbiased=False)

        x_norm = (x - mean) / torch.sqrt(var + self.eps)
        return self.gamma * x_norm + self.beta

# PyTorch 내장
ln = nn.LayerNorm(normalized_shape=768)
ln = nn.LayerNorm(normalized_shape=[768])  # 동일
ln = nn.LayerNorm(normalized_shape=[seq_len, 768])  # 여러 차원

Transformer에서의 적용

class TransformerBlock(nn.Module):
    def __init__(self, d_model, num_heads, d_ff, dropout=0.1):
        super().__init__()

        self.attention = nn.MultiheadAttention(d_model, num_heads)
        self.norm1 = nn.LayerNorm(d_model)
        self.ffn = nn.Sequential(
            nn.Linear(d_model, d_ff),
            nn.GELU(),
            nn.Linear(d_ff, d_model)
        )
        self.norm2 = nn.LayerNorm(d_model)
        self.dropout = nn.Dropout(dropout)

    def forward(self, x):
        # Pre-LN (GPT 스타일) - 더 안정적
        attn_out = self.attention(self.norm1(x), self.norm1(x), self.norm1(x))[0]
        x = x + self.dropout(attn_out)

        ff_out = self.ffn(self.norm2(x))
        x = x + self.dropout(ff_out)

        return x

# Post-LN (원래 Transformer)
class PostLNTransformerBlock(nn.Module):
    def forward(self, x):
        attn_out = self.attention(x, x, x)[0]
        x = self.norm1(x + self.dropout(attn_out))  # Norm 이후

        ff_out = self.ffn(x)
        x = self.norm2(x + self.dropout(ff_out))

        return x

RMSNorm (Root Mean Square Normalization)

LLaMA, Gemma 등에서 사용. LayerNorm의 간소화 버전.

수식

\[\text{RMS}(x) = \sqrt{\frac{1}{n}\sum_{i=1}^{n}x_i^2}$$ $$\hat{x} = \frac{x}{\text{RMS}(x)}$$ $$y = \gamma \odot \hat{x}\]

평균 빼기 없음, 편향 없음.

구현

class RMSNorm(nn.Module):
    def __init__(self, dim, eps=1e-6):
        super().__init__()
        self.eps = eps
        self.weight = nn.Parameter(torch.ones(dim))

    def forward(self, x):
        # RMS 계산
        rms = torch.sqrt(x.pow(2).mean(-1, keepdim=True) + self.eps)
        x_norm = x / rms
        return self.weight * x_norm

# 사용 (LLaMA 스타일)
class LLaMABlock(nn.Module):
    def __init__(self, dim, num_heads, ff_dim):
        super().__init__()
        self.attention = MultiHeadAttention(dim, num_heads)
        self.norm1 = RMSNorm(dim)
        self.ffn = SwiGLU(dim, ff_dim)
        self.norm2 = RMSNorm(dim)

    def forward(self, x):
        x = x + self.attention(self.norm1(x))
        x = x + self.ffn(self.norm2(x))
        return x

RMSNorm vs LayerNorm

특성 LayerNorm RMSNorm
평균 빼기 O X
편향 (β) O X
계산 비용 높음 낮음
파라미터 2n n
성능 기준 유사

정규화 비교

정규화 축

normalization diagram 2

적용 가이드

정규화 사용 사례
BatchNorm CNN, 큰 배치
LayerNorm Transformer, RNN, 작은 배치
RMSNorm LLM (효율성)
GroupNorm 작은 배치 CNN
InstanceNorm 스타일 트랜스퍼

정규화 위치

Pre-Norm vs Post-Norm

# Post-Norm (원래 Transformer)
x = x + sublayer(x)
x = norm(x)

# Pre-Norm (GPT, LLaMA)
x = x + sublayer(norm(x))

# Pre-Norm이 더 안정적 (기울기 흐름 개선)

마지막 LayerNorm

class GPT(nn.Module):
    def __init__(self, ...):
        self.transformer_blocks = nn.ModuleList([...])
        self.ln_f = nn.LayerNorm(d_model)  # 최종 LayerNorm
        self.lm_head = nn.Linear(d_model, vocab_size)

    def forward(self, x):
        for block in self.transformer_blocks:
            x = block(x)
        x = self.ln_f(x)  # 마지막에 정규화
        logits = self.lm_head(x)
        return logits

구현 최적화

Fused LayerNorm

# CUDA 커널 퓨전으로 속도 향상
try:
    from apex.normalization import FusedLayerNorm
    LayerNorm = FusedLayerNorm
except ImportError:
    LayerNorm = nn.LayerNorm

# 또는 PyTorch 2.0+의 compiled 버전
model = torch.compile(model)  # 자동 최적화

메모리 효율적 LayerNorm

# Gradient checkpointing과 함께
from torch.utils.checkpoint import checkpoint

class EfficientTransformer(nn.Module):
    def forward(self, x):
        for block in self.blocks:
            x = checkpoint(block, x)  # 메모리 절약
        return x

실무 가이드

정규화 레이어 선택

normalization diagram 3

성능 영향 분석

# 정규화 유무에 따른 학습 곡선 비교
def compare_normalization(model_with_norm, model_without_norm, train_loader, num_epochs=10):
    """정규화의 효과 비교"""

    results = {'with_norm': [], 'without_norm': []}

    for model, key in [(model_with_norm, 'with_norm'), (model_without_norm, 'without_norm')]:
        optimizer = torch.optim.Adam(model.parameters(), lr=0.001)

        for epoch in range(num_epochs):
            epoch_loss = 0
            for x, y in train_loader:
                optimizer.zero_grad()
                loss = F.cross_entropy(model(x), y)
                loss.backward()
                optimizer.step()
                epoch_loss += loss.item()

            results[key].append(epoch_loss / len(train_loader))

    # 시각화
    plt.plot(results['with_norm'], label='With Normalization')
    plt.plot(results['without_norm'], label='Without Normalization')
    plt.xlabel('Epoch')
    plt.ylabel('Loss')
    plt.legend()
    plt.title('Effect of Normalization on Training')
    plt.show()

추론 시 BatchNorm 주의사항

# 훈련 모드와 추론 모드의 차이
model.train()  # BatchNorm: 배치 통계 사용
model.eval()   # BatchNorm: running 통계 사용 (중요!)

# 잘못된 추론 (흔한 실수)
output = model(input)  # model.eval() 호출 안 함 → 배치 통계 사용

# 올바른 추론
model.eval()
with torch.no_grad():
    output = model(input)

# 배치 크기 1로 추론 시 문제
# BatchNorm + batch_size=1 → 분산이 0 → NaN 가능
# 해결: eval() 모드 사용 또는 LayerNorm으로 교체

# Running statistics 업데이트 확인
for module in model.modules():
    if isinstance(module, nn.BatchNorm2d):
        print(f"running_mean: {module.running_mean[:5]}")
        print(f"running_var: {module.running_var[:5]}")
        print(f"num_batches_tracked: {module.num_batches_tracked}")

디버깅 가이드

정규화 통계 모니터링

class NormalizationDebugger:
    """정규화 레이어 디버깅 도구"""

    def __init__(self, model):
        self.model = model
        self.stats_history = []

    def capture_stats(self):
        """현재 정규화 통계 캡처"""
        stats = {}

        for name, module in self.model.named_modules():
            if isinstance(module, nn.BatchNorm2d):
                stats[name] = {
                    'type': 'BatchNorm',
                    'running_mean': module.running_mean.mean().item(),
                    'running_var': module.running_var.mean().item(),
                    'weight_mean': module.weight.mean().item(),
                    'bias_mean': module.bias.mean().item()
                }
            elif isinstance(module, nn.LayerNorm):
                stats[name] = {
                    'type': 'LayerNorm',
                    'weight_mean': module.weight.mean().item(),
                    'weight_std': module.weight.std().item(),
                    'bias_mean': module.bias.mean().item()
                }

        self.stats_history.append(stats)
        return stats

    def check_health(self, stats=None):
        """정규화 레이어 건강 체크"""
        if stats is None:
            stats = self.capture_stats()

        issues = []

        for name, info in stats.items():
            if info['type'] == 'BatchNorm':
                if info['running_var'] < 1e-5:
                    issues.append(f"[WARNING] {name}: running_var very small ({info['running_var']:.2e})")
                if info['running_var'] > 100:
                    issues.append(f"[WARNING] {name}: running_var very large ({info['running_var']:.2e})")

            if abs(info.get('weight_mean', 1) - 1) > 0.5:
                issues.append(f"[INFO] {name}: weight mean far from 1 ({info['weight_mean']:.3f})")

        return issues

# 사용
debugger = NormalizationDebugger(model)

for epoch in range(num_epochs):
    train_one_epoch(model, train_loader)

    issues = debugger.check_health()
    if issues:
        print(f"Epoch {epoch}:")
        for issue in issues:
            print(f"  {issue}")

일반적인 문제와 해결

문제 증상 해결책
배치 크기가 너무 작음 BatchNorm 불안정 GroupNorm 또는 LayerNorm 사용
eval() 미호출 추론 시 다른 결과 model.eval() 호출 확인
분산이 0 NaN 또는 Inf eps 증가 또는 입력 확인
학습 느림 정규화 후 분포 이상 gamma/beta 초기화 확인

Pre-LN vs Post-LN 비교

# Post-LN (원래 Transformer) - 불안정할 수 있음
class PostLNBlock(nn.Module):
    def forward(self, x):
        x = self.norm1(x + self.attn(x))  # Norm이 뒤에
        x = self.norm2(x + self.ffn(x))
        return x

# Pre-LN (GPT 스타일) - 더 안정적
class PreLNBlock(nn.Module):
    def forward(self, x):
        x = x + self.attn(self.norm1(x))  # Norm이 앞에
        x = x + self.ffn(self.norm2(x))
        return x

# Pre-LN 장점:
# - 기울기 흐름이 더 안정적 (잔차 연결을 통해 직접 전파)
# - 학습률에 덜 민감
# - Warmup 없이도 학습 가능한 경우 있음

# 대부분의 현대 LLM은 Pre-LN + RMSNorm 사용

분산 학습에서의 BatchNorm

# 단일 GPU BatchNorm - 각 GPU의 배치만 사용
bn = nn.BatchNorm2d(64)

# SyncBatchNorm - 모든 GPU의 통계 동기화
bn = nn.SyncBatchNorm(64)

# 모델 전체를 SyncBatchNorm으로 변환
model = nn.SyncBatchNorm.convert_sync_batchnorm(model)

# 주의: SyncBatchNorm은 통신 오버헤드 발생
# 대규모 분산 학습에서는 LayerNorm 선호

참고 자료