콘텐츠로 이동
Data Prep
상세

Sparse Autoencoders (SAE)

개요

Sparse Autoencoders (SAE)는 신경망의 내부 활성화(activations)를 해석 가능한 희소 특징(sparse features)으로 분해하는 비지도 학습 기법이다. 딥러닝 모델이 학습하는 특징(features)이 중첩(superposition) 상태로 인코딩되어 단일 뉴런 분석으로는 파악하기 어려운 문제를 해결한다.

구분 설명
핵심 목표 모델 활성화를 해석 가능한 단일의미(monosemantic) 특징으로 분해
이론적 기반 Dictionary Learning, Sparse Coding
주요 적용 LLM Interpretability, AI Safety, Model Debugging
핵심 논문 Towards Monosemanticity (Anthropic, 2023)
관련 개념 Superposition, Polysemanticity, Mechanistic Interpretability

배경: Superposition 문제

Superposition이란

신경망은 뉴런 수보다 훨씬 많은 특징을 학습한다. 이 특징들이 희소하게 활성화될 경우, 모델은 동일한 뉴런 공간에 여러 특징을 중첩하여 저장한다.

d차원 공간에 n개 특징 저장 (n >> d)
  - 특징 i: 방향 벡터 f_i
  - 활성화: x = sum_i a_i * f_i  (a_i는 희소)
  - 특징 간 간섭 최소화를 위해 거의 직교

Polysemanticity

Superposition의 결과로 단일 뉴런이 여러 무관한 개념에 반응한다.

현상 예시
다의성 한 뉴런이 "학술 인용", "법률 용어", "특정 URL 패턴"에 동시 반응
해석 불가 뉴런별 분석으로는 의미 파악 어려움
분산 표현 하나의 개념이 여러 뉴런에 걸쳐 인코딩

SAE의 해결 방식

활성화 x (d차원) --> SAE --> 특징 f (n차원, n >> d)
                              |
                              v
                        각 f_i는 하나의 해석 가능한 개념

아키텍처

기본 구조

Encoder: f = ReLU(W_enc @ (x - b_dec) + b_enc)
Decoder: x_hat = W_dec @ f + b_dec

손실함수: L = ||x - x_hat||_2^2 + lambda * ||f||_1
구성요소 차원 역할
입력 x (batch, d_model) 모델의 활성화 벡터
인코더 W_enc (d_model, n_features) 활성화를 특징 공간으로 매핑
디코더 W_dec (n_features, d_model) 특징을 활성화 공간으로 복원
특징 f (batch, n_features) 희소 특징 활성화
확장 비율 n_features / d_model 보통 8x ~ 128x

Python 구현

import torch
import torch.nn as nn
import torch.nn.functional as F

class SparseAutoencoder(nn.Module):
    """
    기본 Sparse Autoencoder 구현

    Args:
        d_model: 입력 활성화 차원
        n_features: 학습할 특징 수 (보통 d_model의 8-64배)
        sparsity_coef: L1 희소성 페널티 계수
    """
    def __init__(self, d_model: int, n_features: int, sparsity_coef: float = 1e-3):
        super().__init__()

        # 인코더 (특징 추출)
        self.W_enc = nn.Parameter(torch.randn(d_model, n_features) * 0.01)
        self.b_enc = nn.Parameter(torch.zeros(n_features))

        # 디코더 (재구성)
        self.W_dec = nn.Parameter(torch.randn(n_features, d_model) * 0.01)
        self.b_dec = nn.Parameter(torch.zeros(d_model))

        self.sparsity_coef = sparsity_coef

        # 디코더 가중치 정규화 (단위 벡터)
        with torch.no_grad():
            self.W_dec.data = F.normalize(self.W_dec.data, dim=1)

    def encode(self, x: torch.Tensor) -> torch.Tensor:
        """활성화를 희소 특징으로 인코딩"""
        x_centered = x - self.b_dec
        pre_activation = x_centered @ self.W_enc + self.b_enc
        return F.relu(pre_activation)

    def decode(self, f: torch.Tensor) -> torch.Tensor:
        """희소 특징을 활성화로 디코딩"""
        return f @ self.W_dec + self.b_dec

    def forward(self, x: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]:
        """순전파: 인코딩 후 디코딩"""
        features = self.encode(x)
        reconstruction = self.decode(features)
        return features, reconstruction

    def loss(self, x: torch.Tensor) -> tuple[torch.Tensor, dict]:
        """
        전체 손실 계산

        Returns:
            total_loss: 재구성 손실 + 희소성 손실
            metrics: 개별 손실 값들
        """
        features, reconstruction = self.forward(x)

        # 재구성 손실 (MSE)
        recon_loss = F.mse_loss(reconstruction, x)

        # L1 희소성 손실
        sparsity_loss = features.abs().mean()

        # 전체 손실
        total_loss = recon_loss + self.sparsity_coef * sparsity_loss

        metrics = {
            "recon_loss": recon_loss.item(),
            "sparsity_loss": sparsity_loss.item(),
            "total_loss": total_loss.item(),
            "l0_sparsity": (features > 0).float().mean().item(),  # 활성화된 특징 비율
        }

        return total_loss, metrics

    @torch.no_grad()
    def normalize_decoder(self):
        """디코더 가중치를 단위 벡터로 정규화"""
        self.W_dec.data = F.normalize(self.W_dec.data, dim=1)

