콘텐츠로 이동
Data Prep
상세

Knowledge Distillation (지식 증류)

메타 정보

항목 내용
분류 Model Compression / Transfer Learning
원논문 "Distilling the Knowledge in a Neural Network" (NeurIPS Workshop 2015)
주요 저자 Geoffrey Hinton, Oriol Vinyals, Jeff Dean (Google)
핵심 개념 대형 Teacher 모델의 지식을 소형 Student 모델로 전이
관련 분야 Model Compression, Pruning, Quantization, LLM Serving, Edge AI

정의

Knowledge Distillation(KD)은 대형(또는 앙상블) 모델(Teacher)이 학습한 지식을 소형 모델(Student)로 전달하는 학습 기법이다. Student는 hard label(정답)만 학습하는 것이 아니라 Teacher의 soft output(확률 분포)까지 모방함으로써, 단독 학습 대비 높은 성능을 달성한다.

핵심 통찰: Teacher의 출력 확률 분포에는 클래스 간 유사도, 데이터 구조 등 hard label에는 없는 "dark knowledge"가 담겨 있다.

핵심 아이디어

Dark Knowledge

Hard Label (One-Hot):
  고양이 이미지 -> [1.0, 0.0, 0.0]  (고양이, 개, 자동차)
  -> 정답 외 클래스 간 관계 정보 없음

Soft Label (Teacher Output, Temperature=1):
  고양이 이미지 -> [0.92, 0.07, 0.01]
  -> "이 이미지는 개와 약간 유사하고, 자동차와는 전혀 다르다"
  -> 클래스 간 관계 정보 포함

Soft Label (Temperature=5):
  고양이 이미지 -> [0.58, 0.33, 0.09]
  -> Temperature를 높이면 분포가 더 부드러워짐
  -> dark knowledge가 더 명확하게 드러남

Temperature Scaling

KD의 핵심 메커니즘은 softmax 함수에 temperature 파라미터 T를 도입하는 것이다.

Standard Softmax (T=1):
  p_i = exp(z_i) / sum_j exp(z_j)

Softened Softmax (T>1):
  q_i = exp(z_i / T) / sum_j exp(z_j / T)

T의 효과:
  T -> 0: 가장 큰 logit만 1에 가까워짐 (hard decision)
  T = 1: 표준 softmax
  T -> inf: 균등 분포에 수렴

예시 (logits = [5.0, 3.0, 1.0]):
  T=1:  [0.844, 0.114, 0.042]  -- 거의 one-hot
  T=3:  [0.506, 0.307, 0.187]  -- 구조 가시화
  T=5:  [0.424, 0.328, 0.248]  -- 더 부드러운 분포
  T=10: [0.372, 0.337, 0.291]  -- 거의 균등

Distillation Loss

Student 학습 시 두 가지 loss를 결합한다.

L_total = alpha * L_hard + (1 - alpha) * L_soft

L_hard = CrossEntropy(y_student, y_true)
  -> 정답 레이블에 대한 표준 loss

L_soft = T^2 * KL_Divergence(q_student(T), q_teacher(T))
  -> Teacher의 soft output을 모방하는 loss
  -> T^2: gradient 크기 보정 (T가 클수록 gradient가 작아지는 것 보상)

alpha: 두 loss의 가중치 (보통 0.1-0.5)
T: Temperature (보통 3-20)

KD의 분류 체계

지식의 유형에 따른 분류

1. Response-Based KD (출력 기반)
   Teacher의 최종 출력(logits/softmax)을 모방
   -> 가장 기본적이고 범용적

2. Feature-Based KD (중간 표현 기반)
   Teacher의 중간 레이어 activation을 모방
   -> FitNets (Romero et al., ICLR 2015)
   -> 더 풍부한 지식 전달 가능

3. Relation-Based KD (관계 기반)
   샘플 간 또는 레이어 간 관계 구조를 모방
   -> RKD (Park et al., CVPR 2019)
   -> 데이터의 구조적 정보 보존

Response-Based KD

Teacher:  x -> [Layer1] -> [Layer2] -> ... -> [Logits] -> [Softmax(T)]
                                                              |
                                                        soft targets
                                                              |
Student:  x -> [Layer1] -> [Layer2] -> [Logits] -> [Softmax(T)]
                                                       |
                                                  L_KL(student, teacher)

특징:
  + 구현이 간단
  + Teacher 내부 구조에 무관
  + 다양한 태스크에 적용 가능
  - 중간 레이어의 풍부한 정보 미활용

Feature-Based KD (FitNets)

Teacher:  x -> [T_Layer1] -> [T_Layer2] -> [T_Layer3] -> output
                   |              |              |
                hint_1         hint_2         hint_3
                   |              |              |
               [Adapter]     [Adapter]     [Adapter]  (차원 맞춤)
                   |              |              |
Student:  x -> [S_Layer1] -> [S_Layer2] -> [S_Layer3] -> output

L_feature = sum_l || f_adapter(F_student^l) - F_teacher^l ||^2

f_adapter: 차원 변환 함수 (Student와 Teacher의 hidden dim이 다를 경우)

장점:
  + 중간 표현의 풍부한 지식 전달
  + Student가 Teacher의 추론 과정을 학습

