콘텐츠로 이동
Data Prep
상세

Mechanistic Interpretability

개요

Mechanistic Interpretability (MI)는 신경망의 내부 작동 원리를 역공학(reverse-engineering)하여 학습된 알고리즘과 표현을 인간이 이해할 수 있는 형태로 추출하는 연구 분야다. 단순히 모델의 입출력 행동을 분석하는 것을 넘어, 개별 뉴런, 어텐션 헤드, 레이어가 어떤 계산을 수행하는지 구체적으로 밝히는 것을 목표로 한다.

구분 설명
연구 목표 신경망 내부의 학습된 알고리즘 및 회로(circuits) 식별
핵심 질문 "이 모델은 어떻게 이 답을 도출했는가?"
주요 기관 Anthropic, OpenAI, DeepMind, EleutherAI, Redwood Research
적용 분야 AI 안전성, 디버깅, 모델 개선, 정렬(Alignment)

핵심 개념

1. Features (특징)

신경망이 학습한 의미 있는 방향(direction)이나 개념. 단일 뉴런이 하나의 특징을 인코딩하는 것이 아니라, 여러 뉴런이 분산(distributed) 방식으로 특징을 표현한다.

Feature 예시:
- "대문자로 시작하는 단어" 방향
- "부정적 감정" 방향
- "코드 문법" 방향
- "수학적 연산" 방향

2. Circuits (회로)

특정 행동을 구현하는 뉴런/어텐션 헤드의 연결 패턴. 입력에서 출력까지 정보가 어떻게 흐르고 변환되는지 추적한다.

회로 유형 설명 예시
Induction Head 패턴 복사 및 반복 [A][B]...[A] -> [B] 예측
Indirect Object Identification 문장 내 간접 목적어 식별 "John gave Mary the book. She..."
Greater-Than Circuit 수치 비교 연도, 숫자 크기 비교
Copy Suppression 반복 토큰 억제 동일 단어 재출력 방지

3. Superposition (중첩)

모델이 뉴런 수보다 더 많은 특징을 인코딩하는 현상. 희소하게 활성화되는 특징들이 동일한 뉴런 공간을 공유한다.

Superposition 조건:
- 특징이 희소(sparse)하게 활성화됨
- 특징 간 상관관계가 낮음
- 차원보다 특징 수가 많음 (overcomplete)

4. Polysemanticity

단일 뉴런이 여러 의미적으로 무관한 개념에 반응하는 현상. Superposition의 직접적 결과.

주요 분석 기법

1. Activation Patching

특정 위치의 활성화 값을 다른 입력의 활성화로 교체하여 인과 관계 파악.

import torch

def activation_patch(model, clean_input, corrupted_input, layer_idx, position):
    """
    Clean run의 특정 위치 활성화를 corrupted run에 주입하여
    해당 위치의 인과적 중요도 측정
    """
    # Clean forward pass - 활성화 저장
    clean_cache = {}
    def save_hook(module, input, output):
        clean_cache['activation'] = output.clone()

    handle = model.layers[layer_idx].register_forward_hook(save_hook)
    with torch.no_grad():
        clean_output = model(clean_input)
    handle.remove()

    # Corrupted forward pass - 활성화 교체
    def patch_hook(module, input, output):
        output[:, position, :] = clean_cache['activation'][:, position, :]
        return output

    handle = model.layers[layer_idx].register_forward_hook(patch_hook)
    with torch.no_grad():
        patched_output = model(corrupted_input)
    handle.remove()

    # 원래 corrupted output
    with torch.no_grad():
        corrupted_output = model(corrupted_input)

    # 인과 효과 = patched와 corrupted의 차이
    causal_effect = (patched_output - corrupted_output).abs().mean()
    return causal_effect.item()

2. Logit Lens / Tuned Lens

중간 레이어의 잔차 스트림을 unembedding에 투영하여 각 레이어에서의 "예측" 확인.

def logit_lens(model, hidden_states, layer_idx):
    """
    중간 레이어 hidden state를 어휘 공간으로 투영
    """
    # 레이어 정규화 적용
    normed = model.final_layer_norm(hidden_states[layer_idx])

    # Unembedding (lm_head) 적용
    logits = model.lm_head(normed)

    # 상위 토큰 확인
    probs = torch.softmax(logits[:, -1, :], dim=-1)
    top_tokens = torch.topk(probs, k=10)

    return top_tokens

