Sparse Autoencoders (SAE)
개요
Sparse Autoencoders (SAE)는 신경망의 내부 활성화(activations)를 해석 가능한 희소 특징(sparse features)으로 분해하는 비지도 학습 기법이다. 딥러닝 모델이 학습하는 특징(features)이 중첩(superposition) 상태로 인코딩되어 단일 뉴런 분석으로는 파악하기 어려운 문제를 해결한다.
| 구분 |
설명 |
| 핵심 목표 |
모델 활성화를 해석 가능한 단일의미(monosemantic) 특징으로 분해 |
| 이론적 기반 |
Dictionary Learning, Sparse Coding |
| 주요 적용 |
LLM Interpretability, AI Safety, Model Debugging |
| 핵심 논문 |
Towards Monosemanticity (Anthropic, 2023) |
| 관련 개념 |
Superposition, Polysemanticity, Mechanistic Interpretability |
배경: Superposition 문제
Superposition이란
신경망은 뉴런 수보다 훨씬 많은 특징을 학습한다. 이 특징들이 희소하게 활성화될 경우, 모델은 동일한 뉴런 공간에 여러 특징을 중첩하여 저장한다.
d차원 공간에 n개 특징 저장 (n >> d)
- 특징 i: 방향 벡터 f_i
- 활성화: x = sum_i a_i * f_i (a_i는 희소)
- 특징 간 간섭 최소화를 위해 거의 직교
Polysemanticity
Superposition의 결과로 단일 뉴런이 여러 무관한 개념에 반응한다.
| 현상 |
예시 |
| 다의성 |
한 뉴런이 "학술 인용", "법률 용어", "특정 URL 패턴"에 동시 반응 |
| 해석 불가 |
뉴런별 분석으로는 의미 파악 어려움 |
| 분산 표현 |
하나의 개념이 여러 뉴런에 걸쳐 인코딩 |
SAE의 해결 방식
활성화 x (d차원) --> SAE --> 특징 f (n차원, n >> d)
|
v
각 f_i는 하나의 해석 가능한 개념
아키텍처
기본 구조
Encoder: f = ReLU(W_enc @ (x - b_dec) + b_enc)
Decoder: x_hat = W_dec @ f + b_dec
손실함수: L = ||x - x_hat||_2^2 + lambda * ||f||_1
| 구성요소 |
차원 |
역할 |
| 입력 x |
(batch, d_model) |
모델의 활성화 벡터 |
| 인코더 W_enc |
(d_model, n_features) |
활성화를 특징 공간으로 매핑 |
| 디코더 W_dec |
(n_features, d_model) |
특징을 활성화 공간으로 복원 |
| 특징 f |
(batch, n_features) |
희소 특징 활성화 |
| 확장 비율 |
n_features / d_model |
보통 8x ~ 128x |
Python 구현
import torch
import torch.nn as nn
import torch.nn.functional as F
class SparseAutoencoder(nn.Module):
"""
기본 Sparse Autoencoder 구현
Args:
d_model: 입력 활성화 차원
n_features: 학습할 특징 수 (보통 d_model의 8-64배)
sparsity_coef: L1 희소성 페널티 계수
"""
def __init__(self, d_model: int, n_features: int, sparsity_coef: float = 1e-3):
super().__init__()
# 인코더 (특징 추출)
self.W_enc = nn.Parameter(torch.randn(d_model, n_features) * 0.01)
self.b_enc = nn.Parameter(torch.zeros(n_features))
# 디코더 (재구성)
self.W_dec = nn.Parameter(torch.randn(n_features, d_model) * 0.01)
self.b_dec = nn.Parameter(torch.zeros(d_model))
self.sparsity_coef = sparsity_coef
# 디코더 가중치 정규화 (단위 벡터)
with torch.no_grad():
self.W_dec.data = F.normalize(self.W_dec.data, dim=1)
def encode(self, x: torch.Tensor) -> torch.Tensor:
"""활성화를 희소 특징으로 인코딩"""
x_centered = x - self.b_dec
pre_activation = x_centered @ self.W_enc + self.b_enc
return F.relu(pre_activation)
def decode(self, f: torch.Tensor) -> torch.Tensor:
"""희소 특징을 활성화로 디코딩"""
return f @ self.W_dec + self.b_dec
def forward(self, x: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]:
"""순전파: 인코딩 후 디코딩"""
features = self.encode(x)
reconstruction = self.decode(features)
return features, reconstruction
def loss(self, x: torch.Tensor) -> tuple[torch.Tensor, dict]:
"""
전체 손실 계산
Returns:
total_loss: 재구성 손실 + 희소성 손실
metrics: 개별 손실 값들
"""
features, reconstruction = self.forward(x)
# 재구성 손실 (MSE)
recon_loss = F.mse_loss(reconstruction, x)
# L1 희소성 손실
sparsity_loss = features.abs().mean()
# 전체 손실
total_loss = recon_loss + self.sparsity_coef * sparsity_loss
metrics = {
"recon_loss": recon_loss.item(),
"sparsity_loss": sparsity_loss.item(),
"total_loss": total_loss.item(),
"l0_sparsity": (features > 0).float().mean().item(), # 활성화된 특징 비율
}
return total_loss, metrics
@torch.no_grad()
def normalize_decoder(self):
"""디코더 가중치를 단위 벡터로 정규화"""
self.W_dec.data = F.normalize(self.W_dec.data, dim=1)
SAE 변형
1. TopK SAE
L1 페널티 대신 상위 K개 특징만 활성화하여 희소성을 직접 제어한다.
class TopKSparseAutoencoder(nn.Module):
"""
TopK 활성화 기반 SAE
L1 페널티 튜닝 없이 정확한 희소성 제어
"""
def __init__(self, d_model: int, n_features: int, k: int = 32):
super().__init__()
self.W_enc = nn.Parameter(torch.randn(d_model, n_features) * 0.01)
self.b_enc = nn.Parameter(torch.zeros(n_features))
self.W_dec = nn.Parameter(torch.randn(n_features, d_model) * 0.01)
self.b_dec = nn.Parameter(torch.zeros(d_model))
self.k = k
with torch.no_grad():
self.W_dec.data = F.normalize(self.W_dec.data, dim=1)
def encode(self, x: torch.Tensor) -> torch.Tensor:
x_centered = x - self.b_dec
pre_act = x_centered @ self.W_enc + self.b_enc
# 상위 K개만 활성화
topk_values, topk_indices = torch.topk(pre_act, self.k, dim=-1)
# 희소 활성화 생성
features = torch.zeros_like(pre_act)
features.scatter_(-1, topk_indices, F.relu(topk_values))
return features
def decode(self, f: torch.Tensor) -> torch.Tensor:
return f @ self.W_dec + self.b_dec
def forward(self, x: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]:
features = self.encode(x)
reconstruction = self.decode(features)
return features, reconstruction
def loss(self, x: torch.Tensor) -> tuple[torch.Tensor, dict]:
features, reconstruction = self.forward(x)
recon_loss = F.mse_loss(reconstruction, x)
# Auxiliary loss: TopK 외 특징들의 기울기 유지
pre_act = (x - self.b_dec) @ self.W_enc + self.b_enc
aux_loss = F.relu(pre_act).mean() * 0.01 # 작은 보조 손실
return recon_loss + aux_loss, {
"recon_loss": recon_loss.item(),
"l0_sparsity": self.k / self.W_enc.shape[1]
}
2. Gated SAE (Anthropic, 2024)
게이트 메커니즘으로 특징 활성화 여부와 크기를 분리하여 학습한다.
class GatedSparseAutoencoder(nn.Module):
"""
Gated SAE: 특징 활성화 여부(gate)와 크기(magnitude)를 분리
shrinkage bias 문제 해결
"""
def __init__(self, d_model: int, n_features: int, sparsity_coef: float = 1e-3):
super().__init__()
# Gate 경로 (활성화 여부 결정)
self.W_gate = nn.Parameter(torch.randn(d_model, n_features) * 0.01)
self.b_gate = nn.Parameter(torch.zeros(n_features))
# Magnitude 경로 (활성화 크기 결정)
self.W_mag = nn.Parameter(torch.randn(d_model, n_features) * 0.01)
self.b_mag = nn.Parameter(torch.zeros(n_features))
self.r_mag = nn.Parameter(torch.ones(n_features)) # rescale
# 디코더
self.W_dec = nn.Parameter(torch.randn(n_features, d_model) * 0.01)
self.b_dec = nn.Parameter(torch.zeros(d_model))
self.sparsity_coef = sparsity_coef
with torch.no_grad():
self.W_dec.data = F.normalize(self.W_dec.data, dim=1)
def encode(self, x: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]:
x_centered = x - self.b_dec
# Gate: 활성화 여부 (sigmoid -> Heaviside 근사)
gate_pre = x_centered @ self.W_gate + self.b_gate
gate = (gate_pre > 0).float() # Hard gate
# Magnitude: 활성화 크기
mag_pre = x_centered @ self.W_mag + self.b_mag
magnitude = F.relu(mag_pre) * torch.exp(self.r_mag)
# 최종 특징 = gate * magnitude
features = gate * magnitude
return features, gate_pre
def decode(self, f: torch.Tensor) -> torch.Tensor:
return f @ self.W_dec + self.b_dec
def forward(self, x: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
features, gate_pre = self.encode(x)
reconstruction = self.decode(features)
return features, reconstruction, gate_pre
def loss(self, x: torch.Tensor) -> tuple[torch.Tensor, dict]:
features, reconstruction, gate_pre = self.forward(x)
# 재구성 손실
recon_loss = F.mse_loss(reconstruction, x)
# Gate에 대한 희소성 손실 (L1 on gate activations)
gate_sparsity = F.relu(gate_pre).mean()
# Auxiliary loss: gate 학습 안정화
aux_loss = F.mse_loss(
self.decode(F.relu(gate_pre)),
x - self.b_dec
) * 0.1
total_loss = recon_loss + self.sparsity_coef * gate_sparsity + aux_loss
return total_loss, {
"recon_loss": recon_loss.item(),
"gate_sparsity": gate_sparsity.item(),
"l0_sparsity": (features > 0).float().mean().item()
}
3. JumpReLU SAE (DeepMind, 2024)
학습 가능한 threshold를 가진 JumpReLU 활성화 함수 사용.
class JumpReLU(torch.autograd.Function):
"""
JumpReLU: x > theta이면 x, 아니면 0
theta는 학습 가능
"""
@staticmethod
def forward(ctx, x, theta):
ctx.save_for_backward(x, theta)
return torch.where(x > theta, x, torch.zeros_like(x))
@staticmethod
def backward(ctx, grad_output):
x, theta = ctx.saved_tensors
# x > theta인 경우에만 기울기 전파
mask = (x > theta).float()
grad_x = grad_output * mask
# theta에 대한 기울기 (STE 근사)
grad_theta = -grad_output * mask
return grad_x, grad_theta.sum(dim=0)
class JumpReLUSAE(nn.Module):
"""
JumpReLU SAE: 학습 가능한 threshold로 희소성 제어
"""
def __init__(self, d_model: int, n_features: int, init_threshold: float = 0.1):
super().__init__()
self.W_enc = nn.Parameter(torch.randn(d_model, n_features) * 0.01)
self.b_enc = nn.Parameter(torch.zeros(n_features))
self.theta = nn.Parameter(torch.full((n_features,), init_threshold))
self.W_dec = nn.Parameter(torch.randn(n_features, d_model) * 0.01)
self.b_dec = nn.Parameter(torch.zeros(d_model))
with torch.no_grad():
self.W_dec.data = F.normalize(self.W_dec.data, dim=1)
def encode(self, x: torch.Tensor) -> torch.Tensor:
x_centered = x - self.b_dec
pre_act = x_centered @ self.W_enc + self.b_enc
return JumpReLU.apply(pre_act, self.theta)
def decode(self, f: torch.Tensor) -> torch.Tensor:
return f @ self.W_dec + self.b_dec
def forward(self, x: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]:
features = self.encode(x)
reconstruction = self.decode(features)
return features, reconstruction
SAE 변형 비교
| 변형 |
희소성 제어 |
장점 |
단점 |
| Vanilla (L1) |
lambda 튜닝 |
구현 간단 |
shrinkage bias |
| TopK |
k 직접 지정 |
정확한 희소성 |
기울기 소실 |
| Gated |
분리된 gate |
shrinkage 해결 |
파라미터 2배 |
| JumpReLU |
학습된 threshold |
적응적 희소성 |
학습 불안정 |
학습 방법
데이터 수집
from transformers import AutoModelForCausalLM, AutoTokenizer
import torch
def collect_activations(
model_name: str,
layer_idx: int,
hook_point: str, # "resid_pre", "resid_post", "mlp_out", etc.
dataset,
batch_size: int = 32,
max_samples: int = 1_000_000
):
"""
모델에서 특정 레이어의 활성화 수집
Args:
model_name: HuggingFace 모델 이름
layer_idx: 활성화를 추출할 레이어 인덱스
hook_point: 추출 위치 (residual stream, MLP 출력 등)
dataset: 텍스트 데이터셋
Returns:
activations: (n_samples, d_model) 텐서
"""
model = AutoModelForCausalLM.from_pretrained(model_name)
tokenizer = AutoTokenizer.from_pretrained(model_name)
activations = []
# Hook 설정
def hook_fn(module, input, output):
# output shape: (batch, seq_len, d_model)
# 모든 토큰 위치의 활성화 수집
activations.append(output.detach().cpu())
# 레이어에 따른 hook 등록
if hook_point == "resid_post":
handle = model.model.layers[layer_idx].register_forward_hook(hook_fn)
elif hook_point == "mlp_out":
handle = model.model.layers[layer_idx].mlp.register_forward_hook(hook_fn)
model.eval()
with torch.no_grad():
for batch in dataset:
inputs = tokenizer(batch, return_tensors="pt", padding=True, truncation=True)
_ = model(**inputs)
if sum(a.shape[0] * a.shape[1] for a in activations) >= max_samples:
break
handle.remove()
# (n_samples, d_model) 형태로 변환
all_activations = torch.cat(activations, dim=0)
all_activations = all_activations.view(-1, all_activations.shape[-1])
return all_activations[:max_samples]
학습 루프
def train_sae(
sae: SparseAutoencoder,
activations: torch.Tensor,
batch_size: int = 4096,
num_epochs: int = 1,
lr: float = 3e-4,
device: str = "cuda"
):
"""
SAE 학습
Args:
sae: SAE 모델
activations: 수집된 활성화 데이터
batch_size: 배치 크기 (보통 4096-8192)
num_epochs: 에폭 수 (보통 1-3)
lr: 학습률
"""
sae = sae.to(device)
optimizer = torch.optim.AdamW(sae.parameters(), lr=lr, betas=(0.9, 0.999))
# 학습률 스케줄러 (warmup + decay)
num_steps = len(activations) // batch_size * num_epochs
warmup_steps = num_steps // 10
def lr_schedule(step):
if step < warmup_steps:
return step / warmup_steps
return max(0.1, 1 - (step - warmup_steps) / (num_steps - warmup_steps))
scheduler = torch.optim.lr_scheduler.LambdaLR(optimizer, lr_schedule)
# 학습
dataset = torch.utils.data.TensorDataset(activations)
dataloader = torch.utils.data.DataLoader(
dataset, batch_size=batch_size, shuffle=True
)
for epoch in range(num_epochs):
total_loss = 0
for batch in dataloader:
x = batch[0].to(device)
loss, metrics = sae.loss(x)
optimizer.zero_grad()
loss.backward()
# 기울기 클리핑
torch.nn.utils.clip_grad_norm_(sae.parameters(), 1.0)
optimizer.step()
scheduler.step()
# 디코더 정규화 (단위 벡터 유지)
sae.normalize_decoder()
total_loss += loss.item()
avg_loss = total_loss / len(dataloader)
print(f"Epoch {epoch+1}: Loss = {avg_loss:.4f}")
return sae
Dead Features 처리
일부 특징이 학습 중 활성화되지 않는 "dead features" 문제 해결.
def resample_dead_features(
sae: SparseAutoencoder,
activations: torch.Tensor,
dead_threshold: float = 1e-6,
resample_fraction: float = 0.5
):
"""
Dead features 재초기화
활성화 빈도가 threshold 미만인 특징들을
높은 재구성 오차를 가진 샘플 방향으로 재초기화
"""
with torch.no_grad():
# 특징 활성화 빈도 계산
features, _ = sae.forward(activations)
activation_freq = (features > 0).float().mean(dim=0)
# Dead features 식별
dead_mask = activation_freq < dead_threshold
n_dead = dead_mask.sum().item()
if n_dead == 0:
return 0
# 높은 재구성 오차를 가진 샘플 식별
_, reconstruction = sae.forward(activations)
recon_errors = (activations - reconstruction).pow(2).sum(dim=1)
# 상위 오차 샘플을 dead feature 방향으로 사용
n_resample = min(n_dead, int(len(activations) * resample_fraction))
top_error_indices = torch.topk(recon_errors, n_resample).indices
# Dead features 재초기화
new_directions = activations[top_error_indices]
new_directions = F.normalize(new_directions, dim=1)
dead_indices = torch.where(dead_mask)[0][:n_resample]
sae.W_dec.data[dead_indices] = new_directions
sae.W_enc.data[:, dead_indices] = new_directions.T
sae.b_enc.data[dead_indices] = 0
return n_dead
특징 분석
최대 활성화 샘플 찾기
def find_max_activating_examples(
sae: SparseAutoencoder,
activations: torch.Tensor,
texts: list[str],
feature_idx: int,
top_k: int = 10
):
"""
특정 특징에 가장 강하게 반응하는 예시 찾기
"""
with torch.no_grad():
features, _ = sae.forward(activations)
feature_acts = features[:, feature_idx]
top_indices = torch.topk(feature_acts, top_k).indices
results = []
for idx in top_indices:
results.append({
"text": texts[idx],
"activation": feature_acts[idx].item()
})
return results
자동 특징 해석 (AutoInterp)
def auto_interpret_feature(
max_activating_examples: list[dict],
llm_client # OpenAI/Anthropic API 클라이언트
) -> str:
"""
LLM을 사용하여 특징의 의미 자동 추론
Args:
max_activating_examples: 최대 활성화 예시들
llm_client: API 클라이언트
Returns:
feature_description: 특징에 대한 자연어 설명
"""
prompt = """다음은 신경망의 특정 특징(feature)이 강하게 활성화되는 텍스트 예시들입니다.
이 특징이 감지하는 패턴이나 개념이 무엇인지 한 문장으로 설명해주세요.
예시들:
"""
for ex in max_activating_examples[:10]:
prompt += f"- (활성화: {ex['activation']:.2f}) {ex['text']}\n"
prompt += "\n이 특징이 감지하는 것: "
response = llm_client.generate(prompt)
return response
특징 활성화 시각화
import matplotlib.pyplot as plt
import numpy as np
def visualize_feature_activations(
sae: SparseAutoencoder,
activations: torch.Tensor,
feature_indices: list[int] = None,
n_features: int = 50
):
"""
특징 활성화 패턴 시각화
"""
with torch.no_grad():
features, _ = sae.forward(activations)
if feature_indices is None:
# 가장 자주 활성화되는 특징들
activation_freq = (features > 0).float().mean(dim=0)
feature_indices = torch.topk(activation_freq, n_features).indices.tolist()
# 활성화 히트맵
subset = features[:1000, feature_indices].cpu().numpy()
fig, axes = plt.subplots(2, 2, figsize=(14, 10))
# 1. 활성화 히트맵
ax = axes[0, 0]
im = ax.imshow(subset.T, aspect='auto', cmap='viridis')
ax.set_xlabel('Sample')
ax.set_ylabel('Feature')
ax.set_title('Feature Activations')
plt.colorbar(im, ax=ax)
# 2. 활성화 빈도 분포
ax = axes[0, 1]
freq = (features > 0).float().mean(dim=0).cpu().numpy()
ax.hist(freq, bins=50, edgecolor='black')
ax.set_xlabel('Activation Frequency')
ax.set_ylabel('Count')
ax.set_title('Feature Activation Frequency Distribution')
ax.set_yscale('log')
# 3. L0 sparsity 분포 (샘플당 활성화된 특징 수)
ax = axes[1, 0]
l0 = (features > 0).sum(dim=1).cpu().numpy()
ax.hist(l0, bins=50, edgecolor='black')
ax.set_xlabel('Number of Active Features')
ax.set_ylabel('Count')
ax.set_title('L0 Sparsity per Sample')
# 4. 특징 활성화 크기 분포
ax = axes[1, 1]
nonzero = features[features > 0].cpu().numpy()
ax.hist(nonzero, bins=50, edgecolor='black')
ax.set_xlabel('Activation Magnitude')
ax.set_ylabel('Count')
ax.set_title('Non-zero Activation Distribution')
ax.set_yscale('log')
plt.tight_layout()
return fig
주요 연구
Anthropic의 연구
| 논문 |
연도 |
핵심 기여 |
| Toy Models of Superposition |
2022 |
Superposition 현상의 이론적 모델링 |
| Towards Monosemanticity |
2023 |
1-layer transformer에서 SAE로 해석 가능한 특징 추출 |
| Scaling Monosemanticity |
2024 |
Claude 3 Sonnet에서 수백만 개 특징 식별 |
| Gated SAE |
2024 |
Gate 메커니즘으로 shrinkage bias 해결 |
| Circuit Tracing |
2025 |
SAE 특징 기반 회로 추적 방법론 |
OpenAI의 연구
| 논문 |
연도 |
핵심 기여 |
| Language Models Can Explain Neurons |
2023 |
GPT-4로 뉴런 자동 해석 |
| Extracting Concepts from GPT-4 |
2024 |
GPT-4에서 SAE 기반 개념 추출 |
커뮤니티 연구
| 연구 |
소속 |
핵심 기여 |
| SAELens |
EleutherAI |
SAE 학습/분석 오픈소스 라이브러리 |
| TransformerLens |
Neel Nanda |
MI 연구용 기반 라이브러리 |
| Sparse Feature Circuits |
ICML 2024 |
SAE 특징으로 회로 발견 및 편집 |
응용
1. 모델 행동 분석
def analyze_model_behavior(
model,
sae: SparseAutoencoder,
prompt: str,
layer_idx: int
):
"""
특정 입력에 대한 모델 행동을 SAE 특징으로 분석
"""
# 활성화 추출
activations = extract_activations(model, prompt, layer_idx)
# SAE로 특징 분해
features, reconstruction = sae.forward(activations)
# 활성화된 특징 식별
active_features = torch.where(features[0, -1] > 0)[0] # 마지막 토큰
# 각 특징의 해석 조회
interpretations = []
for feat_idx in active_features[:10]:
interp = get_feature_interpretation(feat_idx)
act_value = features[0, -1, feat_idx].item()
interpretations.append({
"feature": feat_idx.item(),
"activation": act_value,
"interpretation": interp
})
return interpretations
2. 특징 조향 (Feature Steering)
def steer_generation(
model,
sae: SparseAutoencoder,
prompt: str,
feature_idx: int,
steering_strength: float = 2.0,
layer_idx: int = 10
):
"""
특정 특징을 증폭/억제하여 생성 조향
Args:
model: 언어 모델
sae: 학습된 SAE
prompt: 입력 프롬프트
feature_idx: 조향할 특징 인덱스
steering_strength: 조향 강도 (양수: 증폭, 음수: 억제)
layer_idx: SAE가 적용된 레이어
"""
# 조향 벡터 = 특징의 디코더 방향
steering_vector = sae.W_dec[feature_idx].clone()
def steering_hook(module, input, output):
# 활성화에 조향 벡터 추가
output = output + steering_strength * steering_vector
return output
# Hook 등록
handle = model.model.layers[layer_idx].register_forward_hook(steering_hook)
# 생성
output = model.generate(prompt, max_new_tokens=100)
handle.remove()
return output
3. 유해 특징 탐지
def detect_harmful_features(
sae: SparseAutoencoder,
harmful_activations: torch.Tensor,
benign_activations: torch.Tensor,
threshold: float = 2.0
):
"""
유해 콘텐츠에서 특이적으로 활성화되는 특징 탐지
Args:
sae: 학습된 SAE
harmful_activations: 유해 콘텐츠의 활성화
benign_activations: 일반 콘텐츠의 활성화
threshold: 특이성 임계값 (배수)
Returns:
harmful_features: 유해 특징 인덱스와 특이성 점수
"""
with torch.no_grad():
harmful_feats, _ = sae.forward(harmful_activations)
benign_feats, _ = sae.forward(benign_activations)
# 평균 활성화 비교
harmful_mean = harmful_feats.mean(dim=0)
benign_mean = benign_feats.mean(dim=0)
# 특이성 점수 = harmful / (benign + epsilon)
specificity = harmful_mean / (benign_mean + 1e-6)
# 임계값 이상인 특징
harmful_features = torch.where(specificity > threshold)[0]
results = []
for idx in harmful_features:
results.append({
"feature_idx": idx.item(),
"specificity": specificity[idx].item(),
"harmful_activation": harmful_mean[idx].item(),
"benign_activation": benign_mean[idx].item()
})
return sorted(results, key=lambda x: -x["specificity"])
평가 지표
| 지표 |
설명 |
목표 |
| Reconstruction Loss |
MSE(x, x_hat) |
낮을수록 좋음 |
| L0 Sparsity |
활성화된 특징 비율 |
낮을수록 좋음 (1-5%) |
| Dead Features |
전혀 활성화 안 되는 특징 비율 |
낮을수록 좋음 (<5%) |
| Explained Variance |
1 - Var(x - x_hat) / Var(x) |
높을수록 좋음 (>90%) |
| Downstream Loss |
SAE 적용 후 모델 성능 |
원본과 유사해야 함 |
한계 및 도전 과제
| 과제 |
설명 |
| Scalability |
대형 모델에서 특징 수가 폭발적으로 증가 |
| Ground Truth |
특징 해석의 정확성 검증 어려움 |
| Completeness |
모든 특징이 해석 가능하지 않음 (10-30%만 명확) |
| Cross-layer |
레이어별로 별도 SAE 필요, 통합 분석 어려움 |
| Computational Cost |
학습 및 추론 시 추가 비용 |
| Feature Splitting |
유사한 개념이 여러 특징으로 분리될 수 있음 |
도구 및 라이브러리
SAELens
from sae_lens import SAE, SAEConfig, SAETrainer
# 설정
config = SAEConfig(
model_name="gpt2-small",
hook_point="blocks.5.hook_resid_post",
d_in=768,
expansion_factor=32,
lr=3e-4,
sparsity_coefficient=1e-3,
batch_size=4096
)
# 학습
trainer = SAETrainer(config)
sae = trainer.train()
# 특징 분석
features = sae.encode(activations)
dashboard = sae.create_dashboard(features)
Neuronpedia
웹 기반 SAE 특징 탐색 도구. Anthropic, OpenAI 등의 SAE 특징을 시각적으로 탐색 가능.
참고 문헌
- Elhage, N., et al. (2022). "Toy Models of Superposition." Anthropic.
- Bricken, T., et al. (2023). "Towards Monosemanticity: Decomposing Language Models With Dictionary Learning." Anthropic.
- Templeton, A., et al. (2024). "Scaling Monosemanticity: Extracting Interpretable Features from Claude 3 Sonnet." Anthropic.
- Rajamanoharan, S., et al. (2024). "Improving Dictionary Learning with Gated Sparse Autoencoders." Anthropic.
- Cunningham, H., et al. (2023). "Sparse Autoencoders Find Highly Interpretable Features in Language Models." ICLR 2024.
- Marks, S., et al. (2024). "Sparse Feature Circuits: Discovering and Editing Interpretable Causal Graphs in Language Models." ICML 2024.
- Sharkey, L., et al. (2024). "Taking Features Out of Superposition with Sparse Autoencoders." arXiv.