단점:
  - Teacher-Student 레이어 매핑 설계 필요
  - Adapter 함수의 설계 및 학습 비용

Relation-Based KD

데이터 포인트 간 관계를 보존:

Teacher에서의 관계:
  Sample A --0.8-- Sample B
  Sample A --0.2-- Sample C
  Sample B --0.6-- Sample C

Student도 동일한 관계 구조를 학습:
  L_relation = || G_student(x_i, x_j) - G_teacher(x_i, x_j) ||^2

G: 두 샘플 간의 관계 함수
  - 거리(distance): ||f(x_i) - f(x_j)||
  - 각도(angle): cos(f(x_i), f(x_j))
  - 상관(correlation): Gram matrix

장점:
  + 데이터의 구조적 정보 보존
  + Teacher-Student 아키텍처 차이에 강건

단점:
  - 배치 내 모든 쌍 계산 (O(N^2) 복잡도)

학습 방식에 따른 분류

1. Offline Distillation

Phase 1: Teacher 모델 학습 (독립적)
  Teacher_model = train(large_model, dataset)

Phase 2: Student 모델 학습 (Teacher 고정)
  for batch in dataset:
      teacher_output = Teacher_model(batch)  # frozen
      student_output = Student_model(batch)
      loss = alpha * CE(student_output, labels) +
             (1-alpha) * KL(student_output/T, teacher_output/T)

특징:
  + 가장 일반적이고 안정적
  + Teacher와 Student를 독립적으로 설계 가능
  - Teacher 학습 비용이 큼
  - 2단계 학습 파이프라인

2. Online Distillation

Teacher와 Student를 동시에 학습:

  Model_1 --soft labels--> Model_2
  Model_2 --soft labels--> Model_1

Deep Mutual Learning (Zhang et al., CVPR 2018):
  두 모델이 서로의 Teacher 역할

  L_1 = CE(p_1, y) + KL(p_1, p_2)
  L_2 = CE(p_2, y) + KL(p_2, p_1)

  -> 사전 학습된 대형 Teacher가 불필요
  -> 두 모델 모두 성능 향상
  -> 동일 크기 모델끼리도 효과적

3. Self-Distillation

동일 모델이 Teacher와 Student 역할:

방법 1: Born-Again Networks (Furlanello et al., 2018)
  Generation 0: 모델을 정상 학습
  Generation 1: Generation 0을 Teacher로 동일 구조 모델 학습
  Generation 2: Generation 1을 Teacher로 학습
  ...
  -> 세대를 거듭할수록 성능 향상 (diminishing returns)

방법 2: Be Your Own Teacher (Zhang et al., 2019)
  모델의 깊은 레이어가 얕은 레이어를 가르침:

  [Layer 1] -> [Layer 2] -> ... -> [Layer N] -> Final Output
       |                                |
  [Classifier_1]                  [Classifier_N]
       |                                |
       +--- L_KL(shallow, deep) ---+

방법 3: Progressive Self-Distillation
  학습 중 과거의 자신(EMA)이 Teacher:

  theta_teacher = beta * theta_teacher + (1 - beta) * theta_student

LLM을 위한 Knowledge Distillation

2024-2025년, LLM 시대에 KD의 중요성이 크게 증가했다.

LLM KD의 특수성

기존 KD:
  Teacher: CNN/ResNet (수백만 파라미터)
  Student: 작은 CNN (수십만 파라미터)
  지식: Soft label (분류 확률)

LLM KD:
  Teacher: GPT-4, Claude 4 (수천억 파라미터, API만 접근 가능)
  Student: LLaMA-7B, Mistral-7B (수십억 파라미터)
  지식: 생성 텍스트, reasoning chain, 능력(skill)

핵심 차이:
  1. Teacher 내부(logits)에 접근 불가 (black-box)
  2. 분류가 아닌 생성(autoregressive) 태스크
  3. 단일 태스크가 아닌 다양한 능력 전이
  4. 데이터 생성 자체가 KD의 핵심

White-Box vs Black-Box Distillation

White-Box (Teacher 내부 접근 가능):
  Teacher logits 직접 활용
  -> 전통적 KD와 동일한 방식
  -> 오픈소스 Teacher에서만 가능

  예: LLaMA-70B -> LLaMA-7B
  L = CE(student_logits, labels) + KL(student_logits/T, teacher_logits/T)

Black-Box (API만 접근 가능):
  Teacher의 생성 텍스트만 활용
  -> Teacher에게 데이터 생성 요청
  -> 생성된 데이터로 Student 학습 (SFT)

  예: GPT-4 -> LLaMA-7B
  Step 1: GPT-4로 고품질 (prompt, response) 쌍 생성
  Step 2: 생성 데이터로 Student SFT

LLM KD 주요 방법론

1. Symbolic Knowledge Distillation

Teacher가 지식을 텍스트 형태로 생성:

(a) 데이터 증강:
  Teacher: "수학 문제 100개를 풀이와 함께 생성해줘"
  -> 고품질 (문제, 풀이) 데이터셋 생성
  -> Student가 이 데이터로 학습

