Knowledge Distillation (지식 증류)¶
메타 정보¶
| 항목 | 내용 |
|---|---|
| 분류 | Model Compression / Transfer Learning |
| 원논문 | "Distilling the Knowledge in a Neural Network" (NeurIPS Workshop 2015) |
| 주요 저자 | Geoffrey Hinton, Oriol Vinyals, Jeff Dean (Google) |
| 핵심 개념 | 대형 Teacher 모델의 지식을 소형 Student 모델로 전이 |
| 관련 분야 | Model Compression, Pruning, Quantization, LLM Serving, Edge AI |
정의¶
Knowledge Distillation(KD)은 대형(또는 앙상블) 모델(Teacher)이 학습한 지식을 소형 모델(Student)로 전달하는 학습 기법이다. Student는 hard label(정답)만 학습하는 것이 아니라 Teacher의 soft output(확률 분포)까지 모방함으로써, 단독 학습 대비 높은 성능을 달성한다.
핵심 통찰: Teacher의 출력 확률 분포에는 클래스 간 유사도, 데이터 구조 등 hard label에는 없는 "dark knowledge"가 담겨 있다.
핵심 아이디어¶
Dark Knowledge¶
Hard Label (One-Hot):
고양이 이미지 -> [1.0, 0.0, 0.0] (고양이, 개, 자동차)
-> 정답 외 클래스 간 관계 정보 없음
Soft Label (Teacher Output, Temperature=1):
고양이 이미지 -> [0.92, 0.07, 0.01]
-> "이 이미지는 개와 약간 유사하고, 자동차와는 전혀 다르다"
-> 클래스 간 관계 정보 포함
Soft Label (Temperature=5):
고양이 이미지 -> [0.58, 0.33, 0.09]
-> Temperature를 높이면 분포가 더 부드러워짐
-> dark knowledge가 더 명확하게 드러남
Temperature Scaling¶
KD의 핵심 메커니즘은 softmax 함수에 temperature 파라미터 T를 도입하는 것이다.
Standard Softmax (T=1):
p_i = exp(z_i) / sum_j exp(z_j)
Softened Softmax (T>1):
q_i = exp(z_i / T) / sum_j exp(z_j / T)
T의 효과:
T -> 0: 가장 큰 logit만 1에 가까워짐 (hard decision)
T = 1: 표준 softmax
T -> inf: 균등 분포에 수렴
예시 (logits = [5.0, 3.0, 1.0]):
T=1: [0.844, 0.114, 0.042] -- 거의 one-hot
T=3: [0.506, 0.307, 0.187] -- 구조 가시화
T=5: [0.424, 0.328, 0.248] -- 더 부드러운 분포
T=10: [0.372, 0.337, 0.291] -- 거의 균등
Distillation Loss¶
Student 학습 시 두 가지 loss를 결합한다.
L_total = alpha * L_hard + (1 - alpha) * L_soft
L_hard = CrossEntropy(y_student, y_true)
-> 정답 레이블에 대한 표준 loss
L_soft = T^2 * KL_Divergence(q_student(T), q_teacher(T))
-> Teacher의 soft output을 모방하는 loss
-> T^2: gradient 크기 보정 (T가 클수록 gradient가 작아지는 것 보상)
alpha: 두 loss의 가중치 (보통 0.1-0.5)
T: Temperature (보통 3-20)
KD의 분류 체계¶
지식의 유형에 따른 분류¶
1. Response-Based KD (출력 기반)
Teacher의 최종 출력(logits/softmax)을 모방
-> 가장 기본적이고 범용적
2. Feature-Based KD (중간 표현 기반)
Teacher의 중간 레이어 activation을 모방
-> FitNets (Romero et al., ICLR 2015)
-> 더 풍부한 지식 전달 가능
3. Relation-Based KD (관계 기반)
샘플 간 또는 레이어 간 관계 구조를 모방
-> RKD (Park et al., CVPR 2019)
-> 데이터의 구조적 정보 보존
Response-Based KD¶
Teacher: x -> [Layer1] -> [Layer2] -> ... -> [Logits] -> [Softmax(T)]
|
soft targets
|
Student: x -> [Layer1] -> [Layer2] -> [Logits] -> [Softmax(T)]
|
L_KL(student, teacher)
특징:
+ 구현이 간단
+ Teacher 내부 구조에 무관
+ 다양한 태스크에 적용 가능
- 중간 레이어의 풍부한 정보 미활용
Feature-Based KD (FitNets)¶
Teacher: x -> [T_Layer1] -> [T_Layer2] -> [T_Layer3] -> output
| | |
hint_1 hint_2 hint_3
| | |
[Adapter] [Adapter] [Adapter] (차원 맞춤)
| | |
Student: x -> [S_Layer1] -> [S_Layer2] -> [S_Layer3] -> output
L_feature = sum_l || f_adapter(F_student^l) - F_teacher^l ||^2
f_adapter: 차원 변환 함수 (Student와 Teacher의 hidden dim이 다를 경우)
장점:
+ 중간 표현의 풍부한 지식 전달
+ Student가 Teacher의 추론 과정을 학습
단점:
- Teacher-Student 레이어 매핑 설계 필요
- Adapter 함수의 설계 및 학습 비용
Relation-Based KD¶
데이터 포인트 간 관계를 보존:
Teacher에서의 관계:
Sample A --0.8-- Sample B
Sample A --0.2-- Sample C
Sample B --0.6-- Sample C
Student도 동일한 관계 구조를 학습:
L_relation = || G_student(x_i, x_j) - G_teacher(x_i, x_j) ||^2
G: 두 샘플 간의 관계 함수
- 거리(distance): ||f(x_i) - f(x_j)||
- 각도(angle): cos(f(x_i), f(x_j))
- 상관(correlation): Gram matrix
장점:
+ 데이터의 구조적 정보 보존
+ Teacher-Student 아키텍처 차이에 강건
단점:
- 배치 내 모든 쌍 계산 (O(N^2) 복잡도)
학습 방식에 따른 분류¶
1. Offline Distillation¶
Phase 1: Teacher 모델 학습 (독립적)
Teacher_model = train(large_model, dataset)
Phase 2: Student 모델 학습 (Teacher 고정)
for batch in dataset:
teacher_output = Teacher_model(batch) # frozen
student_output = Student_model(batch)
loss = alpha * CE(student_output, labels) +
(1-alpha) * KL(student_output/T, teacher_output/T)
특징:
+ 가장 일반적이고 안정적
+ Teacher와 Student를 독립적으로 설계 가능
- Teacher 학습 비용이 큼
- 2단계 학습 파이프라인
2. Online Distillation¶
Teacher와 Student를 동시에 학습:
Model_1 --soft labels--> Model_2
Model_2 --soft labels--> Model_1
Deep Mutual Learning (Zhang et al., CVPR 2018):
두 모델이 서로의 Teacher 역할
L_1 = CE(p_1, y) + KL(p_1, p_2)
L_2 = CE(p_2, y) + KL(p_2, p_1)
-> 사전 학습된 대형 Teacher가 불필요
-> 두 모델 모두 성능 향상
-> 동일 크기 모델끼리도 효과적
3. Self-Distillation¶
동일 모델이 Teacher와 Student 역할:
방법 1: Born-Again Networks (Furlanello et al., 2018)
Generation 0: 모델을 정상 학습
Generation 1: Generation 0을 Teacher로 동일 구조 모델 학습
Generation 2: Generation 1을 Teacher로 학습
...
-> 세대를 거듭할수록 성능 향상 (diminishing returns)
방법 2: Be Your Own Teacher (Zhang et al., 2019)
모델의 깊은 레이어가 얕은 레이어를 가르침:
[Layer 1] -> [Layer 2] -> ... -> [Layer N] -> Final Output
| |
[Classifier_1] [Classifier_N]
| |
+--- L_KL(shallow, deep) ---+
방법 3: Progressive Self-Distillation
학습 중 과거의 자신(EMA)이 Teacher:
theta_teacher = beta * theta_teacher + (1 - beta) * theta_student
LLM을 위한 Knowledge Distillation¶
2024-2025년, LLM 시대에 KD의 중요성이 크게 증가했다.
LLM KD의 특수성¶
기존 KD:
Teacher: CNN/ResNet (수백만 파라미터)
Student: 작은 CNN (수십만 파라미터)
지식: Soft label (분류 확률)
LLM KD:
Teacher: GPT-4, Claude 4 (수천억 파라미터, API만 접근 가능)
Student: LLaMA-7B, Mistral-7B (수십억 파라미터)
지식: 생성 텍스트, reasoning chain, 능력(skill)
핵심 차이:
1. Teacher 내부(logits)에 접근 불가 (black-box)
2. 분류가 아닌 생성(autoregressive) 태스크
3. 단일 태스크가 아닌 다양한 능력 전이
4. 데이터 생성 자체가 KD의 핵심
White-Box vs Black-Box Distillation¶
White-Box (Teacher 내부 접근 가능):
Teacher logits 직접 활용
-> 전통적 KD와 동일한 방식
-> 오픈소스 Teacher에서만 가능
예: LLaMA-70B -> LLaMA-7B
L = CE(student_logits, labels) + KL(student_logits/T, teacher_logits/T)
Black-Box (API만 접근 가능):
Teacher의 생성 텍스트만 활용
-> Teacher에게 데이터 생성 요청
-> 생성된 데이터로 Student 학습 (SFT)
예: GPT-4 -> LLaMA-7B
Step 1: GPT-4로 고품질 (prompt, response) 쌍 생성
Step 2: 생성 데이터로 Student SFT
LLM KD 주요 방법론¶
1. Symbolic Knowledge Distillation
Teacher가 지식을 텍스트 형태로 생성:
(a) 데이터 증강:
Teacher: "수학 문제 100개를 풀이와 함께 생성해줘"
-> 고품질 (문제, 풀이) 데이터셋 생성
-> Student가 이 데이터로 학습
(b) Chain-of-Thought (CoT) Distillation:
Teacher:
Q: "234 + 567 = ?"
A: "200+500=700, 30+60=90, 4+7=11, 700+90+11=801"
Student는 CoT를 포함한 답변을 학습
-> 추론 능력(reasoning)까지 전이
(c) Skill-Specific Distillation:
특정 능력만 선별적으로 증류
- 코드 생성 능력: Code Alpaca
- 수학 능력: WizardMath
- 대화 능력: Vicuna (ShareGPT 데이터)
2. Logit-Based LLM Distillation
오픈소스 Teacher의 next-token logits 활용:
Token-Level KD:
Input: "The capital of France is"
Teacher logits: [Paris: 0.85, Lyon: 0.08, Marseille: 0.03, ...]
Student logits: [Paris: 0.60, London: 0.15, Lyon: 0.05, ...]
L = sum_t KL(student_logits_t / T, teacher_logits_t / T)
각 토큰 위치에서 Teacher의 분포를 모방
MiniLLM (Gu et al., ICLR 2024):
Reverse KL 사용으로 mode-seeking 특성 활용
-> Student가 Teacher의 고확률 영역에 집중
-> 환각(hallucination) 감소
L = KL(p_student || p_teacher) (reverse direction)
3. NVIDIA Minitron 접근법
Pruning + Distillation 결합:
Step 1: 대형 모델에서 중요도가 낮은 뉴런/레이어 제거 (Pruning)
15B -> 8B (레이어/헤드/채널 pruning)
Step 2: Pruned 모델을 원본의 KD로 복구
Teacher: 원본 15B 모델
Student: Pruned 8B 모델
학습: 원래 토큰의 1/40만 사용
결과 (Minitron, NeurIPS 2024):
Nemotron-4 15B -> 8B:
- LM Eval 평균: 원본의 98.7% 성능
- 학습 비용: 원본 학습의 2.5%
- 처음부터 학습한 8B 대비 우수
Self-Distillation for LLM¶
LLM이 스스로의 Teacher 역할:
(a) Rejection Sampling + Self-Training:
Step 1: LLM이 문제를 여러 번 풀기 (N=64)
Step 2: 정답인 풀이만 선별
Step 3: 선별된 풀이로 SFT
-> STaR (Zelikman et al., NeurIPS 2022)
(b) Self-Play / Self-Improvement:
Step 1: 현재 모델로 데이터 생성
Step 2: 생성 데이터를 평가/필터링
Step 3: 고품질 데이터로 재학습
-> 반복하면서 점진적 개선
(c) On-Policy Distillation:
Student가 직접 생성한 텍스트에 대해
Teacher가 토큰별 확률을 제공
-> Distribution mismatch 감소
-> GKD (Agarwal et al., 2024)
고급 기법¶
Attention Transfer¶
Zagoruyko & Komodakis (ICLR 2017). Teacher의 attention map을 Student에게 전달한다.
Teacher Attention Map:
A_T = sum_i |F_T^i|^2 (채널 방향 제곱합)
-> 어디에 주목하는지를 나타냄
Student Attention Map:
A_S = sum_i |F_S^i|^2
Attention Transfer Loss:
L_AT = sum_l || A_S^l / ||A_S^l||_2 - A_T^l / ||A_T^l||_2 ||_2
-> L2 정규화된 attention map 간 거리
장점:
+ 공간적 주의(spatial attention) 패턴 전이
+ Feature map 크기만 맞으면 적용 가능
+ Teacher의 "무엇을 보는가"를 직접 학습
Contrastive Representation Distillation (CRD)¶
Tian et al. (ICLR 2020). Contrastive learning 원리로 KD를 수행한다.
핵심 아이디어:
같은 입력 x에 대해:
Teacher 표현 t = f_T(x) -- positive pair
Student 표현 s = f_S(x) -- positive pair
다른 입력 x'에 대해:
Teacher 표현 t' = f_T(x') -- negative pair
InfoNCE Loss:
L_CRD = -log(exp(s . t / tau) / (exp(s . t / tau) + sum_{t'} exp(s . t' / tau)))
장점:
+ 구조적 관계 보존
+ Teacher-Student 차원 불일치에 강건
+ Response/Feature-based KD 대비 우수한 성능
Multi-Teacher Distillation¶
여러 Teacher의 지식을 결합:
방법 1: 평균 (Simple Averaging)
q_ensemble = (1/K) * sum_k q_teacher_k
-> 모든 Teacher를 동등하게 취급
방법 2: 가중 평균 (Weighted Averaging)
q_ensemble = sum_k w_k * q_teacher_k
w_k: Teacher k의 가중치 (성능/신뢰도 기반)
방법 3: Feature-level Aggregation
각 Teacher의 중간 표현을 개별적으로 증류
L = sum_k lambda_k * L_feature(S, T_k)
방법 4: Task-Specific Selection
태스크별로 가장 적합한 Teacher 선택
-> Routing network가 Teacher 선택
Data-Free Distillation¶
Teacher 학습에 사용된 원본 데이터 없이 KD를 수행한다.
방법 1: Generator 기반 (DAFL, Chen et al., 2019)
Generator G -> 합성 데이터 생성
Teacher가 합성 데이터에 soft label 제공
Student가 soft label로 학습
G 학습: Teacher의 출력 엔트로피 최소화
-> Teacher가 자신 있는 데이터를 생성하도록 유도
방법 2: Batch Normalization 통계 활용
Teacher의 BN 레이어에 저장된 mean/variance 활용
-> 원본 데이터 분포를 근사하는 합성 데이터 생성
방법 3: Feature Map Inversion
Teacher의 중간 feature map을 역변환하여 입력 복원
활용:
- 개인정보 보호 (원본 데이터 공유 불가)
- 데이터 접근 제한 환경
- 의료/금융 등 규제 분야
Python 구현¶
기본 Knowledge Distillation¶
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import DataLoader
class DistillationLoss(nn.Module):
"""Knowledge Distillation Loss"""
def __init__(
self,
temperature: float = 4.0,
alpha: float = 0.3,
reduction: str = 'mean'
):
"""
Args:
temperature: Softmax temperature (T)
alpha: Hard loss 가중치 (1-alpha = Soft loss 가중치)
reduction: 'mean' or 'sum'
"""
super().__init__()
self.temperature = temperature
self.alpha = alpha
self.reduction = reduction
self.ce_loss = nn.CrossEntropyLoss(reduction=reduction)
def forward(
self,
student_logits: torch.Tensor,
teacher_logits: torch.Tensor,
labels: torch.Tensor
) -> torch.Tensor:
"""
Args:
student_logits: Student 모델 출력 [B, C]
teacher_logits: Teacher 모델 출력 [B, C]
labels: 정답 레이블 [B]
"""
# Hard loss: Student vs Ground Truth
hard_loss = self.ce_loss(student_logits, labels)
# Soft loss: Student vs Teacher (with temperature)
T = self.temperature
student_soft = F.log_softmax(student_logits / T, dim=1)
teacher_soft = F.softmax(teacher_logits / T, dim=1)
soft_loss = F.kl_div(
student_soft,
teacher_soft,
reduction='batchmean'
) * (T ** 2)
# Combined loss
loss = self.alpha * hard_loss + (1 - self.alpha) * soft_loss
return loss
class FeatureDistillationLoss(nn.Module):
"""Feature-Based Knowledge Distillation (FitNets 스타일)"""
def __init__(
self,
student_channels: list,
teacher_channels: list
):
"""
Args:
student_channels: Student 중간 레이어 채널 수 리스트
teacher_channels: Teacher 중간 레이어 채널 수 리스트
"""
super().__init__()
assert len(student_channels) == len(teacher_channels)
# 차원 맞춤용 1x1 conv (Student -> Teacher 차원)
self.adapters = nn.ModuleList([
nn.Conv2d(s_ch, t_ch, kernel_size=1, bias=False)
for s_ch, t_ch in zip(student_channels, teacher_channels)
])
def forward(
self,
student_features: list,
teacher_features: list
) -> torch.Tensor:
"""
Args:
student_features: Student 중간 feature maps 리스트
teacher_features: Teacher 중간 feature maps 리스트
"""
loss = 0
for adapter, s_feat, t_feat in zip(
self.adapters, student_features, teacher_features
):
# 차원 맞춤
s_adapted = adapter(s_feat)
# 공간 크기 맞춤 (필요시)
if s_adapted.shape[2:] != t_feat.shape[2:]:
s_adapted = F.adaptive_avg_pool2d(s_adapted, t_feat.shape[2:])
# L2 distance
loss += F.mse_loss(s_adapted, t_feat.detach())
return loss
class AttentionTransferLoss(nn.Module):
"""Attention Transfer (Zagoruyko & Komodakis, 2017)"""
def __init__(self, p: int = 2):
"""
Args:
p: Attention map 계산 시 norm 차수
"""
super().__init__()
self.p = p
def _attention_map(self, feature: torch.Tensor) -> torch.Tensor:
"""Feature map -> Attention map (채널 방향 집약)"""
# feature: [B, C, H, W]
# attention: [B, H*W]
att = feature.pow(self.p).mean(dim=1).view(feature.size(0), -1)
# L2 정규화
att = F.normalize(att, p=2, dim=1)
return att
def forward(
self,
student_features: list,
teacher_features: list
) -> torch.Tensor:
loss = 0
for s_feat, t_feat in zip(student_features, teacher_features):
s_att = self._attention_map(s_feat)
t_att = self._attention_map(t_feat.detach())
loss += (s_att - t_att).pow(2).mean()
return loss
Knowledge Distillation Trainer¶
class KDTrainer:
"""Knowledge Distillation 학습기"""
def __init__(
self,
teacher: nn.Module,
student: nn.Module,
train_loader: DataLoader,
val_loader: DataLoader,
temperature: float = 4.0,
alpha: float = 0.3,
learning_rate: float = 1e-3,
feature_weight: float = 0.0,
attention_weight: float = 0.0,
device: str = 'cuda'
):
self.teacher = teacher.to(device).eval()
self.student = student.to(device)
self.train_loader = train_loader
self.val_loader = val_loader
self.device = device
# Teacher는 학습하지 않음
for param in self.teacher.parameters():
param.requires_grad = False
# Losses
self.kd_loss = DistillationLoss(temperature, alpha)
self.optimizer = torch.optim.Adam(
self.student.parameters(), lr=learning_rate
)
self.scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(
self.optimizer, T_max=100
)
self.feature_weight = feature_weight
self.attention_weight = attention_weight
def train_epoch(self) -> dict:
self.student.train()
total_loss = 0
correct = 0
total = 0
for data, targets in self.train_loader:
data = data.to(self.device)
targets = targets.to(self.device)
# Teacher forward (no grad)
with torch.no_grad():
teacher_logits = self.teacher(data)
# Student forward
student_logits = self.student(data)
# KD Loss
loss = self.kd_loss(student_logits, teacher_logits, targets)
self.optimizer.zero_grad()
loss.backward()
self.optimizer.step()
total_loss += loss.item() * data.size(0)
_, predicted = student_logits.max(1)
correct += predicted.eq(targets).sum().item()
total += targets.size(0)
self.scheduler.step()
return {
'loss': total_loss / total,
'accuracy': correct / total
}
@torch.no_grad()
def evaluate(self) -> dict:
self.student.eval()
correct = 0
total = 0
for data, targets in self.val_loader:
data = data.to(self.device)
targets = targets.to(self.device)
outputs = self.student(data)
_, predicted = outputs.max(1)
correct += predicted.eq(targets).sum().item()
total += targets.size(0)
return {'accuracy': correct / total}
def train(self, num_epochs: int = 100, log_interval: int = 10):
best_acc = 0
for epoch in range(num_epochs):
train_metrics = self.train_epoch()
if (epoch + 1) % log_interval == 0:
val_metrics = self.evaluate()
print(
f"Epoch {epoch+1}/{num_epochs} | "
f"Train Loss: {train_metrics['loss']:.4f} | "
f"Train Acc: {train_metrics['accuracy']:.4f} | "
f"Val Acc: {val_metrics['accuracy']:.4f}"
)
if val_metrics['accuracy'] > best_acc:
best_acc = val_metrics['accuracy']
print(f"\nBest validation accuracy: {best_acc:.4f}")
return best_acc
LLM Black-Box Distillation 파이프라인¶
import json
from dataclasses import dataclass
from typing import Optional
@dataclass
class DistillationSample:
"""LLM KD용 데이터 샘플"""
prompt: str
teacher_response: str
teacher_model: str
task_type: str # 'reasoning', 'code', 'conversation', ...
quality_score: Optional[float] = None
class LLMDistillationPipeline:
"""
Black-Box LLM Distillation 파이프라인
Teacher API에서 고품질 데이터를 생성하고
Student 모델을 SFT로 학습시킴
"""
def __init__(
self,
teacher_api, # OpenAI/Anthropic 등 API 클라이언트
student_model, # HuggingFace 모델
student_tokenizer, # HuggingFace 토크나이저
task_prompts: dict # 태스크별 프롬프트 템플릿
):
self.teacher_api = teacher_api
self.student_model = student_model
self.tokenizer = student_tokenizer
self.task_prompts = task_prompts
def generate_distillation_data(
self,
seed_prompts: list,
task_type: str = 'reasoning',
n_samples: int = 1000,
n_responses_per_prompt: int = 3,
temperature: float = 0.7
) -> list:
"""
Teacher로부터 학습 데이터 생성
Args:
seed_prompts: 시드 프롬프트 리스트
task_type: 태스크 유형
n_samples: 목표 샘플 수
n_responses_per_prompt: 프롬프트당 응답 수
temperature: Teacher 생성 temperature
Returns:
DistillationSample 리스트
"""
samples = []
for prompt in seed_prompts[:n_samples]:
# 시스템 프롬프트로 품질 유도
system_prompt = self.task_prompts.get(
task_type,
"Provide a detailed, accurate, and helpful response."
)
for _ in range(n_responses_per_prompt):
response = self.teacher_api.chat.completions.create(
model="gpt-4",
messages=[
{"role": "system", "content": system_prompt},
{"role": "user", "content": prompt}
],
temperature=temperature,
max_tokens=2048
)
teacher_text = response.choices[0].message.content
sample = DistillationSample(
prompt=prompt,
teacher_response=teacher_text,
teacher_model="gpt-4",
task_type=task_type
)
samples.append(sample)
return samples
def filter_by_quality(
self,
samples: list,
min_length: int = 50,
max_length: int = 4096,
dedup: bool = True
) -> list:
"""생성 데이터 품질 필터링"""
filtered = []
seen_responses = set()
for sample in samples:
# 길이 필터
if len(sample.teacher_response) < min_length:
continue
if len(sample.teacher_response) > max_length:
continue
# 중복 제거
if dedup:
response_hash = hash(sample.teacher_response[:200])
if response_hash in seen_responses:
continue
seen_responses.add(response_hash)
filtered.append(sample)
return filtered
def prepare_sft_data(self, samples: list) -> list:
"""SFT 학습용 데이터 변환"""
sft_data = []
for sample in samples:
# Chat template 적용
messages = [
{"role": "user", "content": sample.prompt},
{"role": "assistant", "content": sample.teacher_response}
]
text = self.tokenizer.apply_chat_template(
messages,
tokenize=False,
add_generation_prompt=False
)
sft_data.append({
"text": text,
"task_type": sample.task_type
})
return sft_data
def save_dataset(self, samples: list, path: str):
"""데이터셋 저장"""
with open(path, 'w', encoding='utf-8') as f:
for sample in samples:
f.write(json.dumps(sample, ensure_ascii=False) + '\n')
print(f"Saved {len(samples)} samples to {path}")
class CoTDistillation:
"""Chain-of-Thought Distillation"""
def __init__(self, teacher_api):
self.teacher_api = teacher_api
def generate_cot_data(
self,
questions: list,
subject: str = 'math'
) -> list:
"""
Teacher로부터 CoT 추론 과정 생성
"""
cot_prompt = """
Solve the following problem step by step.
Show your reasoning process clearly.
Format:
Step 1: [reasoning]
Step 2: [reasoning]
...
Answer: [final answer]
"""
samples = []
for question in questions:
response = self.teacher_api.chat.completions.create(
model="gpt-4",
messages=[
{"role": "system", "content": cot_prompt},
{"role": "user", "content": question}
],
temperature=0.3
)
teacher_cot = response.choices[0].message.content
samples.append({
"question": question,
"cot_response": teacher_cot,
"subject": subject
})
return samples
완전한 학습 예시 (CIFAR-10)¶
import torch
import torch.nn as nn
from torchvision import datasets, transforms
from torch.utils.data import DataLoader
class TeacherNet(nn.Module):
"""Teacher: 큰 CNN"""
def __init__(self):
super().__init__()
self.features = nn.Sequential(
nn.Conv2d(3, 128, 3, padding=1), nn.BatchNorm2d(128), nn.ReLU(),
nn.Conv2d(128, 128, 3, padding=1), nn.BatchNorm2d(128), nn.ReLU(),
nn.MaxPool2d(2),
nn.Conv2d(128, 256, 3, padding=1), nn.BatchNorm2d(256), nn.ReLU(),
nn.Conv2d(256, 256, 3, padding=1), nn.BatchNorm2d(256), nn.ReLU(),
nn.MaxPool2d(2),
nn.Conv2d(256, 512, 3, padding=1), nn.BatchNorm2d(512), nn.ReLU(),
nn.MaxPool2d(2),
)
self.classifier = nn.Sequential(
nn.Linear(512 * 4 * 4, 1024), nn.ReLU(), nn.Dropout(0.5),
nn.Linear(1024, 10)
)
def forward(self, x):
x = self.features(x)
x = x.view(x.size(0), -1)
return self.classifier(x)
class StudentNet(nn.Module):
"""Student: 작은 CNN (Teacher의 1/4 크기)"""
def __init__(self):
super().__init__()
self.features = nn.Sequential(
nn.Conv2d(3, 32, 3, padding=1), nn.BatchNorm2d(32), nn.ReLU(),
nn.MaxPool2d(2),
nn.Conv2d(32, 64, 3, padding=1), nn.BatchNorm2d(64), nn.ReLU(),
nn.MaxPool2d(2),
nn.Conv2d(64, 128, 3, padding=1), nn.BatchNorm2d(128), nn.ReLU(),
nn.MaxPool2d(2),
)
self.classifier = nn.Sequential(
nn.Linear(128 * 4 * 4, 256), nn.ReLU(),
nn.Linear(256, 10)
)
def forward(self, x):
x = self.features(x)
x = x.view(x.size(0), -1)
return self.classifier(x)
def run_distillation_experiment():
"""KD 전체 실험: Teacher 학습 -> Student KD vs Student 단독"""
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
# 데이터
transform_train = transforms.Compose([
transforms.RandomCrop(32, padding=4),
transforms.RandomHorizontalFlip(),
transforms.ToTensor(),
transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2470, 0.2435, 0.2616))
])
transform_test = transforms.Compose([
transforms.ToTensor(),
transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2470, 0.2435, 0.2616))
])
train_set = datasets.CIFAR10('./data', train=True, download=True, transform=transform_train)
test_set = datasets.CIFAR10('./data', train=False, transform=transform_test)
train_loader = DataLoader(train_set, batch_size=128, shuffle=True, num_workers=2)
test_loader = DataLoader(test_set, batch_size=100, num_workers=2)
# === Phase 1: Teacher 학습 ===
print("=" * 50)
print("Phase 1: Training Teacher")
print("=" * 50)
teacher = TeacherNet().to(device)
teacher_optimizer = torch.optim.SGD(
teacher.parameters(), lr=0.1, momentum=0.9, weight_decay=5e-4
)
teacher_scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(
teacher_optimizer, T_max=100
)
criterion = nn.CrossEntropyLoss()
teacher_params = sum(p.numel() for p in teacher.parameters())
print(f"Teacher parameters: {teacher_params:,}")
for epoch in range(100):
teacher.train()
for data, targets in train_loader:
data, targets = data.to(device), targets.to(device)
teacher_optimizer.zero_grad()
loss = criterion(teacher(data), targets)
loss.backward()
teacher_optimizer.step()
teacher_scheduler.step()
if (epoch + 1) % 20 == 0:
teacher.eval()
correct = sum(
teacher(d.to(device)).argmax(1).eq(t.to(device)).sum().item()
for d, t in test_loader
)
print(f"Teacher Epoch {epoch+1}: {correct/len(test_set):.4f}")
teacher.eval()
# === Phase 2: Student with KD ===
print("\n" + "=" * 50)
print("Phase 2: Student with Knowledge Distillation")
print("=" * 50)
student_kd = StudentNet().to(device)
student_params = sum(p.numel() for p in student_kd.parameters())
print(f"Student parameters: {student_params:,}")
print(f"Compression ratio: {teacher_params/student_params:.1f}x")
kd_trainer = KDTrainer(
teacher=teacher,
student=student_kd,
train_loader=train_loader,
val_loader=test_loader,
temperature=4.0,
alpha=0.3,
learning_rate=1e-3,
device=device
)
kd_acc = kd_trainer.train(num_epochs=100, log_interval=20)
# === Phase 3: Student without KD (baseline) ===
print("\n" + "=" * 50)
print("Phase 3: Student without KD (baseline)")
print("=" * 50)
student_baseline = StudentNet().to(device)
baseline_optimizer = torch.optim.Adam(student_baseline.parameters(), lr=1e-3)
baseline_scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(
baseline_optimizer, T_max=100
)
best_baseline = 0
for epoch in range(100):
student_baseline.train()
for data, targets in train_loader:
data, targets = data.to(device), targets.to(device)
baseline_optimizer.zero_grad()
loss = criterion(student_baseline(data), targets)
loss.backward()
baseline_optimizer.step()
baseline_scheduler.step()
if (epoch + 1) % 20 == 0:
student_baseline.eval()
correct = sum(
student_baseline(d.to(device)).argmax(1).eq(t.to(device)).sum().item()
for d, t in test_loader
)
acc = correct / len(test_set)
best_baseline = max(best_baseline, acc)
print(f"Baseline Epoch {epoch+1}: {acc:.4f}")
# === 결과 비교 ===
print("\n" + "=" * 50)
print("Results")
print("=" * 50)
print(f"Student (with KD): {kd_acc:.4f}")
print(f"Student (no KD): {best_baseline:.4f}")
print(f"Improvement: {(kd_acc - best_baseline)*100:.2f}%p")
if __name__ == "__main__":
run_distillation_experiment()
실무 가이드라인¶
하이퍼파라미터 권장값¶
| 파라미터 | 권장 범위 | 설명 |
|---|---|---|
| Temperature (T) | 3 - 20 | 높을수록 soft (보통 4-8) |
| Alpha | 0.1 - 0.5 | Hard loss 가중치 (보통 0.3) |
| Student/Teacher 비율 | 1/4 - 1/10 | 모델 크기 비율 |
| Learning Rate | 1e-4 - 1e-3 | Student 학습률 |
| Feature Loss Weight | 0.01 - 0.1 | Feature KD 시 가중치 |
Temperature 선택 가이드¶
T가 낮을 때 (1-3):
+ Teacher의 top-1 예측에 충실
+ 분류 정확도 중심
- dark knowledge 활용 부족
적합: 쉬운 태스크, 클래스 수 적을 때
T가 높을 때 (10-20):
+ 클래스 간 관계 충분히 전달
+ 더 부드러운 학습 신호
- 정보 희석 가능성
적합: 클래스 수 많을 때, 유사 클래스 구분
일반적 시작점: T=4, alpha=0.3
언제 KD를 사용해야 하는가¶
효과적인 경우:
+ 대형 모델을 Edge/Mobile에 배포해야 할 때
+ 추론 지연시간(latency) 감소가 필요할 때
+ 앙상블 모델을 단일 모델로 압축할 때
+ LLM API 비용 절감이 필요할 때
+ 특정 도메인에 소형 모델을 특화할 때
+ Teacher 성능의 90%+를 유지하면서 비용 1/10
비효과적이거나 불필요한 경우:
- Student와 Teacher의 성능 차이가 너무 클 때
- 데이터가 매우 적을 때 (Teacher 자체가 부정확)
- 단순한 태스크 (Student 단독으로 충분)
- 실시간 추론이 불필요한 오프라인 환경
주의사항¶
1. Teacher 품질이 핵심
- 낮은 품질의 Teacher -> Student도 낮은 품질
- Teacher의 오류/편향이 Student에게 전달됨
- Teacher 검증 후 증류 시작
2. Capacity Gap 문제
- Student가 너무 작으면 Teacher 지식을 수용 불가
- Teacher-Student 크기 비율: 4-10배 권장
- 너무 큰 gap은 중간 Teacher(TA)로 bridging
Teacher -> Teaching Assistant -> Student
3. Task Mismatch
- Teacher의 학습 태스크와 Student의 목표 태스크가 다르면 효과 감소
- 도메인 특화 Teacher 사용 권장
4. 학습 안정성
- Temperature가 너무 높으면 gradient vanishing
- Alpha 값에 민감할 수 있음 -> grid search 권장
- Warmup: 초기에는 hard loss 비중 높이고 점차 soft loss 증가
5. LLM 증류 시 주의
- Teacher API 비용 관리 (데이터 생성 비용)
- 생성 데이터 품질 검증 필수
- 저작권/라이선스 확인 (일부 모델은 출력물 증류 금지)
성능 벤치마크 (대표 결과)¶
ImageNet Classification¶
| Teacher | Student | Method | Top-1 Acc (%) |
|---|---|---|---|
| ResNet-34 | ResNet-18 | Baseline (no KD) | 69.75 |
| ResNet-34 | ResNet-18 | Vanilla KD | 71.03 |
| ResNet-34 | ResNet-18 | FitNets | 71.06 |
| ResNet-34 | ResNet-18 | AT (Attention) | 70.69 |
| ResNet-34 | ResNet-18 | CRD | 71.17 |
| ResNet-34 | ResNet-18 | KD + CRD | 71.38 |
LLM Distillation¶
| Teacher | Student | Method | MMLU (%) |
|---|---|---|---|
| GPT-4 | LLaMA-7B | SFT (Alpaca) | 42.3 |
| GPT-4 | LLaMA-7B | CoT Distillation | 46.8 |
| Nemotron-4 15B | Minitron 8B | Pruning + KD | 67.2 |
| Nemotron-4 15B (baseline) | - | From scratch 8B | 65.1 |
관련 연구 흐름¶
Model Compression (Bucilua et al., 2006)
|
+-- Knowledge Distillation (Hinton et al., 2015)
| |
| +-- FitNets (Romero et al., 2015): Feature-based KD
| |
| +-- Attention Transfer (Zagoruyko, 2017)
| |
| +-- Born-Again Networks (Furlanello et al., 2018): Self-distillation
| |
| +-- Deep Mutual Learning (Zhang et al., 2018): Online KD
| |
| +-- CRD (Tian et al., 2020): Contrastive KD
| |
| +-- Data-Free KD (Chen et al., 2019)
| |
| +-- RKD (Park et al., 2019): Relation-based KD
| |
| +-- LLM Distillation (2023-):
| |
| +-- Alpaca/Vicuna: Black-box distillation
| +-- MiniLLM (Gu et al., 2024): Reverse KL
| +-- Minitron (NVIDIA, 2024): Pruning + KD
| +-- GKD (Agarwal et al., 2024): On-policy
|
+-- Pruning: 불필요한 파라미터 제거
|
+-- Quantization: 비트 수 줄이기
|
+-- Neural Architecture Search (NAS)
참고 자료¶
핵심 논문¶
- Hinton, G. et al. (2015). Distilling the Knowledge in a Neural Network. NeurIPS Workshop.
- Romero, A. et al. (2015). FitNets: Hints for Thin Deep Nets. ICLR 2015.
- Zagoruyko, S. & Komodakis, N. (2017). Paying More Attention to Attention. ICLR 2017.
- Zhang, Y. et al. (2018). Deep Mutual Learning. CVPR 2018.
- Park, W. et al. (2019). Relational Knowledge Distillation. CVPR 2019.
- Tian, Y. et al. (2020). Contrastive Representation Distillation. ICLR 2020.
최신 연구 (LLM)¶
- Gu, Y. et al. (2024). MiniLLM: Knowledge Distillation of Large Language Models. ICLR 2024.
- Agarwal, R. et al. (2024). On-Policy Distillation of Language Models: Learning from Self-Generated Mistakes. ICLR 2024.
- Muralidharan, S. et al. (2024). Compact Language Models via Pruning and Knowledge Distillation (Minitron). NeurIPS 2024.
- arXiv:2402.13116 (2024). A Survey on Knowledge Distillation of Large Language Models.
- arXiv:2503.12067 (2025). A Comprehensive Survey on Knowledge Distillation.
Survey 논문¶
- Gou, J. et al. (2021). Knowledge Distillation: A Survey. IJCV 2021.
- Wang, L. & Yoon, K. (2022). Knowledge Distillation and Student-Teacher Learning for Visual Intelligence. IEEE TPAMI.
관련 개념¶
- Curriculum Learning: 데이터 순서 최적화
- Preference Optimization: RLHF/DPO 학습
- Parameter-Efficient Fine-Tuning: LoRA 등 효율적 학습
- Mixture of Experts: 조건부 연산으로 효율성
- Neural Scaling Laws: 모델 크기-성능 관계