콘텐츠로 이동
Data Prep
상세

Lottery Ticket Hypothesis

개요

Lottery Ticket Hypothesis(LTH)는 "밀집 신경망 내에 처음부터 효율적으로 학습 가능한 희소 서브네트워크(winning ticket)가 존재한다"는 가설이다. MIT의 Jonathan Frankle과 Michael Carlin이 2019년 NeurIPS에서 발표했으며, Best Paper Award를 수상했다.

항목 내용
논문 The Lottery Ticket Hypothesis: Finding Sparse, Trainable Neural Networks
저자 Jonathan Frankle, Michael Carlin
발표 NeurIPS 2019 (Best Paper Award)
분야 Neural Network Pruning, Efficient Deep Learning

핵심 아이디어

기존 프루닝의 한계

기존 neural network pruning은 학습 완료 후 가중치를 제거하고, 남은 네트워크를 fine-tuning하는 방식이었다.

기존 방식:
Dense Network (학습) -> Pruning -> Sparse Network (fine-tuning)

문제점:
- 희소 네트워크를 처음부터 학습하면 성능 저하
- "큰 모델로 학습해야 한다"는 통념

Lottery Ticket Hypothesis

LTH는 이 통념을 뒤집는다:

핵심 주장:
밀집 네트워크 내에는 "winning ticket" (당첨 복권)이 존재한다.
이 서브네트워크를 초기 가중치로 재설정하고 학습하면,
원본 네트워크와 동일한 정확도를 같은 학습 횟수 내에 달성할 수 있다.

비유로 설명하면: - 밀집 네트워크 = 복권 다발 - Winning ticket = 당첨 복권 (희소 서브네트워크) - 초기 가중치 = 복권 번호 (이 "번호"가 핵심)

이론적 배경

Winning Ticket 정의

네트워크 f(x; theta)에서 winning ticket은 다음 조건을 만족하는 마스크 m과 초기 가중치 theta_0의 조합이다:

조건:
1. f(x; m * theta_0)를 k iterations 학습 시,
   f(x; theta_0)를 k iterations 학습한 것과 비슷한 test accuracy 달성

2. ||m||_0 << ||theta||_0 (희소성)

여기서:
- m: 이진 마스크 (0 또는 1)
- theta_0: 원본 초기 가중치
- m * theta_0: element-wise 곱 (마스킹된 가중치)

초기 가중치의 중요성

LTH의 핵심 통찰: 어떤 가중치를 남길지(구조)뿐 아니라 초기값이 중요하다.

실험 결과:
- 같은 마스크 m, 랜덤 재초기화 -> 성능 저하
- 같은 마스크 m, 원본 초기값 theta_0 -> 성능 유지

결론: 초기화가 학습 궤적을 결정한다.

Iterative Magnitude Pruning (IMP)

Winning ticket을 찾는 핵심 알고리즘이다.

알고리즘

Algorithm: Iterative Magnitude Pruning

Input:
  - 네트워크 구조
  - 목표 희소도 s
  - 프루닝 비율 p (보통 20%)
  - 학습 iterations T

1. theta_0 ~ 초기화
2. 마스크 m = 1 (모든 가중치 활성화)

3. REPEAT until sparsity >= s:
   a. 네트워크 학습: theta_0 -> theta_T
   b. 마스크 m에서 magnitude가 가장 작은 p% 제거
   c. 남은 가중치를 theta_0로 재설정 (rewinding)

4. RETURN (m, theta_0) as winning ticket

시각화