(b) Chain-of-Thought (CoT) Distillation:
  Teacher: 
    Q: "234 + 567 = ?"
    A: "200+500=700, 30+60=90, 4+7=11, 700+90+11=801"

  Student는 CoT를 포함한 답변을 학습
  -> 추론 능력(reasoning)까지 전이

(c) Skill-Specific Distillation:
  특정 능력만 선별적으로 증류
  - 코드 생성 능력: Code Alpaca
  - 수학 능력: WizardMath
  - 대화 능력: Vicuna (ShareGPT 데이터)

2. Logit-Based LLM Distillation

오픈소스 Teacher의 next-token logits 활용:

Token-Level KD:
  Input: "The capital of France is"

  Teacher logits: [Paris: 0.85, Lyon: 0.08, Marseille: 0.03, ...]
  Student logits: [Paris: 0.60, London: 0.15, Lyon: 0.05, ...]

  L = sum_t KL(student_logits_t / T, teacher_logits_t / T)

  각 토큰 위치에서 Teacher의 분포를 모방

MiniLLM (Gu et al., ICLR 2024):
  Reverse KL 사용으로 mode-seeking 특성 활용
  -> Student가 Teacher의 고확률 영역에 집중
  -> 환각(hallucination) 감소

  L = KL(p_student || p_teacher)  (reverse direction)

3. NVIDIA Minitron 접근법

Pruning + Distillation 결합:

Step 1: 대형 모델에서 중요도가 낮은 뉴런/레이어 제거 (Pruning)
  15B -> 8B (레이어/헤드/채널 pruning)

Step 2: Pruned 모델을 원본의 KD로 복구
  Teacher: 원본 15B 모델
  Student: Pruned 8B 모델
  학습: 원래 토큰의 1/40만 사용

결과 (Minitron, NeurIPS 2024):
  Nemotron-4 15B -> 8B:
    - LM Eval 평균: 원본의 98.7% 성능
    - 학습 비용: 원본 학습의 2.5%
    - 처음부터 학습한 8B 대비 우수

Self-Distillation for LLM

LLM이 스스로의 Teacher 역할:

(a) Rejection Sampling + Self-Training:
  Step 1: LLM이 문제를 여러 번 풀기 (N=64)
  Step 2: 정답인 풀이만 선별
  Step 3: 선별된 풀이로 SFT
  -> STaR (Zelikman et al., NeurIPS 2022)

(b) Self-Play / Self-Improvement:
  Step 1: 현재 모델로 데이터 생성
  Step 2: 생성 데이터를 평가/필터링
  Step 3: 고품질 데이터로 재학습
  -> 반복하면서 점진적 개선

(c) On-Policy Distillation:
  Student가 직접 생성한 텍스트에 대해
  Teacher가 토큰별 확률을 제공
  -> Distribution mismatch 감소
  -> GKD (Agarwal et al., 2024)

고급 기법

Attention Transfer

Zagoruyko & Komodakis (ICLR 2017). Teacher의 attention map을 Student에게 전달한다.

Teacher Attention Map:
  A_T = sum_i |F_T^i|^2  (채널 방향 제곱합)
  -> 어디에 주목하는지를 나타냄

Student Attention Map:
  A_S = sum_i |F_S^i|^2

Attention Transfer Loss:
  L_AT = sum_l || A_S^l / ||A_S^l||_2 - A_T^l / ||A_T^l||_2 ||_2

  -> L2 정규화된 attention map 간 거리

장점:
  + 공간적 주의(spatial attention) 패턴 전이
  + Feature map 크기만 맞으면 적용 가능
  + Teacher의 "무엇을 보는가"를 직접 학습

Contrastive Representation Distillation (CRD)

Tian et al. (ICLR 2020). Contrastive learning 원리로 KD를 수행한다.

핵심 아이디어:
  같은 입력 x에 대해:
    Teacher 표현 t = f_T(x)  -- positive pair
    Student 표현 s = f_S(x)  -- positive pair

  다른 입력 x'에 대해:
    Teacher 표현 t' = f_T(x')  -- negative pair

InfoNCE Loss:
  L_CRD = -log(exp(s . t / tau) / (exp(s . t / tau) + sum_{t'} exp(s . t' / tau)))

장점:
  + 구조적 관계 보존
  + Teacher-Student 차원 불일치에 강건
  + Response/Feature-based KD 대비 우수한 성능

Multi-Teacher Distillation

여러 Teacher의 지식을 결합:

방법 1: 평균 (Simple Averaging)
  q_ensemble = (1/K) * sum_k q_teacher_k
  -> 모든 Teacher를 동등하게 취급

방법 2: 가중 평균 (Weighted Averaging)
  q_ensemble = sum_k w_k * q_teacher_k
  w_k: Teacher k의 가중치 (성능/신뢰도 기반)

방법 3: Feature-level Aggregation
  각 Teacher의 중간 표현을 개별적으로 증류
  L = sum_k lambda_k * L_feature(S, T_k)

방법 4: Task-Specific Selection
  태스크별로 가장 적합한 Teacher 선택
  -> Routing network가 Teacher 선택

Data-Free Distillation

Teacher 학습에 사용된 원본 데이터 없이 KD를 수행한다.