def tuned_lens(model, hidden_states, layer_idx, translator):
    """
    학습된 변환기를 사용한 개선된 logit lens
    translator: 각 레이어별로 학습된 affine 변환
    """
    translated = translator[layer_idx](hidden_states[layer_idx])
    normed = model.final_layer_norm(translated)
    logits = model.lm_head(normed)
    return logits

3. Sparse Autoencoders (SAE)

Superposition을 풀어 해석 가능한 단일 특징(monosemantic features)으로 분해.

import torch
import torch.nn as nn

class SparseAutoencoder(nn.Module):
    """
    활성화 벡터를 해석 가능한 특징으로 분해
    """
    def __init__(self, d_model, n_features, sparsity_coef=1e-3):
        super().__init__()
        self.encoder = nn.Linear(d_model, n_features)
        self.decoder = nn.Linear(n_features, d_model)
        self.sparsity_coef = sparsity_coef

    def forward(self, x):
        # 인코딩 (ReLU로 희소성 유도)
        features = torch.relu(self.encoder(x))

        # 디코딩 (재구성)
        reconstruction = self.decoder(features)

        return features, reconstruction

    def loss(self, x):
        features, reconstruction = self.forward(x)

        # 재구성 손실
        recon_loss = nn.functional.mse_loss(reconstruction, x)

        # L1 희소성 손실
        sparsity_loss = self.sparsity_coef * features.abs().mean()

        return recon_loss + sparsity_loss, features

# 학습 예시
def train_sae(model, dataloader, d_model=768, n_features=32768):
    sae = SparseAutoencoder(d_model, n_features)
    optimizer = torch.optim.Adam(sae.parameters(), lr=1e-4)

    for batch in dataloader:
        # 모델에서 활성화 추출
        with torch.no_grad():
            activations = extract_activations(model, batch)

        # SAE 학습
        loss, features = sae.loss(activations)
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

    return sae

4. Probing

특정 레이어에 선형 분류기를 학습시켜 정보 인코딩 여부 확인.

from sklearn.linear_model import LogisticRegression
import numpy as np

def linear_probe(model, dataset, layer_idx, labels):
    """
    특정 레이어가 특정 속성을 인코딩하는지 확인
    """
    activations = []

    for text in dataset:
        with torch.no_grad():
            outputs = model(text, output_hidden_states=True)
            # 특정 레이어의 [CLS] 또는 마지막 토큰 활성화
            act = outputs.hidden_states[layer_idx][:, -1, :]
            activations.append(act.cpu().numpy())

    X = np.vstack(activations)
    y = np.array(labels)

    # 선형 프로브 학습
    probe = LogisticRegression(max_iter=1000)
    probe.fit(X, y)

    accuracy = probe.score(X, y)
    return probe, accuracy

5. Attention Pattern Analysis

어텐션 가중치 시각화 및 패턴 분석.

def analyze_attention_heads(model, input_ids, layer_idx):
    """
    특정 레이어의 어텐션 헤드 패턴 분석
    """
    with torch.no_grad():
        outputs = model(input_ids, output_attentions=True)

    # [batch, num_heads, seq_len, seq_len]
    attention = outputs.attentions[layer_idx]

    patterns = {}
    num_heads = attention.shape[1]

    for head in range(num_heads):
        head_attn = attention[0, head]  # [seq_len, seq_len]

        # 패턴 특성 분석
        patterns[f"head_{head}"] = {
            "diagonal_score": head_attn.diagonal().mean().item(),  # 자기 자신 주목
            "prev_token_score": head_attn.diagonal(-1).mean().item(),  # 이전 토큰 주목
            "first_token_score": head_attn[:, 0].mean().item(),  # 첫 토큰 주목
            "entropy": -(head_attn * head_attn.log()).sum(-1).mean().item()  # 분산도
        }

    return patterns

주요 연구 성과

Anthropic의 연구

연구 핵심 내용 연도
Toy Models of Superposition 중첩 현상의 수학적 모델링 2022
Towards Monosemanticity SAE로 Claude 특징 추출 2023
Scaling Monosemanticity Claude 3 Sonnet에서 수백만 특징 식별 2024
Circuit Tracing 완전한 회로 추적 방법론 2025

OpenAI의 연구

연구 핵심 내용 연도
Language Models Can Explain Neurons GPT-4로 뉴런 설명 자동 생성 2023
Extracting Concepts from GPT-4 SAE 기반 개념 추출 2024

커뮤니티 연구

