콘텐츠로 이동
Data Prep
상세

가중치 초기화 (Weight Initialization)

신경망 학습의 출발점을 결정하는 핵심 요소. 잘못된 초기화는 기울기 소실/폭발, 학습 실패로 이어진다.

왜 초기화가 중요한가

분산 전파 문제

각 층을 지날 때 활성화 값의 분산이 변화함:

weight-initialization diagram 1

시각화

import torch
import torch.nn as nn
import matplotlib.pyplot as plt

def visualize_activation_distribution(init_scale, num_layers=10, dim=512):
    """각 층의 활성화 분포 시각화"""
    x = torch.randn(1000, dim)

    activations = [x]
    for _ in range(num_layers):
        w = torch.randn(dim, dim) * init_scale
        x = torch.relu(x @ w)
        activations.append(x)

    fig, axes = plt.subplots(2, 5, figsize=(15, 6))
    for i, (ax, act) in enumerate(zip(axes.flatten(), activations)):
        ax.hist(act.flatten().numpy(), bins=50, density=True)
        ax.set_title(f'Layer {i}, std={act.std():.4f}')
    plt.tight_layout()
    return activations

# 비교
visualize_activation_distribution(0.01)   # 너무 작음
visualize_activation_distribution(1.0)    # 너무 큼
visualize_activation_distribution(0.05)   # 적절

초기화 방법들

Zero Initialization (사용 금지)

# 절대 사용하지 말 것!
nn.init.zeros_(layer.weight)

# 문제점:
# 1. 모든 뉴런이 동일한 기울기 → 동일하게 업데이트
# 2. 대칭성 파괴 불가 → 학습 불가

Random Initialization

# 단순 랜덤 (비추천)
nn.init.normal_(layer.weight, mean=0, std=0.01)
nn.init.uniform_(layer.weight, a=-0.1, b=0.1)

# 문제점: 깊은 네트워크에서 분산 조절 어려움

Xavier/Glorot Initialization

Sigmoid, Tanh 활성화 함수에 적합.

수식:

\[W \sim \mathcal{U}\left(-\sqrt{\frac{6}{n_{in} + n_{out}}}, \sqrt{\frac{6}{n_{in} + n_{out}}}\right)\]

또는

\[W \sim \mathcal{N}\left(0, \sqrt{\frac{2}{n_{in} + n_{out}}}\right)\]

직관: - 입출력 분산을 동일하게 유지 - 순전파와 역전파 모두에서 분산 보존

# PyTorch
nn.init.xavier_uniform_(layer.weight)
nn.init.xavier_normal_(layer.weight)

# 직접 구현
import math

def xavier_uniform(tensor, fan_in, fan_out):
    bound = math.sqrt(6.0 / (fan_in + fan_out))
    tensor.uniform_(-bound, bound)

def xavier_normal(tensor, fan_in, fan_out):
    std = math.sqrt(2.0 / (fan_in + fan_out))
    tensor.normal_(0, std)

He/Kaiming Initialization

ReLU 계열 활성화 함수에 적합. ReLU가 음수를 0으로 만들어 분산이 절반으로 줄어드는 것을 보정.

수식:

\[W \sim \mathcal{N}\left(0, \sqrt{\frac{2}{n_{in}}}\right)\]

직관: - ReLU는 입력의 절반(음수)을 0으로 만듦 - 분산 보존을 위해 2배 스케일링

# PyTorch
nn.init.kaiming_uniform_(layer.weight, mode='fan_in', nonlinearity='relu')
nn.init.kaiming_normal_(layer.weight, mode='fan_in', nonlinearity='relu')

# fan_in vs fan_out
# fan_in: 순전파 분산 보존 (일반적)
# fan_out: 역전파 분산 보존

# Leaky ReLU의 경우
nn.init.kaiming_normal_(layer.weight, a=0.01, nonlinearity='leaky_relu')

# 직접 구현
def he_normal(tensor, fan_in, negative_slope=0):
    gain = math.sqrt(2.0 / (1 + negative_slope ** 2))
    std = gain / math.sqrt(fan_in)
    tensor.normal_(0, std)

LeCun Initialization

SELU 활성화 함수에 적합.