방법 1: Generator 기반 (DAFL, Chen et al., 2019)
  Generator G -> 합성 데이터 생성
  Teacher가 합성 데이터에 soft label 제공
  Student가 soft label로 학습

  G 학습: Teacher의 출력 엔트로피 최소화
  -> Teacher가 자신 있는 데이터를 생성하도록 유도

방법 2: Batch Normalization 통계 활용
  Teacher의 BN 레이어에 저장된 mean/variance 활용
  -> 원본 데이터 분포를 근사하는 합성 데이터 생성

방법 3: Feature Map Inversion
  Teacher의 중간 feature map을 역변환하여 입력 복원

활용:
  - 개인정보 보호 (원본 데이터 공유 불가)
  - 데이터 접근 제한 환경
  - 의료/금융 등 규제 분야

Python 구현

기본 Knowledge Distillation

import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import DataLoader


class DistillationLoss(nn.Module):
    """Knowledge Distillation Loss"""

    def __init__(
        self,
        temperature: float = 4.0,
        alpha: float = 0.3,
        reduction: str = 'mean'
    ):
        """
        Args:
            temperature: Softmax temperature (T)
            alpha: Hard loss 가중치 (1-alpha = Soft loss 가중치)
            reduction: 'mean' or 'sum'
        """
        super().__init__()
        self.temperature = temperature
        self.alpha = alpha
        self.reduction = reduction
        self.ce_loss = nn.CrossEntropyLoss(reduction=reduction)

    def forward(
        self,
        student_logits: torch.Tensor,
        teacher_logits: torch.Tensor,
        labels: torch.Tensor
    ) -> torch.Tensor:
        """
        Args:
            student_logits: Student 모델 출력 [B, C]
            teacher_logits: Teacher 모델 출력 [B, C]
            labels: 정답 레이블 [B]
        """
        # Hard loss: Student vs Ground Truth
        hard_loss = self.ce_loss(student_logits, labels)

        # Soft loss: Student vs Teacher (with temperature)
        T = self.temperature
        student_soft = F.log_softmax(student_logits / T, dim=1)
        teacher_soft = F.softmax(teacher_logits / T, dim=1)

        soft_loss = F.kl_div(
            student_soft,
            teacher_soft,
            reduction='batchmean'
        ) * (T ** 2)

        # Combined loss
        loss = self.alpha * hard_loss + (1 - self.alpha) * soft_loss

        return loss


class FeatureDistillationLoss(nn.Module):
    """Feature-Based Knowledge Distillation (FitNets 스타일)"""

    def __init__(
        self,
        student_channels: list,
        teacher_channels: list
    ):
        """
        Args:
            student_channels: Student 중간 레이어 채널 수 리스트
            teacher_channels: Teacher 중간 레이어 채널 수 리스트
        """
        super().__init__()
        assert len(student_channels) == len(teacher_channels)

        # 차원 맞춤용 1x1 conv (Student -> Teacher 차원)
        self.adapters = nn.ModuleList([
            nn.Conv2d(s_ch, t_ch, kernel_size=1, bias=False)
            for s_ch, t_ch in zip(student_channels, teacher_channels)
        ])

    def forward(
        self,
        student_features: list,
        teacher_features: list
    ) -> torch.Tensor:
        """
        Args:
            student_features: Student 중간 feature maps 리스트
            teacher_features: Teacher 중간 feature maps 리스트
        """
        loss = 0
        for adapter, s_feat, t_feat in zip(
            self.adapters, student_features, teacher_features
        ):
            # 차원 맞춤
            s_adapted = adapter(s_feat)

            # 공간 크기 맞춤 (필요시)
            if s_adapted.shape[2:] != t_feat.shape[2:]:
                s_adapted = F.adaptive_avg_pool2d(s_adapted, t_feat.shape[2:])

            # L2 distance
            loss += F.mse_loss(s_adapted, t_feat.detach())

        return loss


class AttentionTransferLoss(nn.Module):
    """Attention Transfer (Zagoruyko & Komodakis, 2017)"""

    def __init__(self, p: int = 2):
        """
        Args:
            p: Attention map 계산 시 norm 차수
        """
        super().__init__()
        self.p = p

    def _attention_map(self, feature: torch.Tensor) -> torch.Tensor:
        """Feature map -> Attention map (채널 방향 집약)"""
        # feature: [B, C, H, W]
        # attention: [B, H*W]
        att = feature.pow(self.p).mean(dim=1).view(feature.size(0), -1)
        # L2 정규화
        att = F.normalize(att, p=2, dim=1)
        return att

    def forward(
        self,
        student_features: list,
        teacher_features: list
    ) -> torch.Tensor:
        loss = 0
        for s_feat, t_feat in zip(student_features, teacher_features):
            s_att = self._attention_map(s_feat)
            t_att = self._attention_map(t_feat.detach())
            loss += (s_att - t_att).pow(2).mean()
        return loss

Knowledge Distillation Trainer