SAE 변형

1. TopK SAE

L1 페널티 대신 상위 K개 특징만 활성화하여 희소성을 직접 제어한다.

class TopKSparseAutoencoder(nn.Module):
    """
    TopK 활성화 기반 SAE
    L1 페널티 튜닝 없이 정확한 희소성 제어
    """
    def __init__(self, d_model: int, n_features: int, k: int = 32):
        super().__init__()
        self.W_enc = nn.Parameter(torch.randn(d_model, n_features) * 0.01)
        self.b_enc = nn.Parameter(torch.zeros(n_features))
        self.W_dec = nn.Parameter(torch.randn(n_features, d_model) * 0.01)
        self.b_dec = nn.Parameter(torch.zeros(d_model))
        self.k = k

        with torch.no_grad():
            self.W_dec.data = F.normalize(self.W_dec.data, dim=1)

    def encode(self, x: torch.Tensor) -> torch.Tensor:
        x_centered = x - self.b_dec
        pre_act = x_centered @ self.W_enc + self.b_enc

        # 상위 K개만 활성화
        topk_values, topk_indices = torch.topk(pre_act, self.k, dim=-1)

        # 희소 활성화 생성
        features = torch.zeros_like(pre_act)
        features.scatter_(-1, topk_indices, F.relu(topk_values))

        return features

    def decode(self, f: torch.Tensor) -> torch.Tensor:
        return f @ self.W_dec + self.b_dec

    def forward(self, x: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]:
        features = self.encode(x)
        reconstruction = self.decode(features)
        return features, reconstruction

    def loss(self, x: torch.Tensor) -> tuple[torch.Tensor, dict]:
        features, reconstruction = self.forward(x)
        recon_loss = F.mse_loss(reconstruction, x)

        # Auxiliary loss: TopK 외 특징들의 기울기 유지
        pre_act = (x - self.b_dec) @ self.W_enc + self.b_enc
        aux_loss = F.relu(pre_act).mean() * 0.01  # 작은 보조 손실

        return recon_loss + aux_loss, {
            "recon_loss": recon_loss.item(),
            "l0_sparsity": self.k / self.W_enc.shape[1]
        }

2. Gated SAE (Anthropic, 2024)

게이트 메커니즘으로 특징 활성화 여부와 크기를 분리하여 학습한다.