\[W \sim \mathcal{N}\left(0, \sqrt{\frac{1}{n_{in}}}\right)\]
nn.init.lecun_normal = lambda t: nn.init.kaiming_normal_(t, a=1, mode='fan_in')

Orthogonal Initialization

순환 신경망(RNN)에서 효과적.

nn.init.orthogonal_(layer.weight, gain=1.0)

# 직접 구현
def orthogonal_init(tensor, gain=1.0):
    rows, cols = tensor.shape
    flat = torch.randn(rows, cols)
    q, r = torch.linalg.qr(flat)
    d = torch.diag(r, 0)
    ph = d.sign()
    q *= ph
    if rows < cols:
        q = q.T
    tensor.copy_(gain * q[:rows, :cols])

활성화 함수별 권장 초기화

활성화 함수 권장 초기화 gain
Sigmoid Xavier 1.0
Tanh Xavier 5/3 ≈ 1.67
ReLU He (Kaiming) √2
Leaky ReLU (a=0.01) He √(2/(1+0.01²))
SELU LeCun 1.0
GELU He (경험적) √2
SiLU/Swish He (경험적) √2
def init_weights(module):
    """활성화 함수에 맞는 초기화"""
    if isinstance(module, nn.Linear):
        # Transformer에서는 보통 normal 사용
        nn.init.normal_(module.weight, std=0.02)
        if module.bias is not None:
            nn.init.zeros_(module.bias)

    elif isinstance(module, nn.Conv2d):
        # CNN에서는 He 초기화
        nn.init.kaiming_normal_(module.weight, mode='fan_out', nonlinearity='relu')
        if module.bias is not None:
            nn.init.zeros_(module.bias)

    elif isinstance(module, (nn.LayerNorm, nn.BatchNorm2d)):
        nn.init.ones_(module.weight)
        nn.init.zeros_(module.bias)

    elif isinstance(module, nn.Embedding):
        nn.init.normal_(module.weight, std=0.02)

model.apply(init_weights)

특수 케이스

Residual Network 초기화

잔차 블록이 많을 때 출력 스케일이 커지는 문제 해결.

class ResidualBlock(nn.Module):
    def __init__(self, dim, layer_idx, total_layers):
        super().__init__()
        self.linear1 = nn.Linear(dim, dim)
        self.linear2 = nn.Linear(dim, dim)

        # 잔차 브랜치 스케일링 (GPT-2 스타일)
        # 깊이에 따라 초기화 스케일 조정
        scale = 1 / math.sqrt(2 * total_layers)
        nn.init.normal_(self.linear2.weight, std=0.02 * scale)

    def forward(self, x):
        return x + self.linear2(F.relu(self.linear1(x)))

# 또는 ReZero 방식
class ReZeroBlock(nn.Module):
    def __init__(self, dim):
        super().__init__()
        self.linear = nn.Linear(dim, dim)
        self.alpha = nn.Parameter(torch.zeros(1))  # 0으로 시작

    def forward(self, x):
        return x + self.alpha * self.linear(x)

Transformer 초기화

class TransformerInitMixin:
    def _init_weights(self, module):
        if isinstance(module, nn.Linear):
            torch.nn.init.normal_(module.weight, mean=0.0, std=0.02)
            if module.bias is not None:
                torch.nn.init.zeros_(module.bias)

        elif isinstance(module, nn.Embedding):
            torch.nn.init.normal_(module.weight, mean=0.0, std=0.02)

        elif isinstance(module, nn.LayerNorm):
            torch.nn.init.zeros_(module.bias)
            torch.nn.init.ones_(module.weight)

        # 출력 프로젝션 스케일링 (GPT-2)
        for name, p in module.named_parameters():
            if name.endswith('c_proj.weight'):
                torch.nn.init.normal_(
                    p, 
                    mean=0.0, 
                    std=0.02 / math.sqrt(2 * self.config.n_layer)
                )

사전학습 모델 위의 추가 층

# 사전학습된 인코더 위에 새 층 추가
class FineTuneModel(nn.Module):
    def __init__(self, pretrained_encoder, num_classes):
        super().__init__()
        self.encoder = pretrained_encoder
        self.classifier = nn.Linear(768, num_classes)

        # 새 층만 초기화 (작은 스케일)
        nn.init.normal_(self.classifier.weight, std=0.01)
        nn.init.zeros_(self.classifier.bias)