Round 1:  [################] 100% -> prune 20% -> [############____] 80%
                                                        |
                                          rewind to theta_0
                                                        v
Round 2:  [############____] 80%  -> prune 20% -> [##########______] 64%
                                                        |
                                          rewind to theta_0
                                                        v
Round 3:  [##########______] 64%  -> prune 20% -> [########________] 51%
                ...

주요 발견

1. 희소도와 정확도 관계

희소도 LeNet (MNIST) VGG-19 (CIFAR-10) ResNet-18 (CIFAR-10)
0% 98.3% 93.5% 93.2%
50% 98.3% 93.6% 93.3%
90% 98.2% 93.4% 93.0%
95% 97.9% 92.8% 92.1%
99% 96.5% 88.2% 85.4%

90% 이상의 파라미터를 제거해도 성능이 거의 유지된다.

2. 학습 속도

Winning ticket은 원본 네트워크보다 빠르게 수렴한다.

학습 곡선 비교:
Epoch      1    5    10   20   30
Dense     0.85  0.92  0.94  0.95  0.96
Win-90%   0.87  0.93  0.95  0.96  0.96  (더 빠른 수렴)
Random-90% 0.72  0.83  0.88  0.90  0.91  (느린 수렴, 낮은 최종 성능)

3. 전이 가능성

한 태스크에서 찾은 winning ticket이 다른 태스크에서도 유효하다.

Late Rewinding (Stabilizing the Lottery Ticket Hypothesis)

대규모 모델(ImageNet, BERT)에서는 theta_0로 완전히 돌아가면 불안정할 수 있다. Late Rewinding은 초기 k iterations 후의 가중치로 돌아간다.

Late Rewinding:
- 기존: theta_0 (iteration 0)로 rewind
- 개선: theta_k (iteration k, 보통 0.1%-1% 지점)로 rewind

효과:
- 대규모 모델에서 안정성 향상
- ImageNet ResNet-50에서 80% pruning 달성

확장과 변형

1. One-shot vs Iterative Pruning

방법 프로세스 장점 단점
One-shot 한 번에 목표 희소도까지 pruning 빠름 높은 희소도에서 성능 저하
Iterative 점진적 pruning + rewinding 더 좋은 winning ticket 계산 비용 큼

2. Structured vs Unstructured Pruning

Unstructured (LTH 원본):
[1.2, 0.0, 0.5, 0.0, 0.8, 0.0, 0.3, 0.0]
-> 개별 가중치 제거
-> 이론적 최적, 하드웨어 가속 어려움

Structured:
[1.2, 0.5, 0.8, 0.3] [0.0, 0.0, 0.0, 0.0]
-> 필터/채널 단위 제거
-> 하드웨어 친화적, 더 많은 손실

3. Pruning at Initialization (PaI)

학습 없이 초기화 시점에서 winning ticket을 찾는 방법들:

방법 논문 핵심 아이디어
SNIP ICLR 2019 Connection sensitivity 기반
GraSP ICLR 2020 Gradient flow preservation
SynFlow NeurIPS 2020 Synaptic flow conservation
ProsPr ICLR 2022 Prospect pruning

LLM에서의 적용

LLM 시대에 LTH는 새로운 의미를 가진다.

Transformer Pruning

LLaMA 적용 결과 (JMLR 2023):
- 70% 파라미터 pruning
- Perplexity 유지
- 추론 속도 2.5x 향상

적용 전략

레이어 Pruning 비율 이유
Embedding 낮음 (20-30%) 입력 표현 보존
Attention 중간 (50-60%) 중요 head 선별
FFN 높음 (70-80%) 중복 뉴런 많음
LM Head 낮음 (20-30%) 출력 품질 보존

Python 구현

기본 Iterative Magnitude Pruning

import torch
import torch.nn as nn
import torch.nn.utils.prune as prune
import copy

class IterativeMagnitudePruning:
    """Lottery Ticket Hypothesis를 위한 IMP 구현"""

    def __init__(
        self,
        model: nn.Module,
        target_sparsity: float = 0.9,
        pruning_rate: float = 0.2,
        rewinding_epoch: int = 0
    ):
        self.model = model
        self.target_sparsity = target_sparsity
        self.pruning_rate = pruning_rate
        self.rewinding_epoch = rewinding_epoch

        # 초기 가중치 저장
        self.initial_weights = self._save_weights()
        self.rewinding_weights = None

    def _save_weights(self):
        """현재 가중치 저장"""
        weights = {}
        for name, param in self.model.named_parameters():
            weights[name] = param.data.clone()
        return weights

    def _load_weights(self, weights, keep_mask=True):
        """저장된 가중치 복원 (마스크 유지 가능)"""
        for name, param in self.model.named_parameters():
            if name in weights:
                if keep_mask and hasattr(param, '_mask'):
                    # 마스크가 있으면 마스킹된 위치만 복원
                    param.data = weights[name] * param._mask
                else:
                    param.data = weights[name].clone()

    def save_rewinding_checkpoint(self, epoch):
        """Late rewinding을 위한 체크포인트 저장"""
        if epoch == self.rewinding_epoch:
            self.rewinding_weights = self._save_weights()

    def get_prunable_layers(self):
        """프루닝 가능한 레이어 반환"""
        layers = []
        for name, module in self.model.named_modules():
            if isinstance(module, (nn.Linear, nn.Conv2d)):
                layers.append((module, 'weight'))
        return layers

    def get_current_sparsity(self):
        """현재 전체 희소도 계산"""
        total_params = 0
        zero_params = 0

        for name, param in self.model.named_parameters():
            if 'weight' in name:
                total_params += param.numel()
                zero_params += (param == 0).sum().item()

        return zero_params / total_params if total_params > 0 else 0

    def prune_step(self):
        """한 라운드 프루닝 수행"""
        layers = self.get_prunable_layers()

        # Global magnitude pruning
        prune.global_unstructured(
            layers,
            pruning_method=prune.L1Unstructured,
            amount=self.pruning_rate,
        )

        return self.get_current_sparsity()

    def rewind_weights(self):
        """가중치를 초기값으로 되돌리기 (마스크 유지)"""
        target_weights = (
            self.rewinding_weights 
            if self.rewinding_weights is not None 
            else self.initial_weights
        )

        for name, module in self.model.named_modules():
            if isinstance(module, (nn.Linear, nn.Conv2d)):
                # 마스크 가져오기
                if hasattr(module, 'weight_mask'):
                    mask = module.weight_mask
                    # 프루닝 제거 (가중치에 마스크 적용)
                    prune.remove(module, 'weight')
                    # 초기 가중치로 교체 후 마스크 재적용
                    orig_weight = target_weights[f"{name}.weight"]
                    module.weight.data = orig_weight * mask
                    # 마스크 다시 적용
                    prune.custom_from_mask(module, 'weight', mask)

    def find_winning_ticket(self, train_fn, epochs_per_round):
        """
        Winning ticket 탐색

        Args:
            train_fn: 학습 함수 (model, epochs) -> accuracy
            epochs_per_round: 각 라운드당 학습 에폭

        Returns:
            최종 희소도, 정확도
        """
        round_num = 0

        while self.get_current_sparsity() < self.target_sparsity:
            round_num += 1
            print(f"\n=== Round {round_num} ===")

            # 학습
            accuracy = train_fn(self.model, epochs_per_round)
            print(f"Accuracy after training: {accuracy:.4f}")

            # 프루닝
            sparsity = self.prune_step()
            print(f"Sparsity after pruning: {sparsity:.2%}")

            if sparsity >= self.target_sparsity:
                break

            # Rewinding
            self.rewind_weights()
            print("Weights rewound to initial values")

        # 최종 학습
        print("\n=== Final Training ===")
        final_accuracy = train_fn(self.model, epochs_per_round)
        final_sparsity = self.get_current_sparsity()

        return final_sparsity, final_accuracy


def make_permanent(model):
    """프루닝을 영구적으로 만들기 (마스크 제거)"""
    for name, module in model.named_modules():
        if isinstance(module, (nn.Linear, nn.Conv2d)):
            if hasattr(module, 'weight_mask'):
                prune.remove(module, 'weight')
    return model

완전한 실험 예시

import torch
import torch.nn as nn
import torch.optim as optim
from torchvision import datasets, transforms
from torch.utils.data import DataLoader

# 간단한 CNN 모델
class SimpleCNN(nn.Module):
    def __init__(self):
        super().__init__()
        self.conv1 = nn.Conv2d(1, 32, 3, 1)
        self.conv2 = nn.Conv2d(32, 64, 3, 1)
        self.fc1 = nn.Linear(9216, 128)
        self.fc2 = nn.Linear(128, 10)

    def forward(self, x):
        x = torch.relu(self.conv1(x))
        x = torch.relu(self.conv2(x))
        x = torch.max_pool2d(x, 2)
        x = torch.flatten(x, 1)
        x = torch.relu(self.fc1(x))
        x = self.fc2(x)
        return x


def train_and_evaluate(model, epochs, train_loader, test_loader, device):
    """학습 및 평가 함수"""
    model.train()
    optimizer = optim.Adam(model.parameters(), lr=1e-3)
    criterion = nn.CrossEntropyLoss()

    for epoch in range(epochs):
        for batch_idx, (data, target) in enumerate(train_loader):
            data, target = data.to(device), target.to(device)
            optimizer.zero_grad()
            output = model(data)
            loss = criterion(output, target)
            loss.backward()
            optimizer.step()

    # 평가
    model.eval()
    correct = 0
    total = 0
    with torch.no_grad():
        for data, target in test_loader:
            data, target = data.to(device), target.to(device)
            output = model(data)
            pred = output.argmax(dim=1)
            correct += pred.eq(target).sum().item()
            total += target.size(0)

    return correct / total


def run_lottery_ticket_experiment():
    """Lottery Ticket 실험 실행"""
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

    # 데이터 로드
    transform = transforms.Compose([
        transforms.ToTensor(),
        transforms.Normalize((0.1307,), (0.3081,))
    ])

    train_dataset = datasets.MNIST('./data', train=True, download=True, transform=transform)
    test_dataset = datasets.MNIST('./data', train=False, transform=transform)

    train_loader = DataLoader(train_dataset, batch_size=64, shuffle=True)
    test_loader = DataLoader(test_dataset, batch_size=1000)

    # 모델 생성
    model = SimpleCNN().to(device)

    # IMP 초기화
    imp = IterativeMagnitudePruning(
        model,
        target_sparsity=0.9,  # 90% 희소도 목표
        pruning_rate=0.2,     # 각 라운드 20% 제거
        rewinding_epoch=0     # 초기 가중치로 rewind
    )

    # 학습 함수 정의
    def train_fn(m, epochs):
        return train_and_evaluate(m, epochs, train_loader, test_loader, device)

    # Winning ticket 탐색
    sparsity, accuracy = imp.find_winning_ticket(train_fn, epochs_per_round=5)

    print(f"\nFinal Results:")
    print(f"Sparsity: {sparsity:.2%}")
    print(f"Accuracy: {accuracy:.4f}")

    # 프루닝 영구화
    model = make_permanent(model)

    return model


if __name__ == "__main__":
    run_lottery_ticket_experiment()

Late Rewinding 구현

class LateRewindingIMP(IterativeMagnitudePruning):
    """Late Rewinding을 지원하는 IMP"""

    def __init__(
        self,
        model: nn.Module,
        target_sparsity: float = 0.9,
        pruning_rate: float = 0.2,
        rewinding_ratio: float = 0.01  # 전체 학습의 1% 지점
    ):
        super().__init__(model, target_sparsity, pruning_rate)
        self.rewinding_ratio = rewinding_ratio
        self.rewinding_saved = False

    def find_winning_ticket_with_late_rewind(
        self,
        train_fn,
        total_epochs,
        epochs_per_round
    ):
        """Late rewinding을 사용한 winning ticket 탐색"""
        rewinding_epoch = int(total_epochs * self.rewinding_ratio)
        round_num = 0

        while self.get_current_sparsity() < self.target_sparsity:
            round_num += 1
            print(f"\n=== Round {round_num} ===")

            # 학습 (rewinding checkpoint 저장 포함)
            for epoch in range(epochs_per_round):
                accuracy = train_fn(self.model, 1)

                # Late rewinding checkpoint
                if not self.rewinding_saved and epoch == rewinding_epoch:
                    self.rewinding_weights = self._save_weights()
                    self.rewinding_saved = True
                    print(f"Saved rewinding checkpoint at epoch {epoch}")

            print(f"Accuracy: {accuracy:.4f}")

            # 프루닝
            sparsity = self.prune_step()
            print(f"Sparsity: {sparsity:.2%}")

            if sparsity >= self.target_sparsity:
                break

            # Late rewinding
            self.rewind_weights()

        # 최종 학습
        final_accuracy = train_fn(self.model, epochs_per_round)

        return self.get_current_sparsity(), final_accuracy

프루닝 분석 유틸리티

def analyze_sparsity_per_layer(model):
    """레이어별 희소도 분석"""
    results = []

    for name, module in model.named_modules():
        if isinstance(module, (nn.Linear, nn.Conv2d)):
            weight = module.weight.data
            total = weight.numel()
            zeros = (weight == 0).sum().item()
            sparsity = zeros / total

            results.append({
                'layer': name,
                'type': type(module).__name__,
                'params': total,
                'zeros': zeros,
                'sparsity': sparsity
            })

    return results


def plot_sparsity_distribution(model):
    """희소도 분포 시각화"""
    import matplotlib.pyplot as plt

    analysis = analyze_sparsity_per_layer(model)

    layers = [r['layer'] for r in analysis]
    sparsities = [r['sparsity'] for r in analysis]

    plt.figure(figsize=(12, 6))
    plt.bar(range(len(layers)), sparsities)
    plt.xticks(range(len(layers)), layers, rotation=45, ha='right')
    plt.ylabel('Sparsity')
    plt.title('Sparsity Distribution per Layer')
    plt.tight_layout()
    plt.show()


def count_parameters(model, count_zeros=False):
    """파라미터 수 카운트"""
    total = 0
    nonzero = 0

    for param in model.parameters():
        total += param.numel()
        nonzero += (param != 0).sum().item()

    if count_zeros:
        return total, nonzero, total - nonzero
    return nonzero  # 유효 파라미터 수

실무 가이드라인

하이퍼파라미터 권장값

파라미터 소규모 모델 대규모 모델 설명
목표 희소도 90-95% 70-80% 모델 크기에 반비례
프루닝 비율 20% 10-20% 보수적일수록 안정
Rewinding 지점 epoch 0 epoch k (0.1-1%) Late rewinding
라운드당 에폭 Full training 10-20% 학습 계산 효율

주의사항

1. 계산 비용
   - IMP는 여러 번 학습해야 함 (n rounds x training)
   - 대안: One-shot pruning, PaI methods

2. 하드웨어 가속
   - Unstructured sparsity는 GPU 가속 어려움
   - 실제 속도 향상은 structured pruning 필요

3. 태스크 전이
   - 같은 데이터셋 내 전이: 효과적
   - 다른 도메인 전이: 성능 저하 가능

4. 스케일링
   - 모델이 클수록 late rewinding 필요
   - ImageNet 급에서는 0.1-1% 학습 후 rewind

관련 연구 흐름

LTH (2019)
    |
    +-- Stabilizing LTH (2020): Late Rewinding
    |
    +-- SNIP, GraSP, SynFlow (2019-2020): Pruning at Initialization
    |
    +-- Linear Mode Connectivity (2020): 이론적 분석
    |
    +-- Dual LTH (2022): 학습-비학습 균형
    |
    +-- LTH for Transformers (2021-2023): BERT, GPT 적용
    |
    +-- LTH for LLMs (2023-2024): LLaMA, Mistral 적용

참고 자료

핵심 논문

  1. Frankle & Carlin (2019). The Lottery Ticket Hypothesis: Finding Sparse, Trainable Neural Networks. NeurIPS 2019.
  2. Frankle et al. (2020). Stabilizing the Lottery Ticket Hypothesis. arXiv:1903.01611.
  3. Frankle et al. (2020). Linear Mode Connectivity and the Lottery Ticket Hypothesis. ICML 2020.
  4. Chen et al. (2021). The Lottery Ticket Hypothesis for Pre-trained BERT Networks. NeurIPS 2021.

Survey

  • Hoefler et al. (2021). Sparsity in Deep Learning: Pruning and growth for efficient inference and training in neural networks. JMLR.
  • LTH Survey (2024). A Survey of Lottery Ticket Hypothesis. arXiv:2403.04861.

관련 개념