class KDTrainer:
    """Knowledge Distillation 학습기"""

    def __init__(
        self,
        teacher: nn.Module,
        student: nn.Module,
        train_loader: DataLoader,
        val_loader: DataLoader,
        temperature: float = 4.0,
        alpha: float = 0.3,
        learning_rate: float = 1e-3,
        feature_weight: float = 0.0,
        attention_weight: float = 0.0,
        device: str = 'cuda'
    ):
        self.teacher = teacher.to(device).eval()
        self.student = student.to(device)
        self.train_loader = train_loader
        self.val_loader = val_loader
        self.device = device

        # Teacher는 학습하지 않음
        for param in self.teacher.parameters():
            param.requires_grad = False

        # Losses
        self.kd_loss = DistillationLoss(temperature, alpha)
        self.optimizer = torch.optim.Adam(
            self.student.parameters(), lr=learning_rate
        )
        self.scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(
            self.optimizer, T_max=100
        )

        self.feature_weight = feature_weight
        self.attention_weight = attention_weight

    def train_epoch(self) -> dict:
        self.student.train()
        total_loss = 0
        correct = 0
        total = 0

        for data, targets in self.train_loader:
            data = data.to(self.device)
            targets = targets.to(self.device)

            # Teacher forward (no grad)
            with torch.no_grad():
                teacher_logits = self.teacher(data)

            # Student forward
            student_logits = self.student(data)

            # KD Loss
            loss = self.kd_loss(student_logits, teacher_logits, targets)

            self.optimizer.zero_grad()
            loss.backward()
            self.optimizer.step()

            total_loss += loss.item() * data.size(0)
            _, predicted = student_logits.max(1)
            correct += predicted.eq(targets).sum().item()
            total += targets.size(0)

        self.scheduler.step()

        return {
            'loss': total_loss / total,
            'accuracy': correct / total
        }

    @torch.no_grad()
    def evaluate(self) -> dict:
        self.student.eval()
        correct = 0
        total = 0

        for data, targets in self.val_loader:
            data = data.to(self.device)
            targets = targets.to(self.device)
            outputs = self.student(data)
            _, predicted = outputs.max(1)
            correct += predicted.eq(targets).sum().item()
            total += targets.size(0)

        return {'accuracy': correct / total}

    def train(self, num_epochs: int = 100, log_interval: int = 10):
        best_acc = 0

        for epoch in range(num_epochs):
            train_metrics = self.train_epoch()

            if (epoch + 1) % log_interval == 0:
                val_metrics = self.evaluate()
                print(
                    f"Epoch {epoch+1}/{num_epochs} | "
                    f"Train Loss: {train_metrics['loss']:.4f} | "
                    f"Train Acc: {train_metrics['accuracy']:.4f} | "
                    f"Val Acc: {val_metrics['accuracy']:.4f}"
                )
                if val_metrics['accuracy'] > best_acc:
                    best_acc = val_metrics['accuracy']

        print(f"\nBest validation accuracy: {best_acc:.4f}")
        return best_acc

LLM Black-Box Distillation 파이프라인

import json
from dataclasses import dataclass
from typing import Optional


@dataclass
class DistillationSample:
    """LLM KD용 데이터 샘플"""
    prompt: str
    teacher_response: str
    teacher_model: str
    task_type: str  # 'reasoning', 'code', 'conversation', ...
    quality_score: Optional[float] = None


class LLMDistillationPipeline:
    """
    Black-Box LLM Distillation 파이프라인

    Teacher API에서 고품질 데이터를 생성하고
    Student 모델을 SFT로 학습시킴
    """

    def __init__(
        self,
        teacher_api,        # OpenAI/Anthropic 등 API 클라이언트
        student_model,      # HuggingFace 모델
        student_tokenizer,  # HuggingFace 토크나이저
        task_prompts: dict  # 태스크별 프롬프트 템플릿
    ):
        self.teacher_api = teacher_api
        self.student_model = student_model
        self.tokenizer = student_tokenizer
        self.task_prompts = task_prompts

    def generate_distillation_data(
        self,
        seed_prompts: list,
        task_type: str = 'reasoning',
        n_samples: int = 1000,
        n_responses_per_prompt: int = 3,
        temperature: float = 0.7
    ) -> list:
        """
        Teacher로부터 학습 데이터 생성

        Args:
            seed_prompts: 시드 프롬프트 리스트
            task_type: 태스크 유형
            n_samples: 목표 샘플 수
            n_responses_per_prompt: 프롬프트당 응답 수
            temperature: Teacher 생성 temperature

        Returns:
            DistillationSample 리스트
        """
        samples = []

        for prompt in seed_prompts[:n_samples]:
            # 시스템 프롬프트로 품질 유도
            system_prompt = self.task_prompts.get(
                task_type,
                "Provide a detailed, accurate, and helpful response."
            )

            for _ in range(n_responses_per_prompt):
                response = self.teacher_api.chat.completions.create(
                    model="gpt-4",
                    messages=[
                        {"role": "system", "content": system_prompt},
                        {"role": "user", "content": prompt}
                    ],
                    temperature=temperature,
                    max_tokens=2048
                )

                teacher_text = response.choices[0].message.content

                sample = DistillationSample(
                    prompt=prompt,
                    teacher_response=teacher_text,
                    teacher_model="gpt-4",
                    task_type=task_type
                )
                samples.append(sample)

        return samples

    def filter_by_quality(
        self,
        samples: list,
        min_length: int = 50,
        max_length: int = 4096,
        dedup: bool = True
    ) -> list:
        """생성 데이터 품질 필터링"""
        filtered = []
        seen_responses = set()

        for sample in samples:
            # 길이 필터
            if len(sample.teacher_response) < min_length:
                continue
            if len(sample.teacher_response) > max_length:
                continue

            # 중복 제거
            if dedup:
                response_hash = hash(sample.teacher_response[:200])
                if response_hash in seen_responses:
                    continue
                seen_responses.add(response_hash)

            filtered.append(sample)

        return filtered

    def prepare_sft_data(self, samples: list) -> list:
        """SFT 학습용 데이터 변환"""
        sft_data = []

        for sample in samples:
            # Chat template 적용
            messages = [
                {"role": "user", "content": sample.prompt},
                {"role": "assistant", "content": sample.teacher_response}
            ]

            text = self.tokenizer.apply_chat_template(
                messages,
                tokenize=False,
                add_generation_prompt=False
            )

            sft_data.append({
                "text": text,
                "task_type": sample.task_type
            })

        return sft_data

    def save_dataset(self, samples: list, path: str):
        """데이터셋 저장"""
        with open(path, 'w', encoding='utf-8') as f:
            for sample in samples:
                f.write(json.dumps(sample, ensure_ascii=False) + '\n')
        print(f"Saved {len(samples)} samples to {path}")


