가중치 초기화 (Weight Initialization)¶
신경망 학습의 출발점을 결정하는 핵심 요소. 잘못된 초기화는 기울기 소실/폭발, 학습 실패로 이어진다.
왜 초기화가 중요한가¶
분산 전파 문제¶
각 층을 지날 때 활성화 값의 분산이 변화함:
시각화¶
import torch
import torch.nn as nn
import matplotlib.pyplot as plt
def visualize_activation_distribution(init_scale, num_layers=10, dim=512):
"""각 층의 활성화 분포 시각화"""
x = torch.randn(1000, dim)
activations = [x]
for _ in range(num_layers):
w = torch.randn(dim, dim) * init_scale
x = torch.relu(x @ w)
activations.append(x)
fig, axes = plt.subplots(2, 5, figsize=(15, 6))
for i, (ax, act) in enumerate(zip(axes.flatten(), activations)):
ax.hist(act.flatten().numpy(), bins=50, density=True)
ax.set_title(f'Layer {i}, std={act.std():.4f}')
plt.tight_layout()
return activations
# 비교
visualize_activation_distribution(0.01) # 너무 작음
visualize_activation_distribution(1.0) # 너무 큼
visualize_activation_distribution(0.05) # 적절
초기화 방법들¶
Zero Initialization (사용 금지)¶
# 절대 사용하지 말 것!
nn.init.zeros_(layer.weight)
# 문제점:
# 1. 모든 뉴런이 동일한 기울기 → 동일하게 업데이트
# 2. 대칭성 파괴 불가 → 학습 불가
Random Initialization¶
# 단순 랜덤 (비추천)
nn.init.normal_(layer.weight, mean=0, std=0.01)
nn.init.uniform_(layer.weight, a=-0.1, b=0.1)
# 문제점: 깊은 네트워크에서 분산 조절 어려움
Xavier/Glorot Initialization¶
Sigmoid, Tanh 활성화 함수에 적합.
수식:
\[W \sim \mathcal{U}\left(-\sqrt{\frac{6}{n_{in} + n_{out}}}, \sqrt{\frac{6}{n_{in} + n_{out}}}\right)\]
또는
\[W \sim \mathcal{N}\left(0, \sqrt{\frac{2}{n_{in} + n_{out}}}\right)\]
직관: - 입출력 분산을 동일하게 유지 - 순전파와 역전파 모두에서 분산 보존
# PyTorch
nn.init.xavier_uniform_(layer.weight)
nn.init.xavier_normal_(layer.weight)
# 직접 구현
import math
def xavier_uniform(tensor, fan_in, fan_out):
bound = math.sqrt(6.0 / (fan_in + fan_out))
tensor.uniform_(-bound, bound)
def xavier_normal(tensor, fan_in, fan_out):
std = math.sqrt(2.0 / (fan_in + fan_out))
tensor.normal_(0, std)
He/Kaiming Initialization¶
ReLU 계열 활성화 함수에 적합. ReLU가 음수를 0으로 만들어 분산이 절반으로 줄어드는 것을 보정.
수식:
\[W \sim \mathcal{N}\left(0, \sqrt{\frac{2}{n_{in}}}\right)\]
직관: - ReLU는 입력의 절반(음수)을 0으로 만듦 - 분산 보존을 위해 2배 스케일링
# PyTorch
nn.init.kaiming_uniform_(layer.weight, mode='fan_in', nonlinearity='relu')
nn.init.kaiming_normal_(layer.weight, mode='fan_in', nonlinearity='relu')
# fan_in vs fan_out
# fan_in: 순전파 분산 보존 (일반적)
# fan_out: 역전파 분산 보존
# Leaky ReLU의 경우
nn.init.kaiming_normal_(layer.weight, a=0.01, nonlinearity='leaky_relu')
# 직접 구현
def he_normal(tensor, fan_in, negative_slope=0):
gain = math.sqrt(2.0 / (1 + negative_slope ** 2))
std = gain / math.sqrt(fan_in)
tensor.normal_(0, std)
LeCun Initialization¶
SELU 활성화 함수에 적합.
\[W \sim \mathcal{N}\left(0, \sqrt{\frac{1}{n_{in}}}\right)\]
Orthogonal Initialization¶
순환 신경망(RNN)에서 효과적.
nn.init.orthogonal_(layer.weight, gain=1.0)
# 직접 구현
def orthogonal_init(tensor, gain=1.0):
rows, cols = tensor.shape
flat = torch.randn(rows, cols)
q, r = torch.linalg.qr(flat)
d = torch.diag(r, 0)
ph = d.sign()
q *= ph
if rows < cols:
q = q.T
tensor.copy_(gain * q[:rows, :cols])
활성화 함수별 권장 초기화¶
| 활성화 함수 | 권장 초기화 | gain |
|---|---|---|
| Sigmoid | Xavier | 1.0 |
| Tanh | Xavier | 5/3 ≈ 1.67 |
| ReLU | He (Kaiming) | √2 |
| Leaky ReLU (a=0.01) | He | √(2/(1+0.01²)) |
| SELU | LeCun | 1.0 |
| GELU | He (경험적) | √2 |
| SiLU/Swish | He (경험적) | √2 |
def init_weights(module):
"""활성화 함수에 맞는 초기화"""
if isinstance(module, nn.Linear):
# Transformer에서는 보통 normal 사용
nn.init.normal_(module.weight, std=0.02)
if module.bias is not None:
nn.init.zeros_(module.bias)
elif isinstance(module, nn.Conv2d):
# CNN에서는 He 초기화
nn.init.kaiming_normal_(module.weight, mode='fan_out', nonlinearity='relu')
if module.bias is not None:
nn.init.zeros_(module.bias)
elif isinstance(module, (nn.LayerNorm, nn.BatchNorm2d)):
nn.init.ones_(module.weight)
nn.init.zeros_(module.bias)
elif isinstance(module, nn.Embedding):
nn.init.normal_(module.weight, std=0.02)
model.apply(init_weights)
특수 케이스¶
Residual Network 초기화¶
잔차 블록이 많을 때 출력 스케일이 커지는 문제 해결.
class ResidualBlock(nn.Module):
def __init__(self, dim, layer_idx, total_layers):
super().__init__()
self.linear1 = nn.Linear(dim, dim)
self.linear2 = nn.Linear(dim, dim)
# 잔차 브랜치 스케일링 (GPT-2 스타일)
# 깊이에 따라 초기화 스케일 조정
scale = 1 / math.sqrt(2 * total_layers)
nn.init.normal_(self.linear2.weight, std=0.02 * scale)
def forward(self, x):
return x + self.linear2(F.relu(self.linear1(x)))
# 또는 ReZero 방식
class ReZeroBlock(nn.Module):
def __init__(self, dim):
super().__init__()
self.linear = nn.Linear(dim, dim)
self.alpha = nn.Parameter(torch.zeros(1)) # 0으로 시작
def forward(self, x):
return x + self.alpha * self.linear(x)
Transformer 초기화¶
class TransformerInitMixin:
def _init_weights(self, module):
if isinstance(module, nn.Linear):
torch.nn.init.normal_(module.weight, mean=0.0, std=0.02)
if module.bias is not None:
torch.nn.init.zeros_(module.bias)
elif isinstance(module, nn.Embedding):
torch.nn.init.normal_(module.weight, mean=0.0, std=0.02)
elif isinstance(module, nn.LayerNorm):
torch.nn.init.zeros_(module.bias)
torch.nn.init.ones_(module.weight)
# 출력 프로젝션 스케일링 (GPT-2)
for name, p in module.named_parameters():
if name.endswith('c_proj.weight'):
torch.nn.init.normal_(
p,
mean=0.0,
std=0.02 / math.sqrt(2 * self.config.n_layer)
)
사전학습 모델 위의 추가 층¶
# 사전학습된 인코더 위에 새 층 추가
class FineTuneModel(nn.Module):
def __init__(self, pretrained_encoder, num_classes):
super().__init__()
self.encoder = pretrained_encoder
self.classifier = nn.Linear(768, num_classes)
# 새 층만 초기화 (작은 스케일)
nn.init.normal_(self.classifier.weight, std=0.01)
nn.init.zeros_(self.classifier.bias)
디버깅 가이드¶
초기화 문제 진단¶
def diagnose_initialization(model, sample_input):
"""초기화 상태 진단"""
model.eval()
activations = {}
gradients = {}
def save_activation(name):
def hook(module, input, output):
activations[name] = output.detach()
return hook
def save_gradient(name):
def hook(module, grad_input, grad_output):
gradients[name] = grad_output[0].detach()
return hook
# Hook 등록
for name, module in model.named_modules():
if isinstance(module, nn.Linear):
module.register_forward_hook(save_activation(name))
module.register_full_backward_hook(save_gradient(name))
# Forward & Backward
output = model(sample_input)
loss = output.sum()
loss.backward()
# 분석
print("=== Activation Statistics ===")
for name, act in activations.items():
print(f"{name}: mean={act.mean():.4f}, std={act.std():.4f}, "
f"dead={100*(act==0).float().mean():.1f}%")
print("\n=== Gradient Statistics ===")
for name, grad in gradients.items():
print(f"{name}: mean={grad.mean():.6f}, std={grad.std():.6f}")
# 사용
model = MyModel()
diagnose_initialization(model, torch.randn(32, 784))
일반적인 문제와 해결¶
| 증상 | 원인 | 해결책 |
|---|---|---|
| 모든 출력이 거의 동일 | 초기화가 너무 작거나 0 | He/Xavier 초기화 사용 |
| 학습 초기에 loss=NaN | 초기화가 너무 큼 | 스케일 줄이기, gradient clipping |
| 기울기가 0 (dead neurons) | ReLU + 부적절한 초기화 | He 초기화, Leaky ReLU |
| 깊은 층에서 활성화 0 | 기울기 소실 | 잔차 연결, BatchNorm |
| 학습이 매우 느림 | 분산 불균형 | 적절한 초기화 + LR warmup |
체크리스트¶
def init_checklist(model):
"""초기화 체크리스트"""
issues = []
for name, param in model.named_parameters():
if 'weight' in name:
std = param.std().item()
mean = param.mean().item()
# 체크 1: 분산이 너무 작거나 큼
if std < 1e-4:
issues.append(f"{name}: std too small ({std:.6f})")
if std > 1.0:
issues.append(f"{name}: std too large ({std:.6f})")
# 체크 2: 평균이 0이 아님
if abs(mean) > 0.1:
issues.append(f"{name}: mean not zero ({mean:.4f})")
# 체크 3: 모든 값이 동일
if param.max() == param.min():
issues.append(f"{name}: all values identical!")
return issues
issues = init_checklist(model)
for issue in issues:
print(f"[WARNING] {issue}")
실무 권장사항¶
Framework 기본값¶
# PyTorch nn.Linear 기본 초기화
# weight: Uniform(-1/√fan_in, 1/√fan_in)
# bias: Uniform(-1/√fan_in, 1/√fan_in)
# 대부분의 경우 명시적 초기화 권장
def init_linear(m):
if isinstance(m, nn.Linear):
nn.init.kaiming_uniform_(m.weight, a=math.sqrt(5))
if m.bias is not None:
fan_in, _ = nn.init._calculate_fan_in_and_fan_out(m.weight)
bound = 1 / math.sqrt(fan_in) if fan_in > 0 else 0
nn.init.uniform_(m.bias, -bound, bound)
모던 LLM 초기화 패턴¶
# GPT-2/3 스타일
std = 0.02
for module in model.modules():
if isinstance(module, nn.Linear):
module.weight.data.normal_(mean=0.0, std=std)
if module.bias is not None:
module.bias.data.zero_()
elif isinstance(module, nn.Embedding):
module.weight.data.normal_(mean=0.0, std=std)
# LLaMA 스타일
std = 0.02
for name, param in model.named_parameters():
if param.dim() == 1: # bias, layernorm
param.data.zero_()
elif param.dim() == 2: # weight
torch.nn.init.normal_(param, mean=0.0, std=std)
참고 자료¶
- Xavier Initialization Paper
- He Initialization Paper
- Fixup Initialization - 잔차 네트워크 초기화
- ReZero Paper - 단순하고 효과적인 초기화
- Weight Initialization - Wikipedia