연구 핵심 내용 소속
Induction Heads 문맥 내 학습의 핵심 메커니즘 Anthropic
Indirect Object Identification GPT-2의 IOI 회로 완전 분석 Redwood Research
TransformerLens MI 연구용 라이브러리 Neel Nanda
SAELens SAE 학습 및 분석 도구 EleutherAI

도구 및 라이브러리

TransformerLens

from transformer_lens import HookedTransformer

# 모델 로드 (자동으로 hook 추가)
model = HookedTransformer.from_pretrained("gpt2-small")

# 캐시된 활성화로 추론
logits, cache = model.run_with_cache("Hello, world!")

# 특정 활성화 접근
residual_stream = cache["resid_post", 5]  # 레이어 5 이후 잔차
attn_pattern = cache["pattern", 3]  # 레이어 3 어텐션 패턴
mlp_out = cache["mlp_out", 7]  # 레이어 7 MLP 출력

# Activation patching
def patch_hook(activation, hook):
    activation[:, 5, :] = corrupted_activation[:, 5, :]
    return activation

patched_logits = model.run_with_hooks(
    clean_input,
    fwd_hooks=[("blocks.3.hook_resid_post", patch_hook)]
)

SAELens

from sae_lens import SAE, SAETrainingConfig

# SAE 설정
config = SAETrainingConfig(
    model_name="gpt2-small",
    hook_point="blocks.5.hook_resid_post",
    d_in=768,
    expansion_factor=32,  # n_features = d_in * expansion_factor
    lr=3e-4,
    sparsity_coefficient=1e-3
)

# 학습
sae = SAE.from_config(config)
sae.train(dataloader)

# 특징 활성화 분석
features = sae.encode(activations)
top_features = features.topk(k=10)

한계 및 도전 과제

도전 과제 설명
Scalability 대형 모델(100B+ 파라미터)에 적용 어려움
Faithfulness 발견한 회로가 실제 계산을 반영하는지 검증 필요
Completeness 모델 행동의 일부만 설명 가능
Automation 수작업 분석에서 자동화로 전환 중
Ground Truth 정답이 없어 검증이 어려움

실용적 응용

1. 모델 디버깅

def debug_incorrect_prediction(model, input_text, expected_token, actual_token):
    """
    잘못된 예측의 원인 분석
    """
    # 각 레이어에서의 logit 기여도 분석
    _, cache = model.run_with_cache(input_text)

    contributions = []
    for layer in range(model.cfg.n_layers):
        # 해당 레이어의 잔차 기여도
        layer_contrib = cache["resid_post", layer] - cache["resid_pre", layer]

        # 예상 토큰과 실제 토큰에 대한 로짓 기여도
        expected_logit = (layer_contrib @ model.W_U[:, expected_token]).mean()
        actual_logit = (layer_contrib @ model.W_U[:, actual_token]).mean()

        contributions.append({
            "layer": layer,
            "expected_contribution": expected_logit.item(),
            "actual_contribution": actual_logit.item()
        })

    return contributions

2. 특징 조작 (Steering)

def steer_generation(model, sae, input_text, feature_idx, strength=1.0):
    """
    특정 특징을 증폭/억제하여 생성 조작
    """
    def steering_hook(activation, hook):
        # SAE로 특징 분해
        features = sae.encode(activation)

        # 특정 특징 조작
        features[:, :, feature_idx] += strength

        # 다시 활성화 공간으로
        return sae.decode(features)

    output = model.run_with_hooks(
        input_text,
        fwd_hooks=[("blocks.10.hook_resid_post", steering_hook)]
    )
    return output

참고 문헌

  1. Elhage, N., et al. (2022). "Toy Models of Superposition." Anthropic.
  2. Conmy, A., et al. (2023). "Towards Automated Circuit Discovery for Mechanistic Interpretability." NeurIPS 2023.
  3. Bricken, T., et al. (2023). "Towards Monosemanticity: Decomposing Language Models With Dictionary Learning." Anthropic.
  4. Templeton, A., et al. (2024). "Scaling Monosemanticity: Extracting Interpretable Features from Claude 3 Sonnet." Anthropic.
  5. Bills, S., et al. (2023). "Language Models Can Explain Neurons in Language Models." OpenAI.
  6. Wang, K., et al. (2023). "Interpretability in the Wild: A Circuit for Indirect Object Identification in GPT-2 Small." ICLR 2023.
  7. Nanda, N. (2023). "TransformerLens: A Library for Mechanistic Interpretability." GitHub.
  8. Marks, S., et al. (2024). "Sparse Feature Circuits: Discovering and Editing Interpretable Causal Graphs in Language Models." ICML 2024.