class CoTDistillation:
    """Chain-of-Thought Distillation"""

    def __init__(self, teacher_api):
        self.teacher_api = teacher_api

    def generate_cot_data(
        self,
        questions: list,
        subject: str = 'math'
    ) -> list:
        """
        Teacher로부터 CoT 추론 과정 생성
        """
        cot_prompt = """
Solve the following problem step by step.
Show your reasoning process clearly.
Format:
Step 1: [reasoning]
Step 2: [reasoning]
...
Answer: [final answer]
"""
        samples = []

        for question in questions:
            response = self.teacher_api.chat.completions.create(
                model="gpt-4",
                messages=[
                    {"role": "system", "content": cot_prompt},
                    {"role": "user", "content": question}
                ],
                temperature=0.3
            )

            teacher_cot = response.choices[0].message.content

            samples.append({
                "question": question,
                "cot_response": teacher_cot,
                "subject": subject
            })

        return samples

완전한 학습 예시 (CIFAR-10)

import torch
import torch.nn as nn
from torchvision import datasets, transforms
from torch.utils.data import DataLoader


class TeacherNet(nn.Module):
    """Teacher: 큰 CNN"""
    def __init__(self):
        super().__init__()
        self.features = nn.Sequential(
            nn.Conv2d(3, 128, 3, padding=1), nn.BatchNorm2d(128), nn.ReLU(),
            nn.Conv2d(128, 128, 3, padding=1), nn.BatchNorm2d(128), nn.ReLU(),
            nn.MaxPool2d(2),
            nn.Conv2d(128, 256, 3, padding=1), nn.BatchNorm2d(256), nn.ReLU(),
            nn.Conv2d(256, 256, 3, padding=1), nn.BatchNorm2d(256), nn.ReLU(),
            nn.MaxPool2d(2),
            nn.Conv2d(256, 512, 3, padding=1), nn.BatchNorm2d(512), nn.ReLU(),
            nn.MaxPool2d(2),
        )
        self.classifier = nn.Sequential(
            nn.Linear(512 * 4 * 4, 1024), nn.ReLU(), nn.Dropout(0.5),
            nn.Linear(1024, 10)
        )

    def forward(self, x):
        x = self.features(x)
        x = x.view(x.size(0), -1)
        return self.classifier(x)


class StudentNet(nn.Module):
    """Student: 작은 CNN (Teacher의 1/4 크기)"""
    def __init__(self):
        super().__init__()
        self.features = nn.Sequential(
            nn.Conv2d(3, 32, 3, padding=1), nn.BatchNorm2d(32), nn.ReLU(),
            nn.MaxPool2d(2),
            nn.Conv2d(32, 64, 3, padding=1), nn.BatchNorm2d(64), nn.ReLU(),
            nn.MaxPool2d(2),
            nn.Conv2d(64, 128, 3, padding=1), nn.BatchNorm2d(128), nn.ReLU(),
            nn.MaxPool2d(2),
        )
        self.classifier = nn.Sequential(
            nn.Linear(128 * 4 * 4, 256), nn.ReLU(),
            nn.Linear(256, 10)
        )

    def forward(self, x):
        x = self.features(x)
        x = x.view(x.size(0), -1)
        return self.classifier(x)