class GatedSparseAutoencoder(nn.Module):
    """
    Gated SAE: 특징 활성화 여부(gate)와 크기(magnitude)를 분리
    shrinkage bias 문제 해결
    """
    def __init__(self, d_model: int, n_features: int, sparsity_coef: float = 1e-3):
        super().__init__()

        # Gate 경로 (활성화 여부 결정)
        self.W_gate = nn.Parameter(torch.randn(d_model, n_features) * 0.01)
        self.b_gate = nn.Parameter(torch.zeros(n_features))

        # Magnitude 경로 (활성화 크기 결정)
        self.W_mag = nn.Parameter(torch.randn(d_model, n_features) * 0.01)
        self.b_mag = nn.Parameter(torch.zeros(n_features))
        self.r_mag = nn.Parameter(torch.ones(n_features))  # rescale

        # 디코더
        self.W_dec = nn.Parameter(torch.randn(n_features, d_model) * 0.01)
        self.b_dec = nn.Parameter(torch.zeros(d_model))

        self.sparsity_coef = sparsity_coef

        with torch.no_grad():
            self.W_dec.data = F.normalize(self.W_dec.data, dim=1)

    def encode(self, x: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]:
        x_centered = x - self.b_dec

        # Gate: 활성화 여부 (sigmoid -> Heaviside 근사)
        gate_pre = x_centered @ self.W_gate + self.b_gate
        gate = (gate_pre > 0).float()  # Hard gate

        # Magnitude: 활성화 크기
        mag_pre = x_centered @ self.W_mag + self.b_mag
        magnitude = F.relu(mag_pre) * torch.exp(self.r_mag)

        # 최종 특징 = gate * magnitude
        features = gate * magnitude

        return features, gate_pre

    def decode(self, f: torch.Tensor) -> torch.Tensor:
        return f @ self.W_dec + self.b_dec

    def forward(self, x: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
        features, gate_pre = self.encode(x)
        reconstruction = self.decode(features)
        return features, reconstruction, gate_pre

    def loss(self, x: torch.Tensor) -> tuple[torch.Tensor, dict]:
        features, reconstruction, gate_pre = self.forward(x)

        # 재구성 손실
        recon_loss = F.mse_loss(reconstruction, x)

        # Gate에 대한 희소성 손실 (L1 on gate activations)
        gate_sparsity = F.relu(gate_pre).mean()

        # Auxiliary loss: gate 학습 안정화
        aux_loss = F.mse_loss(
            self.decode(F.relu(gate_pre)),
            x - self.b_dec
        ) * 0.1

        total_loss = recon_loss + self.sparsity_coef * gate_sparsity + aux_loss

        return total_loss, {
            "recon_loss": recon_loss.item(),
            "gate_sparsity": gate_sparsity.item(),
            "l0_sparsity": (features > 0).float().mean().item()
        }

3. JumpReLU SAE (DeepMind, 2024)

학습 가능한 threshold를 가진 JumpReLU 활성화 함수 사용.

class JumpReLU(torch.autograd.Function):
    """
    JumpReLU: x > theta이면 x, 아니면 0
    theta는 학습 가능
    """
    @staticmethod
    def forward(ctx, x, theta):
        ctx.save_for_backward(x, theta)
        return torch.where(x > theta, x, torch.zeros_like(x))

    @staticmethod
    def backward(ctx, grad_output):
        x, theta = ctx.saved_tensors
        # x > theta인 경우에만 기울기 전파
        mask = (x > theta).float()
        grad_x = grad_output * mask
        # theta에 대한 기울기 (STE 근사)
        grad_theta = -grad_output * mask
        return grad_x, grad_theta.sum(dim=0)


class JumpReLUSAE(nn.Module):
    """
    JumpReLU SAE: 학습 가능한 threshold로 희소성 제어
    """
    def __init__(self, d_model: int, n_features: int, init_threshold: float = 0.1):
        super().__init__()
        self.W_enc = nn.Parameter(torch.randn(d_model, n_features) * 0.01)
        self.b_enc = nn.Parameter(torch.zeros(n_features))
        self.theta = nn.Parameter(torch.full((n_features,), init_threshold))

        self.W_dec = nn.Parameter(torch.randn(n_features, d_model) * 0.01)
        self.b_dec = nn.Parameter(torch.zeros(d_model))

        with torch.no_grad():
            self.W_dec.data = F.normalize(self.W_dec.data, dim=1)

    def encode(self, x: torch.Tensor) -> torch.Tensor:
        x_centered = x - self.b_dec
        pre_act = x_centered @ self.W_enc + self.b_enc
        return JumpReLU.apply(pre_act, self.theta)

    def decode(self, f: torch.Tensor) -> torch.Tensor:
        return f @ self.W_dec + self.b_dec

    def forward(self, x: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]:
        features = self.encode(x)
        reconstruction = self.decode(features)
        return features, reconstruction

SAE 변형 비교

변형 희소성 제어 장점 단점
Vanilla (L1) lambda 튜닝 구현 간단 shrinkage bias
TopK k 직접 지정 정확한 희소성 기울기 소실
Gated 분리된 gate shrinkage 해결 파라미터 2배
JumpReLU 학습된 threshold 적응적 희소성 학습 불안정

학습 방법

데이터 수집

from transformers import AutoModelForCausalLM, AutoTokenizer
import torch

def collect_activations(
    model_name: str,
    layer_idx: int,
    hook_point: str,  # "resid_pre", "resid_post", "mlp_out", etc.
    dataset,
    batch_size: int = 32,
    max_samples: int = 1_000_000
):
    """
    모델에서 특정 레이어의 활성화 수집

    Args:
        model_name: HuggingFace 모델 이름
        layer_idx: 활성화를 추출할 레이어 인덱스
        hook_point: 추출 위치 (residual stream, MLP 출력 등)
        dataset: 텍스트 데이터셋

    Returns:
        activations: (n_samples, d_model) 텐서
    """
    model = AutoModelForCausalLM.from_pretrained(model_name)
    tokenizer = AutoTokenizer.from_pretrained(model_name)

    activations = []

    # Hook 설정
    def hook_fn(module, input, output):
        # output shape: (batch, seq_len, d_model)
        # 모든 토큰 위치의 활성화 수집
        activations.append(output.detach().cpu())

    # 레이어에 따른 hook 등록
    if hook_point == "resid_post":
        handle = model.model.layers[layer_idx].register_forward_hook(hook_fn)
    elif hook_point == "mlp_out":
        handle = model.model.layers[layer_idx].mlp.register_forward_hook(hook_fn)

    model.eval()
    with torch.no_grad():
        for batch in dataset:
            inputs = tokenizer(batch, return_tensors="pt", padding=True, truncation=True)
            _ = model(**inputs)

            if sum(a.shape[0] * a.shape[1] for a in activations) >= max_samples:
                break

    handle.remove()

    # (n_samples, d_model) 형태로 변환
    all_activations = torch.cat(activations, dim=0)
    all_activations = all_activations.view(-1, all_activations.shape[-1])

    return all_activations[:max_samples]

학습 루프

def train_sae(
    sae: SparseAutoencoder,
    activations: torch.Tensor,
    batch_size: int = 4096,
    num_epochs: int = 1,
    lr: float = 3e-4,
    device: str = "cuda"
):
    """
    SAE 학습

    Args:
        sae: SAE 모델
        activations: 수집된 활성화 데이터
        batch_size: 배치 크기 (보통 4096-8192)
        num_epochs: 에폭 수 (보통 1-3)
        lr: 학습률
    """
    sae = sae.to(device)
    optimizer = torch.optim.AdamW(sae.parameters(), lr=lr, betas=(0.9, 0.999))

    # 학습률 스케줄러 (warmup + decay)
    num_steps = len(activations) // batch_size * num_epochs
    warmup_steps = num_steps // 10

    def lr_schedule(step):
        if step < warmup_steps:
            return step / warmup_steps
        return max(0.1, 1 - (step - warmup_steps) / (num_steps - warmup_steps))

    scheduler = torch.optim.lr_scheduler.LambdaLR(optimizer, lr_schedule)

    # 학습
    dataset = torch.utils.data.TensorDataset(activations)
    dataloader = torch.utils.data.DataLoader(
        dataset, batch_size=batch_size, shuffle=True
    )

    for epoch in range(num_epochs):
        total_loss = 0
        for batch in dataloader:
            x = batch[0].to(device)

            loss, metrics = sae.loss(x)

            optimizer.zero_grad()
            loss.backward()

            # 기울기 클리핑
            torch.nn.utils.clip_grad_norm_(sae.parameters(), 1.0)

            optimizer.step()
            scheduler.step()

            # 디코더 정규화 (단위 벡터 유지)
            sae.normalize_decoder()

            total_loss += loss.item()

        avg_loss = total_loss / len(dataloader)
        print(f"Epoch {epoch+1}: Loss = {avg_loss:.4f}")

    return sae

Dead Features 처리

일부 특징이 학습 중 활성화되지 않는 "dead features" 문제 해결.

def resample_dead_features(
    sae: SparseAutoencoder,
    activations: torch.Tensor,
    dead_threshold: float = 1e-6,
    resample_fraction: float = 0.5
):
    """
    Dead features 재초기화

    활성화 빈도가 threshold 미만인 특징들을
    높은 재구성 오차를 가진 샘플 방향으로 재초기화
    """
    with torch.no_grad():
        # 특징 활성화 빈도 계산
        features, _ = sae.forward(activations)
        activation_freq = (features > 0).float().mean(dim=0)

        # Dead features 식별
        dead_mask = activation_freq < dead_threshold
        n_dead = dead_mask.sum().item()

        if n_dead == 0:
            return 0

        # 높은 재구성 오차를 가진 샘플 식별
        _, reconstruction = sae.forward(activations)
        recon_errors = (activations - reconstruction).pow(2).sum(dim=1)

        # 상위 오차 샘플을 dead feature 방향으로 사용
        n_resample = min(n_dead, int(len(activations) * resample_fraction))
        top_error_indices = torch.topk(recon_errors, n_resample).indices

        # Dead features 재초기화
        new_directions = activations[top_error_indices]
        new_directions = F.normalize(new_directions, dim=1)

        dead_indices = torch.where(dead_mask)[0][:n_resample]

        sae.W_dec.data[dead_indices] = new_directions
        sae.W_enc.data[:, dead_indices] = new_directions.T
        sae.b_enc.data[dead_indices] = 0

    return n_dead

특징 분석

최대 활성화 샘플 찾기

def find_max_activating_examples(
    sae: SparseAutoencoder,
    activations: torch.Tensor,
    texts: list[str],
    feature_idx: int,
    top_k: int = 10
):
    """
    특정 특징에 가장 강하게 반응하는 예시 찾기
    """
    with torch.no_grad():
        features, _ = sae.forward(activations)
        feature_acts = features[:, feature_idx]

        top_indices = torch.topk(feature_acts, top_k).indices

    results = []
    for idx in top_indices:
        results.append({
            "text": texts[idx],
            "activation": feature_acts[idx].item()
        })

    return results

자동 특징 해석 (AutoInterp)

def auto_interpret_feature(
    max_activating_examples: list[dict],
    llm_client  # OpenAI/Anthropic API 클라이언트
) -> str:
    """
    LLM을 사용하여 특징의 의미 자동 추론

    Args:
        max_activating_examples: 최대 활성화 예시들
        llm_client: API 클라이언트

    Returns:
        feature_description: 특징에 대한 자연어 설명
    """
    prompt = """다음은 신경망의 특정 특징(feature)이 강하게 활성화되는 텍스트 예시들입니다.
이 특징이 감지하는 패턴이나 개념이 무엇인지 한 문장으로 설명해주세요.

예시들:
"""
    for ex in max_activating_examples[:10]:
        prompt += f"- (활성화: {ex['activation']:.2f}) {ex['text']}\n"

    prompt += "\n이 특징이 감지하는 것: "

    response = llm_client.generate(prompt)
    return response

특징 활성화 시각화

import matplotlib.pyplot as plt
import numpy as np

def visualize_feature_activations(
    sae: SparseAutoencoder,
    activations: torch.Tensor,
    feature_indices: list[int] = None,
    n_features: int = 50
):
    """
    특징 활성화 패턴 시각화
    """
    with torch.no_grad():
        features, _ = sae.forward(activations)

    if feature_indices is None:
        # 가장 자주 활성화되는 특징들
        activation_freq = (features > 0).float().mean(dim=0)
        feature_indices = torch.topk(activation_freq, n_features).indices.tolist()

    # 활성화 히트맵
    subset = features[:1000, feature_indices].cpu().numpy()

    fig, axes = plt.subplots(2, 2, figsize=(14, 10))

    # 1. 활성화 히트맵
    ax = axes[0, 0]
    im = ax.imshow(subset.T, aspect='auto', cmap='viridis')
    ax.set_xlabel('Sample')
    ax.set_ylabel('Feature')
    ax.set_title('Feature Activations')
    plt.colorbar(im, ax=ax)

    # 2. 활성화 빈도 분포
    ax = axes[0, 1]
    freq = (features > 0).float().mean(dim=0).cpu().numpy()
    ax.hist(freq, bins=50, edgecolor='black')
    ax.set_xlabel('Activation Frequency')
    ax.set_ylabel('Count')
    ax.set_title('Feature Activation Frequency Distribution')
    ax.set_yscale('log')

    # 3. L0 sparsity 분포 (샘플당 활성화된 특징 수)
    ax = axes[1, 0]
    l0 = (features > 0).sum(dim=1).cpu().numpy()
    ax.hist(l0, bins=50, edgecolor='black')
    ax.set_xlabel('Number of Active Features')
    ax.set_ylabel('Count')
    ax.set_title('L0 Sparsity per Sample')

    # 4. 특징 활성화 크기 분포
    ax = axes[1, 1]
    nonzero = features[features > 0].cpu().numpy()
    ax.hist(nonzero, bins=50, edgecolor='black')
    ax.set_xlabel('Activation Magnitude')
    ax.set_ylabel('Count')
    ax.set_title('Non-zero Activation Distribution')
    ax.set_yscale('log')

    plt.tight_layout()
    return fig

주요 연구

Anthropic의 연구

논문 연도 핵심 기여
Toy Models of Superposition 2022 Superposition 현상의 이론적 모델링
Towards Monosemanticity 2023 1-layer transformer에서 SAE로 해석 가능한 특징 추출
Scaling Monosemanticity 2024 Claude 3 Sonnet에서 수백만 개 특징 식별
Gated SAE 2024 Gate 메커니즘으로 shrinkage bias 해결
Circuit Tracing 2025 SAE 특징 기반 회로 추적 방법론

OpenAI의 연구

논문 연도 핵심 기여
Language Models Can Explain Neurons 2023 GPT-4로 뉴런 자동 해석
Extracting Concepts from GPT-4 2024 GPT-4에서 SAE 기반 개념 추출

커뮤니티 연구

연구 소속 핵심 기여
SAELens EleutherAI SAE 학습/분석 오픈소스 라이브러리
TransformerLens Neel Nanda MI 연구용 기반 라이브러리
Sparse Feature Circuits ICML 2024 SAE 특징으로 회로 발견 및 편집

응용

1. 모델 행동 분석

def analyze_model_behavior(
    model,
    sae: SparseAutoencoder,
    prompt: str,
    layer_idx: int
):
    """
    특정 입력에 대한 모델 행동을 SAE 특징으로 분석
    """
    # 활성화 추출
    activations = extract_activations(model, prompt, layer_idx)

    # SAE로 특징 분해
    features, reconstruction = sae.forward(activations)

    # 활성화된 특징 식별
    active_features = torch.where(features[0, -1] > 0)[0]  # 마지막 토큰

    # 각 특징의 해석 조회
    interpretations = []
    for feat_idx in active_features[:10]:
        interp = get_feature_interpretation(feat_idx)
        act_value = features[0, -1, feat_idx].item()
        interpretations.append({
            "feature": feat_idx.item(),
            "activation": act_value,
            "interpretation": interp
        })

    return interpretations

2. 특징 조향 (Feature Steering)

def steer_generation(
    model,
    sae: SparseAutoencoder,
    prompt: str,
    feature_idx: int,
    steering_strength: float = 2.0,
    layer_idx: int = 10
):
    """
    특정 특징을 증폭/억제하여 생성 조향

    Args:
        model: 언어 모델
        sae: 학습된 SAE
        prompt: 입력 프롬프트
        feature_idx: 조향할 특징 인덱스
        steering_strength: 조향 강도 (양수: 증폭, 음수: 억제)
        layer_idx: SAE가 적용된 레이어
    """
    # 조향 벡터 = 특징의 디코더 방향
    steering_vector = sae.W_dec[feature_idx].clone()

    def steering_hook(module, input, output):
        # 활성화에 조향 벡터 추가
        output = output + steering_strength * steering_vector
        return output

    # Hook 등록
    handle = model.model.layers[layer_idx].register_forward_hook(steering_hook)

    # 생성
    output = model.generate(prompt, max_new_tokens=100)

    handle.remove()

    return output

3. 유해 특징 탐지

def detect_harmful_features(
    sae: SparseAutoencoder,
    harmful_activations: torch.Tensor,
    benign_activations: torch.Tensor,
    threshold: float = 2.0
):
    """
    유해 콘텐츠에서 특이적으로 활성화되는 특징 탐지

    Args:
        sae: 학습된 SAE
        harmful_activations: 유해 콘텐츠의 활성화
        benign_activations: 일반 콘텐츠의 활성화
        threshold: 특이성 임계값 (배수)

    Returns:
        harmful_features: 유해 특징 인덱스와 특이성 점수
    """
    with torch.no_grad():
        harmful_feats, _ = sae.forward(harmful_activations)
        benign_feats, _ = sae.forward(benign_activations)

    # 평균 활성화 비교
    harmful_mean = harmful_feats.mean(dim=0)
    benign_mean = benign_feats.mean(dim=0)

    # 특이성 점수 = harmful / (benign + epsilon)
    specificity = harmful_mean / (benign_mean + 1e-6)

    # 임계값 이상인 특징
    harmful_features = torch.where(specificity > threshold)[0]

    results = []
    for idx in harmful_features:
        results.append({
            "feature_idx": idx.item(),
            "specificity": specificity[idx].item(),
            "harmful_activation": harmful_mean[idx].item(),
            "benign_activation": benign_mean[idx].item()
        })

    return sorted(results, key=lambda x: -x["specificity"])

평가 지표

지표 설명 목표
Reconstruction Loss MSE(x, x_hat) 낮을수록 좋음
L0 Sparsity 활성화된 특징 비율 낮을수록 좋음 (1-5%)
Dead Features 전혀 활성화 안 되는 특징 비율 낮을수록 좋음 (<5%)
Explained Variance 1 - Var(x - x_hat) / Var(x) 높을수록 좋음 (>90%)
Downstream Loss SAE 적용 후 모델 성능 원본과 유사해야 함

한계 및 도전 과제

과제 설명
Scalability 대형 모델에서 특징 수가 폭발적으로 증가
Ground Truth 특징 해석의 정확성 검증 어려움
Completeness 모든 특징이 해석 가능하지 않음 (10-30%만 명확)
Cross-layer 레이어별로 별도 SAE 필요, 통합 분석 어려움
Computational Cost 학습 및 추론 시 추가 비용
Feature Splitting 유사한 개념이 여러 특징으로 분리될 수 있음

도구 및 라이브러리

SAELens

from sae_lens import SAE, SAEConfig, SAETrainer

# 설정
config = SAEConfig(
    model_name="gpt2-small",
    hook_point="blocks.5.hook_resid_post",
    d_in=768,
    expansion_factor=32,
    lr=3e-4,
    sparsity_coefficient=1e-3,
    batch_size=4096
)

# 학습
trainer = SAETrainer(config)
sae = trainer.train()

# 특징 분석
features = sae.encode(activations)
dashboard = sae.create_dashboard(features)

Neuronpedia

웹 기반 SAE 특징 탐색 도구. Anthropic, OpenAI 등의 SAE 특징을 시각적으로 탐색 가능.

참고 문헌

  1. Elhage, N., et al. (2022). "Toy Models of Superposition." Anthropic.
  2. Bricken, T., et al. (2023). "Towards Monosemanticity: Decomposing Language Models With Dictionary Learning." Anthropic.
  3. Templeton, A., et al. (2024). "Scaling Monosemanticity: Extracting Interpretable Features from Claude 3 Sonnet." Anthropic.
  4. Rajamanoharan, S., et al. (2024). "Improving Dictionary Learning with Gated Sparse Autoencoders." Anthropic.
  5. Cunningham, H., et al. (2023). "Sparse Autoencoders Find Highly Interpretable Features in Language Models." ICLR 2024.
  6. Marks, S., et al. (2024). "Sparse Feature Circuits: Discovering and Editing Interpretable Causal Graphs in Language Models." ICML 2024.
  7. Sharkey, L., et al. (2024). "Taking Features Out of Superposition with Sparse Autoencoders." arXiv.