In-Context Learning Theory¶
In-Context Learning (ICL)은 Transformer 모델이 추가 학습 없이 프롬프트에 포함된 예시만으로 새로운 태스크를 수행하는 능력이다. 최근 연구에서는 ICL의 메커니즘이 암묵적으로 gradient descent를 수행하는 것과 동등함이 밝혀졌으며, 이를 mesa-optimization이라 부른다.
1. 개요¶
1.1 배경 및 정의¶
In-Context Learning은 GPT-3 (Brown et al., 2020)에서 처음 체계적으로 연구되었다. 모델은 다음과 같은 형태의 입력을 받는다:
| 용어 | 정의 |
|---|---|
| Context | 입력 프롬프트에 포함된 (x, y) 예시 쌍들 |
| In-Context Learning | 파라미터 업데이트 없이 context로부터 학습 |
| Few-shot | 소수의 예시만 주어지는 설정 |
| Zero-shot | 예시 없이 태스크 설명만 주어지는 설정 |
핵심 질문: Transformer가 ICL을 어떻게 수행하는가?
1.2 ICL의 두 가지 관점¶
- Task Recognition: 사전학습 중 본 태스크를 인식하고 해당 능력을 발휘
- Task Learning: context에서 실제로 새로운 규칙을 학습
최근 연구는 두 번째 관점, 즉 Transformer가 forward pass 중에 실제로 학습을 수행한다는 증거를 제시한다.
2. 이론적 기초: ICL as Gradient Descent¶
2.1 핵심 논문: Transformers learn in-context by gradient descent¶
| 항목 | 내용 |
|---|---|
| 논문 | Transformers learn in-context by gradient descent |
| 출처 | ICML 2023 |
| 저자 | von Oswald, Niklasson, Randazzo, Sacramento, Mordvintsev, Zhmoginov, Vladymyrov |
| 코드 | github.com/google-research/self-organising-systems |
핵심 발견: 단일 linear self-attention layer가 gradient descent step과 정확히 동등한 연산을 수행할 수 있다.
2.2 Self-Attention과 Gradient Descent의 동치성¶
Linear regression 문제를 고려하자. 데이터 {(x_i, y_i)}_{i=1}^n에 대해:
Gradient Descent 업데이트:
Linear Self-Attention:
Linear attention (softmax 제거)의 경우:
동치성 구성: 적절한 W_Q, W_K, W_V 가중치를 선택하면:
# Input sequence: [(x_1, y_1), (x_2, y_2), ..., (x_n, y_n), (x_query, 0)]
# 각 토큰을 [x; y] 형태로 concatenate
# Linear attention output의 query 위치:
# sum_i (x_query^T x_i) * [x_i; y_i]
# = [sum_i (x_query^T x_i) x_i; sum_i (x_query^T x_i) y_i]
# 이는 OLS 해의 형태와 유사:
# w_OLS = (X^T X)^{-1} X^T y
2.3 Mesa-Optimization¶
Mesa-optimization은 학습된 모델이 내부적으로 최적화 알고리즘을 수행하는 현상이다.
Base Optimizer (SGD/Adam)
|
v
[Training] --> Transformer Weights
|
v
[Forward Pass]
|
v
Mesa-Optimizer (내부 GD)
|
v
Task Solution
| 용어 | 정의 |
|---|---|
| Base Optimizer | 모델 파라미터를 학습시키는 외부 최적화기 |
| Mesa-Optimizer | 학습된 모델 내부에 구현된 최적화기 |
| Objective | Mesa-optimizer가 최적화하는 목표 (암묵적) |
실험적 증거: - 학습된 Transformer 가중치가 GD 구성과 높은 유사성 - Context 길이가 증가할수록 ICL 성능이 GD 수렴과 유사한 패턴 - Attention pattern이 GD의 data weighting과 대응
2.4 GD++: Plain GD를 넘어서¶
von Oswald et al.의 분석에서 Transformer는 단순 GD보다 우수한 성능을 보였다. 이를 GD++라 명명한다.
GD++ 특성: 1. Curvature Correction: Hessian 정보를 활용한 적응적 학습률 2. Momentum: 이전 업데이트 방향 활용 3. Preconditioning: 입력 공간 변환
# GD++ 형태 (von Oswald et al. 분석)
# Layer l에서:
delta_w_l = -eta_l * H_l^{-1} * nabla_w L(w)
# 여기서 H_l은 학습된 preconditioner
# 여러 layer가 Newton method의 근사를 구현
3. 일반화 이론¶
3.1 Transformers as Algorithms¶
| 항목 | 내용 |
|---|---|
| 논문 | Transformers as Algorithms: Generalization and Stability in In-context Learning |
| 출처 | ICML 2023 |
| 저자 | Li, Ildiz, Papailiopoulos, Oymak |
핵심 아이디어: ICL을 algorithm learning 문제로 형식화
Transformer f_theta: (context, x_query) -> y_query
# Context = {(x_1, y_1), ..., (x_n, y_n)}
# f_theta는 context로부터 암묵적으로 hypothesis h를 구성
# h(x_query) = y_query
3.2 Stability와 Generalization¶
Algorithmic stability: 입력의 작은 변화에 대해 출력이 안정적
정리 (Li et al., 2023): Transformer가 beta-stable하면:
여기서 n은 context 길이.
Stability 조건: - Softmax attention의 온도 파라미터 - Layer normalization - Residual connection의 스케일링
3.3 Multitask Learning 관점¶
ICL의 일반화는 multitask learning (MTL)과 밀접하게 연관된다.
사전학습: T개의 태스크 {tau_1, ..., tau_T}
테스트: 새로운 태스크 tau_new (context로 정의)
Transfer Risk = Task Complexity / sqrt(T) + Context Error / sqrt(n)
Inductive Bias: 태스크 수 T가 증가하면 새로운 태스크로의 전이가 개선된다.
4. Bayesian 관점¶
4.1 ICL as Bayesian Inference¶
| 항목 | 내용 |
|---|---|
| 논문 | Can Transformers Learn Full Bayesian Inference in Context? |
| 출처 | ICML 2025 |
Transformer가 posterior 계산을 수행한다는 가설:
증거: - 학습된 모델이 predictive posterior와 유사한 분포 출력 - Uncertainty quantification이 Bayesian 예측과 상관
4.2 Posterior Approximation¶
# Bayesian ICL 해석
# Context로부터 posterior 근사:
# p(theta | D_n) propto p(D_n | theta) p(theta)
# Transformer의 내부 표현:
# z = f_encoder(context) # posterior의 sufficient statistics
# y = f_decoder(x_query, z) # predictive distribution
5. Algorithm Selection과 적응¶
5.1 Transformers as Statisticians¶
| 항목 | 내용 |
|---|---|
| 논문 | Transformers as Statisticians: Provable In-Context Learning with In-Context Algorithm Selection |
| 출처 | NeurIPS 2024 |
| 저자 | Bai, Chen, Wang, Xiong, Mei |
Transformer가 context에 따라 다른 알고리즘을 선택:
Context (선형 패턴) --> Linear Regression 구현
Context (비선형 패턴) --> Kernel Regression 구현
Context (노이즈 多) --> Robust Estimator 구현
5.2 Algorithm Selection 메커니즘¶
- Pre-ICL Testing: Context 분석 후 알고리즘 선택
- Post-ICL Validation: 여러 알고리즘 시도 후 최적 선택
# Pre-ICL Testing (개념적)
def select_algorithm(context):
# Context 특성 추출
noise_level = estimate_noise(context)
linearity = test_linearity(context)
if linearity > threshold and noise_level < threshold:
return "OLS"
elif linearity > threshold:
return "Ridge"
else:
return "Kernel"
6. Universal Approximation¶
6.1 최신 이론: 2025¶
| 항목 | 내용 |
|---|---|
| 논문 | Transformers Meet In-Context Learning: A Universal Approximation Theory |
| 출처 | arXiv 2025 |
| 저자 | Gen Li et al. |
정리: 충분한 깊이와 너비의 Transformer는 다양한 함수 클래스에 대해 ICL을 수행할 수 있다.
임의의 함수 클래스 F와 오차 epsilon에 대해,
Transformer T가 존재하여:
E_{f ~ F, context, x} [ ||T(context, x) - f(x)||^2 ] < epsilon
6.2 Depth와 Width의 역할¶
| 구성 요소 | ICL에서의 역할 |
|---|---|
| Depth | 반복적 최적화 스텝, 알고리즘 복잡도 |
| Width | 표현력, 동시 가설 수 |
| Attention Heads | 병렬 알고리즘 실행 |
| MLP | 비선형 변환, 특징 추출 |
7. Mechanistic Interpretability¶
7.1 Induction Heads¶
Olsson et al. (2022)가 발견한 induction head는 ICL의 핵심 메커니즘 중 하나다.
Induction Head 구성: - Head 1: 이전 토큰으로 attention - Head 2: 현재와 매칭되는 이전 context 탐색 - 결합: Copy-like 행동 구현
7.2 ICL Circuit 분석¶
Layer 1-2: Token/Position Encoding
Layer 3-4: Pattern Matching (Induction)
Layer 5-6: Task Vector 형성
Layer 7+: Task 실행, 예측 생성
8. Python 구현¶
8.1 Linear Attention의 GD 동치성 시연¶
import numpy as np
import torch
import torch.nn as nn
class LinearAttentionGD(nn.Module):
"""
Linear attention이 gradient descent를 수행함을 보이는 구현
"""
def __init__(self, d_in, d_out):
super().__init__()
self.d_in = d_in
self.d_out = d_out
# GD 동치 가중치 구성
# W_K, W_Q: 입력 x에 대한 projection
# W_V: (x, y) pair에 대한 projection
self.W_K = nn.Parameter(torch.eye(d_in))
self.W_Q = nn.Parameter(torch.eye(d_in))
self.W_V = nn.Parameter(torch.zeros(d_in + d_out, d_out))
# GD 스타일 초기화
self._init_gd_weights()
def _init_gd_weights(self):
# von Oswald et al. 구성에 따른 초기화
# V가 y 부분만 추출하도록
with torch.no_grad():
self.W_V[self.d_in:, :] = torch.eye(self.d_out)
def forward(self, context_x, context_y, query_x):
"""
context_x: (batch, n_context, d_in)
context_y: (batch, n_context, d_out)
query_x: (batch, d_in)
Returns: predicted y for query
"""
batch_size = context_x.shape[0]
n_context = context_x.shape[1]
# Key, Query 계산
K = context_x @ self.W_K # (batch, n, d_in)
Q = query_x @ self.W_Q # (batch, d_in)
# Attention scores (linear, no softmax)
scores = torch.einsum('bi,bni->bn', Q, K) # (batch, n)
# Value: concatenate (x, y)
context_xy = torch.cat([context_x, context_y], dim=-1) # (batch, n, d_in+d_out)
V = context_xy @ self.W_V # (batch, n, d_out)
# Weighted sum
output = torch.einsum('bn,bno->bo', scores, V) # (batch, d_out)
return output
def compare_with_gd(n_samples=100, d_in=5, d_out=1, n_context=20):
"""
Linear attention과 closed-form OLS 비교
"""
# 랜덤 데이터 생성
np.random.seed(42)
# True weight
w_true = np.random.randn(d_in, d_out)
# Context data
X_context = np.random.randn(n_context, d_in)
y_context = X_context @ w_true + 0.1 * np.random.randn(n_context, d_out)
# Query
x_query = np.random.randn(d_in)
y_true = x_query @ w_true
# OLS solution
w_ols = np.linalg.lstsq(X_context, y_context, rcond=None)[0]
y_ols = x_query @ w_ols
# Linear attention (simplified)
# y_pred = x_query^T @ X^T @ y / (x_query^T @ X^T @ X @ x_query)
# 이는 kernel regression과 유사
attention_scores = X_context @ x_query # (n,)
y_attn = (attention_scores @ y_context) / (np.sum(attention_scores ** 2) + 1e-6)
print(f"True y: {y_true[0]:.4f}")
print(f"OLS prediction: {y_ols[0]:.4f}")
print(f"Attention prediction: {y_attn[0]:.4f}")
return y_true, y_ols, y_attn
# 실행
compare_with_gd()
8.2 ICL 실험: Synthetic Regression¶
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import DataLoader, Dataset
class ICLRegressionDataset(Dataset):
"""
각 샘플이 하나의 태스크 (context + query)인 데이터셋
"""
def __init__(self, n_tasks, n_context, d_in, d_out=1, noise_std=0.1):
self.n_tasks = n_tasks
self.n_context = n_context
self.d_in = d_in
self.d_out = d_out
self.noise_std = noise_std
def __len__(self):
return self.n_tasks
def __getitem__(self, idx):
# 각 태스크마다 랜덤 weight
w = torch.randn(self.d_in, self.d_out)
# Context
x_context = torch.randn(self.n_context, self.d_in)
y_context = x_context @ w + self.noise_std * torch.randn(self.n_context, self.d_out)
# Query
x_query = torch.randn(self.d_in)
y_query = x_query @ w
return {
'x_context': x_context,
'y_context': y_context,
'x_query': x_query,
'y_query': y_query,
'w': w
}
class SimpleICLTransformer(nn.Module):
"""
ICL을 위한 간단한 Transformer
"""
def __init__(self, d_in, d_out, d_model=64, n_heads=4, n_layers=4):
super().__init__()
self.d_in = d_in
self.d_out = d_out
self.d_model = d_model
# Input embedding
self.input_embed = nn.Linear(d_in + d_out, d_model)
self.query_embed = nn.Linear(d_in, d_model)
# Transformer layers
encoder_layer = nn.TransformerEncoderLayer(
d_model=d_model,
nhead=n_heads,
dim_feedforward=d_model * 4,
batch_first=True
)
self.transformer = nn.TransformerEncoder(encoder_layer, num_layers=n_layers)
# Output projection
self.output = nn.Linear(d_model, d_out)
def forward(self, x_context, y_context, x_query):
"""
x_context: (batch, n_context, d_in)
y_context: (batch, n_context, d_out)
x_query: (batch, d_in)
"""
batch_size = x_context.shape[0]
n_context = x_context.shape[1]
# Context embedding
context_input = torch.cat([x_context, y_context], dim=-1) # (batch, n, d_in+d_out)
context_emb = self.input_embed(context_input) # (batch, n, d_model)
# Query embedding (y=0으로 padding)
query_input = x_query.unsqueeze(1) # (batch, 1, d_in)
query_emb = self.query_embed(query_input) # (batch, 1, d_model)
# Concatenate
seq = torch.cat([context_emb, query_emb], dim=1) # (batch, n+1, d_model)
# Transformer
out = self.transformer(seq) # (batch, n+1, d_model)
# Query position output
query_out = out[:, -1, :] # (batch, d_model)
return self.output(query_out) # (batch, d_out)
def train_icl_model(n_epochs=100, batch_size=32):
"""
ICL Transformer 학습
"""
d_in, d_out = 10, 1
n_context = 20
# Dataset
train_dataset = ICLRegressionDataset(
n_tasks=10000, n_context=n_context, d_in=d_in, d_out=d_out
)
train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)
# Model
model = SimpleICLTransformer(d_in=d_in, d_out=d_out)
optimizer = torch.optim.Adam(model.parameters(), lr=1e-3)
# Training
for epoch in range(n_epochs):
total_loss = 0
for batch in train_loader:
optimizer.zero_grad()
pred = model(
batch['x_context'],
batch['y_context'],
batch['x_query']
)
loss = F.mse_loss(pred, batch['y_query'])
loss.backward()
optimizer.step()
total_loss += loss.item()
if (epoch + 1) % 10 == 0:
print(f"Epoch {epoch+1}, Loss: {total_loss/len(train_loader):.6f}")
return model
def analyze_icl_vs_ols(model, n_tests=100, n_context=20, d_in=10):
"""
학습된 ICL 모델과 OLS 비교
"""
model.eval()
icl_errors = []
ols_errors = []
with torch.no_grad():
for _ in range(n_tests):
# 새로운 태스크
w = torch.randn(d_in, 1)
x_context = torch.randn(n_context, d_in)
y_context = x_context @ w
x_query = torch.randn(d_in)
y_true = (x_query @ w).item()
# ICL prediction
y_icl = model(
x_context.unsqueeze(0),
y_context.unsqueeze(0),
x_query.unsqueeze(0)
).item()
# OLS prediction
w_ols = torch.linalg.lstsq(x_context, y_context).solution
y_ols = (x_query @ w_ols).item()
icl_errors.append((y_icl - y_true) ** 2)
ols_errors.append((y_ols - y_true) ** 2)
print(f"ICL MSE: {np.mean(icl_errors):.6f}")
print(f"OLS MSE: {np.mean(ols_errors):.6f}")
print(f"Ratio (ICL/OLS): {np.mean(icl_errors)/np.mean(ols_errors):.4f}")
8.3 Context Length와 성능 관계 분석¶
import matplotlib.pyplot as plt
def analyze_context_length_effect(model, d_in=10, max_context=50):
"""
Context 길이에 따른 ICL 성능 변화 분석
"""
model.eval()
context_lengths = list(range(5, max_context + 1, 5))
icl_mses = []
ols_mses = []
n_tests = 100
for n_ctx in context_lengths:
icl_errors = []
ols_errors = []
with torch.no_grad():
for _ in range(n_tests):
w = torch.randn(d_in, 1)
x_context = torch.randn(n_ctx, d_in)
y_context = x_context @ w + 0.1 * torch.randn(n_ctx, 1)
x_query = torch.randn(d_in)
y_true = (x_query @ w).item()
# ICL
y_icl = model(
x_context.unsqueeze(0),
y_context.unsqueeze(0),
x_query.unsqueeze(0)
).item()
# OLS
w_ols = torch.linalg.lstsq(x_context, y_context).solution
y_ols = (x_query @ w_ols).item()
icl_errors.append((y_icl - y_true) ** 2)
ols_errors.append((y_ols - y_true) ** 2)
icl_mses.append(np.mean(icl_errors))
ols_mses.append(np.mean(ols_errors))
# Plotting
plt.figure(figsize=(10, 6))
plt.plot(context_lengths, icl_mses, 'b-o', label='ICL Transformer')
plt.plot(context_lengths, ols_mses, 'r--s', label='OLS')
plt.xlabel('Context Length')
plt.ylabel('MSE')
plt.title('ICL Performance vs Context Length')
plt.legend()
plt.grid(True)
plt.savefig('icl_context_length.png', dpi=150, bbox_inches='tight')
return context_lengths, icl_mses, ols_mses
9. 최신 연구 동향 (2024-2025)¶
9.1 주요 연구 방향¶
| 연구 방향 | 핵심 내용 | 대표 논문 |
|---|---|---|
| State Space Models의 ICL | Mamba 등 SSM의 ICL 능력 분석 | Park et al., 2024; Yin & Steinhardt, 2025 |
| Hybrid Architectures | Transformer + SSM 조합 | Understanding ICL Beyond Transformers, 2025 |
| Scaling Laws | ICL 성능과 모델 크기 관계 | Memory Mosaics v2, 2025 |
| Bayesian ICL | 완전한 posterior 추론 | Can Transformers Learn Full Bayesian Inference, ICML 2025 |
9.2 Open Problems¶
- Task Diversity와 ICL: 사전학습 태스크 다양성이 ICL에 미치는 영향
- Compositional ICL: 복잡한 태스크의 분해와 조합
- ICL의 한계: 어떤 태스크는 ICL이 불가능한가?
- Efficient ICL: Context 길이 제한 극복
10. 핵심 요약¶
| 개념 | 설명 |
|---|---|
| In-Context Learning | 파라미터 업데이트 없이 context에서 학습 |
| Mesa-Optimization | 모델 내부에서 최적화 알고리즘 수행 |
| ICL as GD | Linear attention이 gradient descent와 동치 |
| GD++ | Transformer가 수행하는 향상된 최적화 (curvature correction) |
| Stability | ICL 일반화의 핵심 조건 |
| Algorithm Selection | Context에 따른 적응적 알고리즘 선택 |
참고 문헌¶
- Brown, T., et al. (2020). Language Models are Few-Shot Learners. NeurIPS.
- von Oswald, J., et al. (2023). Transformers learn in-context by gradient descent. ICML.
- Li, Y., et al. (2023). Transformers as Algorithms: Generalization and Stability in In-context Learning. ICML.
- Bai, Y., et al. (2024). Transformers as Statisticians: Provable In-Context Learning with In-Context Algorithm Selection. NeurIPS.
- Olsson, C., et al. (2022). In-context Learning and Induction Heads. Anthropic.
- Akyurek, E., et al. (2023). What learning algorithm is in-context learning? Investigations with linear models. ICLR.
- Garg, S., et al. (2022). What Can Transformers Learn In-Context? A Case Study of Simple Function Classes. NeurIPS.
- Li, G., et al. (2025). Transformers Meet In-Context Learning: A Universal Approximation Theory. arXiv.