def run_distillation_experiment():
    """KD 전체 실험: Teacher 학습 -> Student KD vs Student 단독"""
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

    # 데이터
    transform_train = transforms.Compose([
        transforms.RandomCrop(32, padding=4),
        transforms.RandomHorizontalFlip(),
        transforms.ToTensor(),
        transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2470, 0.2435, 0.2616))
    ])
    transform_test = transforms.Compose([
        transforms.ToTensor(),
        transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2470, 0.2435, 0.2616))
    ])

    train_set = datasets.CIFAR10('./data', train=True, download=True, transform=transform_train)
    test_set = datasets.CIFAR10('./data', train=False, transform=transform_test)

    train_loader = DataLoader(train_set, batch_size=128, shuffle=True, num_workers=2)
    test_loader = DataLoader(test_set, batch_size=100, num_workers=2)

    # === Phase 1: Teacher 학습 ===
    print("=" * 50)
    print("Phase 1: Training Teacher")
    print("=" * 50)

    teacher = TeacherNet().to(device)
    teacher_optimizer = torch.optim.SGD(
        teacher.parameters(), lr=0.1, momentum=0.9, weight_decay=5e-4
    )
    teacher_scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(
        teacher_optimizer, T_max=100
    )
    criterion = nn.CrossEntropyLoss()

    teacher_params = sum(p.numel() for p in teacher.parameters())
    print(f"Teacher parameters: {teacher_params:,}")

    for epoch in range(100):
        teacher.train()
        for data, targets in train_loader:
            data, targets = data.to(device), targets.to(device)
            teacher_optimizer.zero_grad()
            loss = criterion(teacher(data), targets)
            loss.backward()
            teacher_optimizer.step()
        teacher_scheduler.step()

        if (epoch + 1) % 20 == 0:
            teacher.eval()
            correct = sum(
                teacher(d.to(device)).argmax(1).eq(t.to(device)).sum().item()
                for d, t in test_loader
            )
            print(f"Teacher Epoch {epoch+1}: {correct/len(test_set):.4f}")

    teacher.eval()

    # === Phase 2: Student with KD ===
    print("\n" + "=" * 50)
    print("Phase 2: Student with Knowledge Distillation")
    print("=" * 50)

    student_kd = StudentNet().to(device)
    student_params = sum(p.numel() for p in student_kd.parameters())
    print(f"Student parameters: {student_params:,}")
    print(f"Compression ratio: {teacher_params/student_params:.1f}x")

    kd_trainer = KDTrainer(
        teacher=teacher,
        student=student_kd,
        train_loader=train_loader,
        val_loader=test_loader,
        temperature=4.0,
        alpha=0.3,
        learning_rate=1e-3,
        device=device
    )
    kd_acc = kd_trainer.train(num_epochs=100, log_interval=20)

    # === Phase 3: Student without KD (baseline) ===
    print("\n" + "=" * 50)
    print("Phase 3: Student without KD (baseline)")
    print("=" * 50)

    student_baseline = StudentNet().to(device)
    baseline_optimizer = torch.optim.Adam(student_baseline.parameters(), lr=1e-3)
    baseline_scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(
        baseline_optimizer, T_max=100
    )

    best_baseline = 0
    for epoch in range(100):
        student_baseline.train()
        for data, targets in train_loader:
            data, targets = data.to(device), targets.to(device)
            baseline_optimizer.zero_grad()
            loss = criterion(student_baseline(data), targets)
            loss.backward()
            baseline_optimizer.step()
        baseline_scheduler.step()

        if (epoch + 1) % 20 == 0:
            student_baseline.eval()
            correct = sum(
                student_baseline(d.to(device)).argmax(1).eq(t.to(device)).sum().item()
                for d, t in test_loader
            )
            acc = correct / len(test_set)
            best_baseline = max(best_baseline, acc)
            print(f"Baseline Epoch {epoch+1}: {acc:.4f}")

    # === 결과 비교 ===
    print("\n" + "=" * 50)
    print("Results")
    print("=" * 50)
    print(f"Student (with KD):    {kd_acc:.4f}")
    print(f"Student (no KD):      {best_baseline:.4f}")
    print(f"Improvement:          {(kd_acc - best_baseline)*100:.2f}%p")


if __name__ == "__main__":
    run_distillation_experiment()

실무 가이드라인

하이퍼파라미터 권장값

파라미터 권장 범위 설명
Temperature (T) 3 - 20 높을수록 soft (보통 4-8)
Alpha 0.1 - 0.5 Hard loss 가중치 (보통 0.3)
Student/Teacher 비율 1/4 - 1/10 모델 크기 비율
Learning Rate 1e-4 - 1e-3 Student 학습률
Feature Loss Weight 0.01 - 0.1 Feature KD 시 가중치

Temperature 선택 가이드

T가 낮을 때 (1-3):
  + Teacher의 top-1 예측에 충실
  + 분류 정확도 중심
  - dark knowledge 활용 부족
  적합: 쉬운 태스크, 클래스 수 적을 때

T가 높을 때 (10-20):
  + 클래스 간 관계 충분히 전달
  + 더 부드러운 학습 신호
  - 정보 희석 가능성
  적합: 클래스 수 많을 때, 유사 클래스 구분

일반적 시작점: T=4, alpha=0.3

언제 KD를 사용해야 하는가

효과적인 경우:
  + 대형 모델을 Edge/Mobile에 배포해야 할 때
  + 추론 지연시간(latency) 감소가 필요할 때
  + 앙상블 모델을 단일 모델로 압축할 때
  + LLM API 비용 절감이 필요할 때
  + 특정 도메인에 소형 모델을 특화할 때
  + Teacher 성능의 90%+를 유지하면서 비용 1/10

