Lottery Ticket Hypothesis¶
개요¶
Lottery Ticket Hypothesis(LTH)는 "밀집 신경망 내에 처음부터 효율적으로 학습 가능한 희소 서브네트워크(winning ticket)가 존재한다"는 가설이다. MIT의 Jonathan Frankle과 Michael Carlin이 2019년 NeurIPS에서 발표했으며, Best Paper Award를 수상했다.
| 항목 | 내용 |
|---|---|
| 논문 | The Lottery Ticket Hypothesis: Finding Sparse, Trainable Neural Networks |
| 저자 | Jonathan Frankle, Michael Carlin |
| 발표 | NeurIPS 2019 (Best Paper Award) |
| 분야 | Neural Network Pruning, Efficient Deep Learning |
핵심 아이디어¶
기존 프루닝의 한계¶
기존 neural network pruning은 학습 완료 후 가중치를 제거하고, 남은 네트워크를 fine-tuning하는 방식이었다.
기존 방식:
Dense Network (학습) -> Pruning -> Sparse Network (fine-tuning)
문제점:
- 희소 네트워크를 처음부터 학습하면 성능 저하
- "큰 모델로 학습해야 한다"는 통념
Lottery Ticket Hypothesis¶
LTH는 이 통념을 뒤집는다:
핵심 주장:
밀집 네트워크 내에는 "winning ticket" (당첨 복권)이 존재한다.
이 서브네트워크를 초기 가중치로 재설정하고 학습하면,
원본 네트워크와 동일한 정확도를 같은 학습 횟수 내에 달성할 수 있다.
비유로 설명하면: - 밀집 네트워크 = 복권 다발 - Winning ticket = 당첨 복권 (희소 서브네트워크) - 초기 가중치 = 복권 번호 (이 "번호"가 핵심)
이론적 배경¶
Winning Ticket 정의¶
네트워크 f(x; theta)에서 winning ticket은 다음 조건을 만족하는 마스크 m과 초기 가중치 theta_0의 조합이다:
조건:
1. f(x; m * theta_0)를 k iterations 학습 시,
f(x; theta_0)를 k iterations 학습한 것과 비슷한 test accuracy 달성
2. ||m||_0 << ||theta||_0 (희소성)
여기서:
- m: 이진 마스크 (0 또는 1)
- theta_0: 원본 초기 가중치
- m * theta_0: element-wise 곱 (마스킹된 가중치)
초기 가중치의 중요성¶
LTH의 핵심 통찰: 어떤 가중치를 남길지(구조)뿐 아니라 초기값이 중요하다.
Iterative Magnitude Pruning (IMP)¶
Winning ticket을 찾는 핵심 알고리즘이다.
알고리즘¶
Algorithm: Iterative Magnitude Pruning
Input:
- 네트워크 구조
- 목표 희소도 s
- 프루닝 비율 p (보통 20%)
- 학습 iterations T
1. theta_0 ~ 초기화
2. 마스크 m = 1 (모든 가중치 활성화)
3. REPEAT until sparsity >= s:
a. 네트워크 학습: theta_0 -> theta_T
b. 마스크 m에서 magnitude가 가장 작은 p% 제거
c. 남은 가중치를 theta_0로 재설정 (rewinding)
4. RETURN (m, theta_0) as winning ticket
시각화¶
Round 1: [################] 100% -> prune 20% -> [############____] 80%
|
rewind to theta_0
v
Round 2: [############____] 80% -> prune 20% -> [##########______] 64%
|
rewind to theta_0
v
Round 3: [##########______] 64% -> prune 20% -> [########________] 51%
...
주요 발견¶
1. 희소도와 정확도 관계¶
| 희소도 | LeNet (MNIST) | VGG-19 (CIFAR-10) | ResNet-18 (CIFAR-10) |
|---|---|---|---|
| 0% | 98.3% | 93.5% | 93.2% |
| 50% | 98.3% | 93.6% | 93.3% |
| 90% | 98.2% | 93.4% | 93.0% |
| 95% | 97.9% | 92.8% | 92.1% |
| 99% | 96.5% | 88.2% | 85.4% |
90% 이상의 파라미터를 제거해도 성능이 거의 유지된다.
2. 학습 속도¶
Winning ticket은 원본 네트워크보다 빠르게 수렴한다.
학습 곡선 비교:
Epoch 1 5 10 20 30
Dense 0.85 0.92 0.94 0.95 0.96
Win-90% 0.87 0.93 0.95 0.96 0.96 (더 빠른 수렴)
Random-90% 0.72 0.83 0.88 0.90 0.91 (느린 수렴, 낮은 최종 성능)
3. 전이 가능성¶
한 태스크에서 찾은 winning ticket이 다른 태스크에서도 유효하다.
Late Rewinding (Stabilizing the Lottery Ticket Hypothesis)¶
대규모 모델(ImageNet, BERT)에서는 theta_0로 완전히 돌아가면 불안정할 수 있다. Late Rewinding은 초기 k iterations 후의 가중치로 돌아간다.
Late Rewinding:
- 기존: theta_0 (iteration 0)로 rewind
- 개선: theta_k (iteration k, 보통 0.1%-1% 지점)로 rewind
효과:
- 대규모 모델에서 안정성 향상
- ImageNet ResNet-50에서 80% pruning 달성
확장과 변형¶
1. One-shot vs Iterative Pruning¶
| 방법 | 프로세스 | 장점 | 단점 |
|---|---|---|---|
| One-shot | 한 번에 목표 희소도까지 pruning | 빠름 | 높은 희소도에서 성능 저하 |
| Iterative | 점진적 pruning + rewinding | 더 좋은 winning ticket | 계산 비용 큼 |
2. Structured vs Unstructured Pruning¶
Unstructured (LTH 원본):
[1.2, 0.0, 0.5, 0.0, 0.8, 0.0, 0.3, 0.0]
-> 개별 가중치 제거
-> 이론적 최적, 하드웨어 가속 어려움
Structured:
[1.2, 0.5, 0.8, 0.3] [0.0, 0.0, 0.0, 0.0]
-> 필터/채널 단위 제거
-> 하드웨어 친화적, 더 많은 손실
3. Pruning at Initialization (PaI)¶
학습 없이 초기화 시점에서 winning ticket을 찾는 방법들:
| 방법 | 논문 | 핵심 아이디어 |
|---|---|---|
| SNIP | ICLR 2019 | Connection sensitivity 기반 |
| GraSP | ICLR 2020 | Gradient flow preservation |
| SynFlow | NeurIPS 2020 | Synaptic flow conservation |
| ProsPr | ICLR 2022 | Prospect pruning |
LLM에서의 적용¶
LLM 시대에 LTH는 새로운 의미를 가진다.
Transformer Pruning¶
적용 전략¶
| 레이어 | Pruning 비율 | 이유 |
|---|---|---|
| Embedding | 낮음 (20-30%) | 입력 표현 보존 |
| Attention | 중간 (50-60%) | 중요 head 선별 |
| FFN | 높음 (70-80%) | 중복 뉴런 많음 |
| LM Head | 낮음 (20-30%) | 출력 품질 보존 |
Python 구현¶
기본 Iterative Magnitude Pruning¶
import torch
import torch.nn as nn
import torch.nn.utils.prune as prune
import copy
class IterativeMagnitudePruning:
"""Lottery Ticket Hypothesis를 위한 IMP 구현"""
def __init__(
self,
model: nn.Module,
target_sparsity: float = 0.9,
pruning_rate: float = 0.2,
rewinding_epoch: int = 0
):
self.model = model
self.target_sparsity = target_sparsity
self.pruning_rate = pruning_rate
self.rewinding_epoch = rewinding_epoch
# 초기 가중치 저장
self.initial_weights = self._save_weights()
self.rewinding_weights = None
def _save_weights(self):
"""현재 가중치 저장"""
weights = {}
for name, param in self.model.named_parameters():
weights[name] = param.data.clone()
return weights
def _load_weights(self, weights, keep_mask=True):
"""저장된 가중치 복원 (마스크 유지 가능)"""
for name, param in self.model.named_parameters():
if name in weights:
if keep_mask and hasattr(param, '_mask'):
# 마스크가 있으면 마스킹된 위치만 복원
param.data = weights[name] * param._mask
else:
param.data = weights[name].clone()
def save_rewinding_checkpoint(self, epoch):
"""Late rewinding을 위한 체크포인트 저장"""
if epoch == self.rewinding_epoch:
self.rewinding_weights = self._save_weights()
def get_prunable_layers(self):
"""프루닝 가능한 레이어 반환"""
layers = []
for name, module in self.model.named_modules():
if isinstance(module, (nn.Linear, nn.Conv2d)):
layers.append((module, 'weight'))
return layers
def get_current_sparsity(self):
"""현재 전체 희소도 계산"""
total_params = 0
zero_params = 0
for name, param in self.model.named_parameters():
if 'weight' in name:
total_params += param.numel()
zero_params += (param == 0).sum().item()
return zero_params / total_params if total_params > 0 else 0
def prune_step(self):
"""한 라운드 프루닝 수행"""
layers = self.get_prunable_layers()
# Global magnitude pruning
prune.global_unstructured(
layers,
pruning_method=prune.L1Unstructured,
amount=self.pruning_rate,
)
return self.get_current_sparsity()
def rewind_weights(self):
"""가중치를 초기값으로 되돌리기 (마스크 유지)"""
target_weights = (
self.rewinding_weights
if self.rewinding_weights is not None
else self.initial_weights
)
for name, module in self.model.named_modules():
if isinstance(module, (nn.Linear, nn.Conv2d)):
# 마스크 가져오기
if hasattr(module, 'weight_mask'):
mask = module.weight_mask
# 프루닝 제거 (가중치에 마스크 적용)
prune.remove(module, 'weight')
# 초기 가중치로 교체 후 마스크 재적용
orig_weight = target_weights[f"{name}.weight"]
module.weight.data = orig_weight * mask
# 마스크 다시 적용
prune.custom_from_mask(module, 'weight', mask)
def find_winning_ticket(self, train_fn, epochs_per_round):
"""
Winning ticket 탐색
Args:
train_fn: 학습 함수 (model, epochs) -> accuracy
epochs_per_round: 각 라운드당 학습 에폭
Returns:
최종 희소도, 정확도
"""
round_num = 0
while self.get_current_sparsity() < self.target_sparsity:
round_num += 1
print(f"\n=== Round {round_num} ===")
# 학습
accuracy = train_fn(self.model, epochs_per_round)
print(f"Accuracy after training: {accuracy:.4f}")
# 프루닝
sparsity = self.prune_step()
print(f"Sparsity after pruning: {sparsity:.2%}")
if sparsity >= self.target_sparsity:
break
# Rewinding
self.rewind_weights()
print("Weights rewound to initial values")
# 최종 학습
print("\n=== Final Training ===")
final_accuracy = train_fn(self.model, epochs_per_round)
final_sparsity = self.get_current_sparsity()
return final_sparsity, final_accuracy
def make_permanent(model):
"""프루닝을 영구적으로 만들기 (마스크 제거)"""
for name, module in model.named_modules():
if isinstance(module, (nn.Linear, nn.Conv2d)):
if hasattr(module, 'weight_mask'):
prune.remove(module, 'weight')
return model
완전한 실험 예시¶
import torch
import torch.nn as nn
import torch.optim as optim
from torchvision import datasets, transforms
from torch.utils.data import DataLoader
# 간단한 CNN 모델
class SimpleCNN(nn.Module):
def __init__(self):
super().__init__()
self.conv1 = nn.Conv2d(1, 32, 3, 1)
self.conv2 = nn.Conv2d(32, 64, 3, 1)
self.fc1 = nn.Linear(9216, 128)
self.fc2 = nn.Linear(128, 10)
def forward(self, x):
x = torch.relu(self.conv1(x))
x = torch.relu(self.conv2(x))
x = torch.max_pool2d(x, 2)
x = torch.flatten(x, 1)
x = torch.relu(self.fc1(x))
x = self.fc2(x)
return x
def train_and_evaluate(model, epochs, train_loader, test_loader, device):
"""학습 및 평가 함수"""
model.train()
optimizer = optim.Adam(model.parameters(), lr=1e-3)
criterion = nn.CrossEntropyLoss()
for epoch in range(epochs):
for batch_idx, (data, target) in enumerate(train_loader):
data, target = data.to(device), target.to(device)
optimizer.zero_grad()
output = model(data)
loss = criterion(output, target)
loss.backward()
optimizer.step()
# 평가
model.eval()
correct = 0
total = 0
with torch.no_grad():
for data, target in test_loader:
data, target = data.to(device), target.to(device)
output = model(data)
pred = output.argmax(dim=1)
correct += pred.eq(target).sum().item()
total += target.size(0)
return correct / total
def run_lottery_ticket_experiment():
"""Lottery Ticket 실험 실행"""
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
# 데이터 로드
transform = transforms.Compose([
transforms.ToTensor(),
transforms.Normalize((0.1307,), (0.3081,))
])
train_dataset = datasets.MNIST('./data', train=True, download=True, transform=transform)
test_dataset = datasets.MNIST('./data', train=False, transform=transform)
train_loader = DataLoader(train_dataset, batch_size=64, shuffle=True)
test_loader = DataLoader(test_dataset, batch_size=1000)
# 모델 생성
model = SimpleCNN().to(device)
# IMP 초기화
imp = IterativeMagnitudePruning(
model,
target_sparsity=0.9, # 90% 희소도 목표
pruning_rate=0.2, # 각 라운드 20% 제거
rewinding_epoch=0 # 초기 가중치로 rewind
)
# 학습 함수 정의
def train_fn(m, epochs):
return train_and_evaluate(m, epochs, train_loader, test_loader, device)
# Winning ticket 탐색
sparsity, accuracy = imp.find_winning_ticket(train_fn, epochs_per_round=5)
print(f"\nFinal Results:")
print(f"Sparsity: {sparsity:.2%}")
print(f"Accuracy: {accuracy:.4f}")
# 프루닝 영구화
model = make_permanent(model)
return model
if __name__ == "__main__":
run_lottery_ticket_experiment()
Late Rewinding 구현¶
class LateRewindingIMP(IterativeMagnitudePruning):
"""Late Rewinding을 지원하는 IMP"""
def __init__(
self,
model: nn.Module,
target_sparsity: float = 0.9,
pruning_rate: float = 0.2,
rewinding_ratio: float = 0.01 # 전체 학습의 1% 지점
):
super().__init__(model, target_sparsity, pruning_rate)
self.rewinding_ratio = rewinding_ratio
self.rewinding_saved = False
def find_winning_ticket_with_late_rewind(
self,
train_fn,
total_epochs,
epochs_per_round
):
"""Late rewinding을 사용한 winning ticket 탐색"""
rewinding_epoch = int(total_epochs * self.rewinding_ratio)
round_num = 0
while self.get_current_sparsity() < self.target_sparsity:
round_num += 1
print(f"\n=== Round {round_num} ===")
# 학습 (rewinding checkpoint 저장 포함)
for epoch in range(epochs_per_round):
accuracy = train_fn(self.model, 1)
# Late rewinding checkpoint
if not self.rewinding_saved and epoch == rewinding_epoch:
self.rewinding_weights = self._save_weights()
self.rewinding_saved = True
print(f"Saved rewinding checkpoint at epoch {epoch}")
print(f"Accuracy: {accuracy:.4f}")
# 프루닝
sparsity = self.prune_step()
print(f"Sparsity: {sparsity:.2%}")
if sparsity >= self.target_sparsity:
break
# Late rewinding
self.rewind_weights()
# 최종 학습
final_accuracy = train_fn(self.model, epochs_per_round)
return self.get_current_sparsity(), final_accuracy
프루닝 분석 유틸리티¶
def analyze_sparsity_per_layer(model):
"""레이어별 희소도 분석"""
results = []
for name, module in model.named_modules():
if isinstance(module, (nn.Linear, nn.Conv2d)):
weight = module.weight.data
total = weight.numel()
zeros = (weight == 0).sum().item()
sparsity = zeros / total
results.append({
'layer': name,
'type': type(module).__name__,
'params': total,
'zeros': zeros,
'sparsity': sparsity
})
return results
def plot_sparsity_distribution(model):
"""희소도 분포 시각화"""
import matplotlib.pyplot as plt
analysis = analyze_sparsity_per_layer(model)
layers = [r['layer'] for r in analysis]
sparsities = [r['sparsity'] for r in analysis]
plt.figure(figsize=(12, 6))
plt.bar(range(len(layers)), sparsities)
plt.xticks(range(len(layers)), layers, rotation=45, ha='right')
plt.ylabel('Sparsity')
plt.title('Sparsity Distribution per Layer')
plt.tight_layout()
plt.show()
def count_parameters(model, count_zeros=False):
"""파라미터 수 카운트"""
total = 0
nonzero = 0
for param in model.parameters():
total += param.numel()
nonzero += (param != 0).sum().item()
if count_zeros:
return total, nonzero, total - nonzero
return nonzero # 유효 파라미터 수
실무 가이드라인¶
하이퍼파라미터 권장값¶
| 파라미터 | 소규모 모델 | 대규모 모델 | 설명 |
|---|---|---|---|
| 목표 희소도 | 90-95% | 70-80% | 모델 크기에 반비례 |
| 프루닝 비율 | 20% | 10-20% | 보수적일수록 안정 |
| Rewinding 지점 | epoch 0 | epoch k (0.1-1%) | Late rewinding |
| 라운드당 에폭 | Full training | 10-20% 학습 | 계산 효율 |
주의사항¶
1. 계산 비용
- IMP는 여러 번 학습해야 함 (n rounds x training)
- 대안: One-shot pruning, PaI methods
2. 하드웨어 가속
- Unstructured sparsity는 GPU 가속 어려움
- 실제 속도 향상은 structured pruning 필요
3. 태스크 전이
- 같은 데이터셋 내 전이: 효과적
- 다른 도메인 전이: 성능 저하 가능
4. 스케일링
- 모델이 클수록 late rewinding 필요
- ImageNet 급에서는 0.1-1% 학습 후 rewind
관련 연구 흐름¶
LTH (2019)
|
+-- Stabilizing LTH (2020): Late Rewinding
|
+-- SNIP, GraSP, SynFlow (2019-2020): Pruning at Initialization
|
+-- Linear Mode Connectivity (2020): 이론적 분석
|
+-- Dual LTH (2022): 학습-비학습 균형
|
+-- LTH for Transformers (2021-2023): BERT, GPT 적용
|
+-- LTH for LLMs (2023-2024): LLaMA, Mistral 적용
참고 자료¶
핵심 논문¶
- Frankle & Carlin (2019). The Lottery Ticket Hypothesis: Finding Sparse, Trainable Neural Networks. NeurIPS 2019.
- Frankle et al. (2020). Stabilizing the Lottery Ticket Hypothesis. arXiv:1903.01611.
- Frankle et al. (2020). Linear Mode Connectivity and the Lottery Ticket Hypothesis. ICML 2020.
- Chen et al. (2021). The Lottery Ticket Hypothesis for Pre-trained BERT Networks. NeurIPS 2021.
Survey¶
- Hoefler et al. (2021). Sparsity in Deep Learning: Pruning and growth for efficient inference and training in neural networks. JMLR.
- LTH Survey (2024). A Survey of Lottery Ticket Hypothesis. arXiv:2403.04861.
관련 개념¶
- Mixture-of-Experts: 조건부 계산
- Mixture-of-Depths: 동적 토큰 처리
- Knowledge Distillation: 모델 압축