Mechanistic Interpretability¶
개요¶
Mechanistic Interpretability (MI)는 신경망의 내부 작동 원리를 역공학(reverse-engineering)하여 학습된 알고리즘과 표현을 인간이 이해할 수 있는 형태로 추출하는 연구 분야다. 단순히 모델의 입출력 행동을 분석하는 것을 넘어, 개별 뉴런, 어텐션 헤드, 레이어가 어떤 계산을 수행하는지 구체적으로 밝히는 것을 목표로 한다.
| 구분 | 설명 |
|---|---|
| 연구 목표 | 신경망 내부의 학습된 알고리즘 및 회로(circuits) 식별 |
| 핵심 질문 | "이 모델은 어떻게 이 답을 도출했는가?" |
| 주요 기관 | Anthropic, OpenAI, DeepMind, EleutherAI, Redwood Research |
| 적용 분야 | AI 안전성, 디버깅, 모델 개선, 정렬(Alignment) |
핵심 개념¶
1. Features (특징)¶
신경망이 학습한 의미 있는 방향(direction)이나 개념. 단일 뉴런이 하나의 특징을 인코딩하는 것이 아니라, 여러 뉴런이 분산(distributed) 방식으로 특징을 표현한다.
2. Circuits (회로)¶
특정 행동을 구현하는 뉴런/어텐션 헤드의 연결 패턴. 입력에서 출력까지 정보가 어떻게 흐르고 변환되는지 추적한다.
| 회로 유형 | 설명 | 예시 |
|---|---|---|
| Induction Head | 패턴 복사 및 반복 | [A][B]...[A] -> [B] 예측 |
| Indirect Object Identification | 문장 내 간접 목적어 식별 | "John gave Mary the book. She..." |
| Greater-Than Circuit | 수치 비교 | 연도, 숫자 크기 비교 |
| Copy Suppression | 반복 토큰 억제 | 동일 단어 재출력 방지 |
3. Superposition (중첩)¶
모델이 뉴런 수보다 더 많은 특징을 인코딩하는 현상. 희소하게 활성화되는 특징들이 동일한 뉴런 공간을 공유한다.
4. Polysemanticity¶
단일 뉴런이 여러 의미적으로 무관한 개념에 반응하는 현상. Superposition의 직접적 결과.
주요 분석 기법¶
1. Activation Patching¶
특정 위치의 활성화 값을 다른 입력의 활성화로 교체하여 인과 관계 파악.
import torch
def activation_patch(model, clean_input, corrupted_input, layer_idx, position):
"""
Clean run의 특정 위치 활성화를 corrupted run에 주입하여
해당 위치의 인과적 중요도 측정
"""
# Clean forward pass - 활성화 저장
clean_cache = {}
def save_hook(module, input, output):
clean_cache['activation'] = output.clone()
handle = model.layers[layer_idx].register_forward_hook(save_hook)
with torch.no_grad():
clean_output = model(clean_input)
handle.remove()
# Corrupted forward pass - 활성화 교체
def patch_hook(module, input, output):
output[:, position, :] = clean_cache['activation'][:, position, :]
return output
handle = model.layers[layer_idx].register_forward_hook(patch_hook)
with torch.no_grad():
patched_output = model(corrupted_input)
handle.remove()
# 원래 corrupted output
with torch.no_grad():
corrupted_output = model(corrupted_input)
# 인과 효과 = patched와 corrupted의 차이
causal_effect = (patched_output - corrupted_output).abs().mean()
return causal_effect.item()
2. Logit Lens / Tuned Lens¶
중간 레이어의 잔차 스트림을 unembedding에 투영하여 각 레이어에서의 "예측" 확인.
def logit_lens(model, hidden_states, layer_idx):
"""
중간 레이어 hidden state를 어휘 공간으로 투영
"""
# 레이어 정규화 적용
normed = model.final_layer_norm(hidden_states[layer_idx])
# Unembedding (lm_head) 적용
logits = model.lm_head(normed)
# 상위 토큰 확인
probs = torch.softmax(logits[:, -1, :], dim=-1)
top_tokens = torch.topk(probs, k=10)
return top_tokens
def tuned_lens(model, hidden_states, layer_idx, translator):
"""
학습된 변환기를 사용한 개선된 logit lens
translator: 각 레이어별로 학습된 affine 변환
"""
translated = translator[layer_idx](hidden_states[layer_idx])
normed = model.final_layer_norm(translated)
logits = model.lm_head(normed)
return logits
3. Sparse Autoencoders (SAE)¶
Superposition을 풀어 해석 가능한 단일 특징(monosemantic features)으로 분해.
import torch
import torch.nn as nn
class SparseAutoencoder(nn.Module):
"""
활성화 벡터를 해석 가능한 특징으로 분해
"""
def __init__(self, d_model, n_features, sparsity_coef=1e-3):
super().__init__()
self.encoder = nn.Linear(d_model, n_features)
self.decoder = nn.Linear(n_features, d_model)
self.sparsity_coef = sparsity_coef
def forward(self, x):
# 인코딩 (ReLU로 희소성 유도)
features = torch.relu(self.encoder(x))
# 디코딩 (재구성)
reconstruction = self.decoder(features)
return features, reconstruction
def loss(self, x):
features, reconstruction = self.forward(x)
# 재구성 손실
recon_loss = nn.functional.mse_loss(reconstruction, x)
# L1 희소성 손실
sparsity_loss = self.sparsity_coef * features.abs().mean()
return recon_loss + sparsity_loss, features
# 학습 예시
def train_sae(model, dataloader, d_model=768, n_features=32768):
sae = SparseAutoencoder(d_model, n_features)
optimizer = torch.optim.Adam(sae.parameters(), lr=1e-4)
for batch in dataloader:
# 모델에서 활성화 추출
with torch.no_grad():
activations = extract_activations(model, batch)
# SAE 학습
loss, features = sae.loss(activations)
optimizer.zero_grad()
loss.backward()
optimizer.step()
return sae
4. Probing¶
특정 레이어에 선형 분류기를 학습시켜 정보 인코딩 여부 확인.
from sklearn.linear_model import LogisticRegression
import numpy as np
def linear_probe(model, dataset, layer_idx, labels):
"""
특정 레이어가 특정 속성을 인코딩하는지 확인
"""
activations = []
for text in dataset:
with torch.no_grad():
outputs = model(text, output_hidden_states=True)
# 특정 레이어의 [CLS] 또는 마지막 토큰 활성화
act = outputs.hidden_states[layer_idx][:, -1, :]
activations.append(act.cpu().numpy())
X = np.vstack(activations)
y = np.array(labels)
# 선형 프로브 학습
probe = LogisticRegression(max_iter=1000)
probe.fit(X, y)
accuracy = probe.score(X, y)
return probe, accuracy
5. Attention Pattern Analysis¶
어텐션 가중치 시각화 및 패턴 분석.
def analyze_attention_heads(model, input_ids, layer_idx):
"""
특정 레이어의 어텐션 헤드 패턴 분석
"""
with torch.no_grad():
outputs = model(input_ids, output_attentions=True)
# [batch, num_heads, seq_len, seq_len]
attention = outputs.attentions[layer_idx]
patterns = {}
num_heads = attention.shape[1]
for head in range(num_heads):
head_attn = attention[0, head] # [seq_len, seq_len]
# 패턴 특성 분석
patterns[f"head_{head}"] = {
"diagonal_score": head_attn.diagonal().mean().item(), # 자기 자신 주목
"prev_token_score": head_attn.diagonal(-1).mean().item(), # 이전 토큰 주목
"first_token_score": head_attn[:, 0].mean().item(), # 첫 토큰 주목
"entropy": -(head_attn * head_attn.log()).sum(-1).mean().item() # 분산도
}
return patterns
주요 연구 성과¶
Anthropic의 연구¶
| 연구 | 핵심 내용 | 연도 |
|---|---|---|
| Toy Models of Superposition | 중첩 현상의 수학적 모델링 | 2022 |
| Towards Monosemanticity | SAE로 Claude 특징 추출 | 2023 |
| Scaling Monosemanticity | Claude 3 Sonnet에서 수백만 특징 식별 | 2024 |
| Circuit Tracing | 완전한 회로 추적 방법론 | 2025 |
OpenAI의 연구¶
| 연구 | 핵심 내용 | 연도 |
|---|---|---|
| Language Models Can Explain Neurons | GPT-4로 뉴런 설명 자동 생성 | 2023 |
| Extracting Concepts from GPT-4 | SAE 기반 개념 추출 | 2024 |
커뮤니티 연구¶
| 연구 | 핵심 내용 | 소속 |
|---|---|---|
| Induction Heads | 문맥 내 학습의 핵심 메커니즘 | Anthropic |
| Indirect Object Identification | GPT-2의 IOI 회로 완전 분석 | Redwood Research |
| TransformerLens | MI 연구용 라이브러리 | Neel Nanda |
| SAELens | SAE 학습 및 분석 도구 | EleutherAI |
도구 및 라이브러리¶
TransformerLens¶
from transformer_lens import HookedTransformer
# 모델 로드 (자동으로 hook 추가)
model = HookedTransformer.from_pretrained("gpt2-small")
# 캐시된 활성화로 추론
logits, cache = model.run_with_cache("Hello, world!")
# 특정 활성화 접근
residual_stream = cache["resid_post", 5] # 레이어 5 이후 잔차
attn_pattern = cache["pattern", 3] # 레이어 3 어텐션 패턴
mlp_out = cache["mlp_out", 7] # 레이어 7 MLP 출력
# Activation patching
def patch_hook(activation, hook):
activation[:, 5, :] = corrupted_activation[:, 5, :]
return activation
patched_logits = model.run_with_hooks(
clean_input,
fwd_hooks=[("blocks.3.hook_resid_post", patch_hook)]
)
SAELens¶
from sae_lens import SAE, SAETrainingConfig
# SAE 설정
config = SAETrainingConfig(
model_name="gpt2-small",
hook_point="blocks.5.hook_resid_post",
d_in=768,
expansion_factor=32, # n_features = d_in * expansion_factor
lr=3e-4,
sparsity_coefficient=1e-3
)
# 학습
sae = SAE.from_config(config)
sae.train(dataloader)
# 특징 활성화 분석
features = sae.encode(activations)
top_features = features.topk(k=10)
한계 및 도전 과제¶
| 도전 과제 | 설명 |
|---|---|
| Scalability | 대형 모델(100B+ 파라미터)에 적용 어려움 |
| Faithfulness | 발견한 회로가 실제 계산을 반영하는지 검증 필요 |
| Completeness | 모델 행동의 일부만 설명 가능 |
| Automation | 수작업 분석에서 자동화로 전환 중 |
| Ground Truth | 정답이 없어 검증이 어려움 |
실용적 응용¶
1. 모델 디버깅¶
def debug_incorrect_prediction(model, input_text, expected_token, actual_token):
"""
잘못된 예측의 원인 분석
"""
# 각 레이어에서의 logit 기여도 분석
_, cache = model.run_with_cache(input_text)
contributions = []
for layer in range(model.cfg.n_layers):
# 해당 레이어의 잔차 기여도
layer_contrib = cache["resid_post", layer] - cache["resid_pre", layer]
# 예상 토큰과 실제 토큰에 대한 로짓 기여도
expected_logit = (layer_contrib @ model.W_U[:, expected_token]).mean()
actual_logit = (layer_contrib @ model.W_U[:, actual_token]).mean()
contributions.append({
"layer": layer,
"expected_contribution": expected_logit.item(),
"actual_contribution": actual_logit.item()
})
return contributions
2. 특징 조작 (Steering)¶
def steer_generation(model, sae, input_text, feature_idx, strength=1.0):
"""
특정 특징을 증폭/억제하여 생성 조작
"""
def steering_hook(activation, hook):
# SAE로 특징 분해
features = sae.encode(activation)
# 특정 특징 조작
features[:, :, feature_idx] += strength
# 다시 활성화 공간으로
return sae.decode(features)
output = model.run_with_hooks(
input_text,
fwd_hooks=[("blocks.10.hook_resid_post", steering_hook)]
)
return output
참고 문헌¶
- Elhage, N., et al. (2022). "Toy Models of Superposition." Anthropic.
- Conmy, A., et al. (2023). "Towards Automated Circuit Discovery for Mechanistic Interpretability." NeurIPS 2023.
- Bricken, T., et al. (2023). "Towards Monosemanticity: Decomposing Language Models With Dictionary Learning." Anthropic.
- Templeton, A., et al. (2024). "Scaling Monosemanticity: Extracting Interpretable Features from Claude 3 Sonnet." Anthropic.
- Bills, S., et al. (2023). "Language Models Can Explain Neurons in Language Models." OpenAI.
- Wang, K., et al. (2023). "Interpretability in the Wild: A Circuit for Indirect Object Identification in GPT-2 Small." ICLR 2023.
- Nanda, N. (2023). "TransformerLens: A Library for Mechanistic Interpretability." GitHub.
- Marks, S., et al. (2024). "Sparse Feature Circuits: Discovering and Editing Interpretable Causal Graphs in Language Models." ICML 2024.