비효과적이거나 불필요한 경우:
  - Student와 Teacher의 성능 차이가 너무 클 때
  - 데이터가 매우 적을 때 (Teacher 자체가 부정확)
  - 단순한 태스크 (Student 단독으로 충분)
  - 실시간 추론이 불필요한 오프라인 환경

주의사항

1. Teacher 품질이 핵심
   - 낮은 품질의 Teacher -> Student도 낮은 품질
   - Teacher의 오류/편향이 Student에게 전달됨
   - Teacher 검증 후 증류 시작

2. Capacity Gap 문제
   - Student가 너무 작으면 Teacher 지식을 수용 불가
   - Teacher-Student 크기 비율: 4-10배 권장
   - 너무 큰 gap은 중간 Teacher(TA)로 bridging
     Teacher -> Teaching Assistant -> Student

3. Task Mismatch
   - Teacher의 학습 태스크와 Student의 목표 태스크가 다르면 효과 감소
   - 도메인 특화 Teacher 사용 권장

4. 학습 안정성
   - Temperature가 너무 높으면 gradient vanishing
   - Alpha 값에 민감할 수 있음 -> grid search 권장
   - Warmup: 초기에는 hard loss 비중 높이고 점차 soft loss 증가

5. LLM 증류 시 주의
   - Teacher API 비용 관리 (데이터 생성 비용)
   - 생성 데이터 품질 검증 필수
   - 저작권/라이선스 확인 (일부 모델은 출력물 증류 금지)

성능 벤치마크 (대표 결과)

ImageNet Classification

Teacher Student Method Top-1 Acc (%)
ResNet-34 ResNet-18 Baseline (no KD) 69.75
ResNet-34 ResNet-18 Vanilla KD 71.03
ResNet-34 ResNet-18 FitNets 71.06
ResNet-34 ResNet-18 AT (Attention) 70.69
ResNet-34 ResNet-18 CRD 71.17
ResNet-34 ResNet-18 KD + CRD 71.38

LLM Distillation

Teacher Student Method MMLU (%)
GPT-4 LLaMA-7B SFT (Alpaca) 42.3
GPT-4 LLaMA-7B CoT Distillation 46.8
Nemotron-4 15B Minitron 8B Pruning + KD 67.2
Nemotron-4 15B (baseline) - From scratch 8B 65.1

관련 연구 흐름

Model Compression (Bucilua et al., 2006)
  |
  +-- Knowledge Distillation (Hinton et al., 2015)
  |       |
  |       +-- FitNets (Romero et al., 2015): Feature-based KD
  |       |
  |       +-- Attention Transfer (Zagoruyko, 2017)
  |       |
  |       +-- Born-Again Networks (Furlanello et al., 2018): Self-distillation
  |       |
  |       +-- Deep Mutual Learning (Zhang et al., 2018): Online KD
  |       |
  |       +-- CRD (Tian et al., 2020): Contrastive KD
  |       |
  |       +-- Data-Free KD (Chen et al., 2019)
  |       |
  |       +-- RKD (Park et al., 2019): Relation-based KD
  |       |
  |       +-- LLM Distillation (2023-):
  |               |
  |               +-- Alpaca/Vicuna: Black-box distillation
  |               +-- MiniLLM (Gu et al., 2024): Reverse KL
  |               +-- Minitron (NVIDIA, 2024): Pruning + KD
  |               +-- GKD (Agarwal et al., 2024): On-policy
  |
  +-- Pruning: 불필요한 파라미터 제거
  |
  +-- Quantization: 비트 수 줄이기
  |
  +-- Neural Architecture Search (NAS)

참고 자료

핵심 논문

  1. Hinton, G. et al. (2015). Distilling the Knowledge in a Neural Network. NeurIPS Workshop.
  2. Romero, A. et al. (2015). FitNets: Hints for Thin Deep Nets. ICLR 2015.
  3. Zagoruyko, S. & Komodakis, N. (2017). Paying More Attention to Attention. ICLR 2017.
  4. Zhang, Y. et al. (2018). Deep Mutual Learning. CVPR 2018.
  5. Park, W. et al. (2019). Relational Knowledge Distillation. CVPR 2019.
  6. Tian, Y. et al. (2020). Contrastive Representation Distillation. ICLR 2020.

최신 연구 (LLM)

  1. Gu, Y. et al. (2024). MiniLLM: Knowledge Distillation of Large Language Models. ICLR 2024.
  2. Agarwal, R. et al. (2024). On-Policy Distillation of Language Models: Learning from Self-Generated Mistakes. ICLR 2024.
  3. Muralidharan, S. et al. (2024). Compact Language Models via Pruning and Knowledge Distillation (Minitron). NeurIPS 2024.
  4. arXiv:2402.13116 (2024). A Survey on Knowledge Distillation of Large Language Models.
  5. arXiv:2503.12067 (2025). A Comprehensive Survey on Knowledge Distillation.

Survey 논문

  1. Gou, J. et al. (2021). Knowledge Distillation: A Survey. IJCV 2021.
  2. Wang, L. & Yoon, K. (2022). Knowledge Distillation and Student-Teacher Learning for Visual Intelligence. IEEE TPAMI.

관련 개념