디버깅 가이드

초기화 문제 진단

def diagnose_initialization(model, sample_input):
    """초기화 상태 진단"""
    model.eval()

    activations = {}
    gradients = {}

    def save_activation(name):
        def hook(module, input, output):
            activations[name] = output.detach()
        return hook

    def save_gradient(name):
        def hook(module, grad_input, grad_output):
            gradients[name] = grad_output[0].detach()
        return hook

    # Hook 등록
    for name, module in model.named_modules():
        if isinstance(module, nn.Linear):
            module.register_forward_hook(save_activation(name))
            module.register_full_backward_hook(save_gradient(name))

    # Forward & Backward
    output = model(sample_input)
    loss = output.sum()
    loss.backward()

    # 분석
    print("=== Activation Statistics ===")
    for name, act in activations.items():
        print(f"{name}: mean={act.mean():.4f}, std={act.std():.4f}, "
              f"dead={100*(act==0).float().mean():.1f}%")

    print("\n=== Gradient Statistics ===")
    for name, grad in gradients.items():
        print(f"{name}: mean={grad.mean():.6f}, std={grad.std():.6f}")

# 사용
model = MyModel()
diagnose_initialization(model, torch.randn(32, 784))

일반적인 문제와 해결

증상 원인 해결책
모든 출력이 거의 동일 초기화가 너무 작거나 0 He/Xavier 초기화 사용
학습 초기에 loss=NaN 초기화가 너무 큼 스케일 줄이기, gradient clipping
기울기가 0 (dead neurons) ReLU + 부적절한 초기화 He 초기화, Leaky ReLU
깊은 층에서 활성화 0 기울기 소실 잔차 연결, BatchNorm
학습이 매우 느림 분산 불균형 적절한 초기화 + LR warmup

체크리스트

def init_checklist(model):
    """초기화 체크리스트"""
    issues = []

    for name, param in model.named_parameters():
        if 'weight' in name:
            std = param.std().item()
            mean = param.mean().item()

            # 체크 1: 분산이 너무 작거나 큼
            if std < 1e-4:
                issues.append(f"{name}: std too small ({std:.6f})")
            if std > 1.0:
                issues.append(f"{name}: std too large ({std:.6f})")

            # 체크 2: 평균이 0이 아님
            if abs(mean) > 0.1:
                issues.append(f"{name}: mean not zero ({mean:.4f})")

            # 체크 3: 모든 값이 동일
            if param.max() == param.min():
                issues.append(f"{name}: all values identical!")

    return issues

issues = init_checklist(model)
for issue in issues:
    print(f"[WARNING] {issue}")

실무 권장사항

Framework 기본값

# PyTorch nn.Linear 기본 초기화
# weight: Uniform(-1/√fan_in, 1/√fan_in)
# bias: Uniform(-1/√fan_in, 1/√fan_in)

# 대부분의 경우 명시적 초기화 권장
def init_linear(m):
    if isinstance(m, nn.Linear):
        nn.init.kaiming_uniform_(m.weight, a=math.sqrt(5))
        if m.bias is not None:
            fan_in, _ = nn.init._calculate_fan_in_and_fan_out(m.weight)
            bound = 1 / math.sqrt(fan_in) if fan_in > 0 else 0
            nn.init.uniform_(m.bias, -bound, bound)

모던 LLM 초기화 패턴

# GPT-2/3 스타일
std = 0.02
for module in model.modules():
    if isinstance(module, nn.Linear):
        module.weight.data.normal_(mean=0.0, std=std)
        if module.bias is not None:
            module.bias.data.zero_()
    elif isinstance(module, nn.Embedding):
        module.weight.data.normal_(mean=0.0, std=std)

# LLaMA 스타일
std = 0.02
for name, param in model.named_parameters():
    if param.dim() == 1:  # bias, layernorm
        param.data.zero_()
    elif param.dim() == 2:  # weight
        torch.nn.init.normal_(param, mean=0.0, std=std)

참고 자료