콘텐츠로 이동
Data Prep
상세

In-Context Learning Theory

In-Context Learning (ICL)은 Transformer 모델이 추가 학습 없이 프롬프트에 포함된 예시만으로 새로운 태스크를 수행하는 능력이다. 최근 연구에서는 ICL의 메커니즘이 암묵적으로 gradient descent를 수행하는 것과 동등함이 밝혀졌으며, 이를 mesa-optimization이라 부른다.

1. 개요

1.1 배경 및 정의

In-Context Learning은 GPT-3 (Brown et al., 2020)에서 처음 체계적으로 연구되었다. 모델은 다음과 같은 형태의 입력을 받는다:

[x_1, y_1, x_2, y_2, ..., x_n, y_n, x_query] -> y_query
용어 정의
Context 입력 프롬프트에 포함된 (x, y) 예시 쌍들
In-Context Learning 파라미터 업데이트 없이 context로부터 학습
Few-shot 소수의 예시만 주어지는 설정
Zero-shot 예시 없이 태스크 설명만 주어지는 설정

핵심 질문: Transformer가 ICL을 어떻게 수행하는가?

1.2 ICL의 두 가지 관점

  1. Task Recognition: 사전학습 중 본 태스크를 인식하고 해당 능력을 발휘
  2. Task Learning: context에서 실제로 새로운 규칙을 학습

최근 연구는 두 번째 관점, 즉 Transformer가 forward pass 중에 실제로 학습을 수행한다는 증거를 제시한다.

2. 이론적 기초: ICL as Gradient Descent

2.1 핵심 논문: Transformers learn in-context by gradient descent

항목 내용
논문 Transformers learn in-context by gradient descent
출처 ICML 2023
저자 von Oswald, Niklasson, Randazzo, Sacramento, Mordvintsev, Zhmoginov, Vladymyrov
코드 github.com/google-research/self-organising-systems

핵심 발견: 단일 linear self-attention layer가 gradient descent step과 정확히 동등한 연산을 수행할 수 있다.

2.2 Self-Attention과 Gradient Descent의 동치성

Linear regression 문제를 고려하자. 데이터 {(x_i, y_i)}_{i=1}^n에 대해:

Gradient Descent 업데이트:

w_{t+1} = w_t - eta * nabla_w L(w_t)
        = w_t - eta * sum_i (w_t^T x_i - y_i) * x_i

Linear Self-Attention:

Attention(Q, K, V) = softmax(QK^T / sqrt(d)) * V

Linear attention (softmax 제거)의 경우:

Attention(Q, K, V) = QK^T * V

동치성 구성: 적절한 W_Q, W_K, W_V 가중치를 선택하면:

# Input sequence: [(x_1, y_1), (x_2, y_2), ..., (x_n, y_n), (x_query, 0)]
# 각 토큰을 [x; y] 형태로 concatenate

# Linear attention output의 query 위치:
# sum_i (x_query^T x_i) * [x_i; y_i]
# = [sum_i (x_query^T x_i) x_i; sum_i (x_query^T x_i) y_i]

# 이는 OLS 해의 형태와 유사:
# w_OLS = (X^T X)^{-1} X^T y

2.3 Mesa-Optimization

Mesa-optimization은 학습된 모델이 내부적으로 최적화 알고리즘을 수행하는 현상이다.

Base Optimizer (SGD/Adam)
    |
    v
[Training] --> Transformer Weights
                    |
                    v
              [Forward Pass]
                    |
                    v
            Mesa-Optimizer (내부 GD)
                    |
                    v
              Task Solution
용어 정의
Base Optimizer 모델 파라미터를 학습시키는 외부 최적화기
Mesa-Optimizer 학습된 모델 내부에 구현된 최적화기
Objective Mesa-optimizer가 최적화하는 목표 (암묵적)

실험적 증거: - 학습된 Transformer 가중치가 GD 구성과 높은 유사성 - Context 길이가 증가할수록 ICL 성능이 GD 수렴과 유사한 패턴 - Attention pattern이 GD의 data weighting과 대응

2.4 GD++: Plain GD를 넘어서

von Oswald et al.의 분석에서 Transformer는 단순 GD보다 우수한 성능을 보였다. 이를 GD++라 명명한다.

GD++ 특성: 1. Curvature Correction: Hessian 정보를 활용한 적응적 학습률 2. Momentum: 이전 업데이트 방향 활용 3. Preconditioning: 입력 공간 변환

# GD++ 형태 (von Oswald et al. 분석)
# Layer l에서:

delta_w_l = -eta_l * H_l^{-1} * nabla_w L(w)

# 여기서 H_l은 학습된 preconditioner
# 여러 layer가 Newton method의 근사를 구현

3. 일반화 이론

3.1 Transformers as Algorithms

항목 내용
논문 Transformers as Algorithms: Generalization and Stability in In-context Learning
출처 ICML 2023
저자 Li, Ildiz, Papailiopoulos, Oymak

핵심 아이디어: ICL을 algorithm learning 문제로 형식화

Transformer f_theta: (context, x_query) -> y_query

# Context = {(x_1, y_1), ..., (x_n, y_n)}
# f_theta는 context로부터 암묵적으로 hypothesis h를 구성
# h(x_query) = y_query

3.2 Stability와 Generalization

Algorithmic stability: 입력의 작은 변화에 대해 출력이 안정적

||f_theta(context, x) - f_theta(context', x)|| <= beta * ||context - context'||

정리 (Li et al., 2023): Transformer가 beta-stable하면:

E[excess risk] <= O(beta * sqrt(n) + 1/sqrt(n))

여기서 n은 context 길이.

Stability 조건: - Softmax attention의 온도 파라미터 - Layer normalization - Residual connection의 스케일링

3.3 Multitask Learning 관점

ICL의 일반화는 multitask learning (MTL)과 밀접하게 연관된다.

사전학습: T개의 태스크 {tau_1, ..., tau_T}
테스트: 새로운 태스크 tau_new (context로 정의)

Transfer Risk = Task Complexity / sqrt(T) + Context Error / sqrt(n)

Inductive Bias: 태스크 수 T가 증가하면 새로운 태스크로의 전이가 개선된다.

4. Bayesian 관점

4.1 ICL as Bayesian Inference

항목 내용
논문 Can Transformers Learn Full Bayesian Inference in Context?
출처 ICML 2025

Transformer가 posterior 계산을 수행한다는 가설:

p(y_query | x_query, context) = integral p(y | x, theta) p(theta | context) d_theta

증거: - 학습된 모델이 predictive posterior와 유사한 분포 출력 - Uncertainty quantification이 Bayesian 예측과 상관

4.2 Posterior Approximation

# Bayesian ICL 해석

# Context로부터 posterior 근사:
# p(theta | D_n) propto p(D_n | theta) p(theta)

# Transformer의 내부 표현:
# z = f_encoder(context)  # posterior의 sufficient statistics
# y = f_decoder(x_query, z)  # predictive distribution

5. Algorithm Selection과 적응

5.1 Transformers as Statisticians

항목 내용
논문 Transformers as Statisticians: Provable In-Context Learning with In-Context Algorithm Selection
출처 NeurIPS 2024
저자 Bai, Chen, Wang, Xiong, Mei

Transformer가 context에 따라 다른 알고리즘을 선택:

Context (선형 패턴) --> Linear Regression 구현
Context (비선형 패턴) --> Kernel Regression 구현
Context (노이즈 多) --> Robust Estimator 구현

5.2 Algorithm Selection 메커니즘

  1. Pre-ICL Testing: Context 분석 후 알고리즘 선택
  2. Post-ICL Validation: 여러 알고리즘 시도 후 최적 선택
# Pre-ICL Testing (개념적)
def select_algorithm(context):
    # Context 특성 추출
    noise_level = estimate_noise(context)
    linearity = test_linearity(context)

    if linearity > threshold and noise_level < threshold:
        return "OLS"
    elif linearity > threshold:
        return "Ridge"
    else:
        return "Kernel"

6. Universal Approximation

6.1 최신 이론: 2025

항목 내용
논문 Transformers Meet In-Context Learning: A Universal Approximation Theory
출처 arXiv 2025
저자 Gen Li et al.

정리: 충분한 깊이와 너비의 Transformer는 다양한 함수 클래스에 대해 ICL을 수행할 수 있다.

임의의 함수 클래스 F와 오차 epsilon에 대해,
Transformer T가 존재하여:

E_{f ~ F, context, x} [ ||T(context, x) - f(x)||^2 ] < epsilon

6.2 Depth와 Width의 역할

구성 요소 ICL에서의 역할
Depth 반복적 최적화 스텝, 알고리즘 복잡도
Width 표현력, 동시 가설 수
Attention Heads 병렬 알고리즘 실행
MLP 비선형 변환, 특징 추출

7. Mechanistic Interpretability

7.1 Induction Heads

Olsson et al. (2022)가 발견한 induction head는 ICL의 핵심 메커니즘 중 하나다.

패턴: [A] [B] ... [A] --> [B]

"이전에 A 다음에 B가 왔으니, 다시 A가 오면 B를 예측"

Induction Head 구성: - Head 1: 이전 토큰으로 attention - Head 2: 현재와 매칭되는 이전 context 탐색 - 결합: Copy-like 행동 구현

7.2 ICL Circuit 분석

Layer 1-2: Token/Position Encoding
Layer 3-4: Pattern Matching (Induction)
Layer 5-6: Task Vector 형성
Layer 7+: Task 실행, 예측 생성

8. Python 구현

8.1 Linear Attention의 GD 동치성 시연

import numpy as np
import torch
import torch.nn as nn

class LinearAttentionGD(nn.Module):
    """
    Linear attention이 gradient descent를 수행함을 보이는 구현
    """
    def __init__(self, d_in, d_out):
        super().__init__()
        self.d_in = d_in
        self.d_out = d_out

        # GD 동치 가중치 구성
        # W_K, W_Q: 입력 x에 대한 projection
        # W_V: (x, y) pair에 대한 projection

        self.W_K = nn.Parameter(torch.eye(d_in))
        self.W_Q = nn.Parameter(torch.eye(d_in))
        self.W_V = nn.Parameter(torch.zeros(d_in + d_out, d_out))

        # GD 스타일 초기화
        self._init_gd_weights()

    def _init_gd_weights(self):
        # von Oswald et al. 구성에 따른 초기화
        # V가 y 부분만 추출하도록
        with torch.no_grad():
            self.W_V[self.d_in:, :] = torch.eye(self.d_out)

    def forward(self, context_x, context_y, query_x):
        """
        context_x: (batch, n_context, d_in)
        context_y: (batch, n_context, d_out)
        query_x: (batch, d_in)

        Returns: predicted y for query
        """
        batch_size = context_x.shape[0]
        n_context = context_x.shape[1]

        # Key, Query 계산
        K = context_x @ self.W_K  # (batch, n, d_in)
        Q = query_x @ self.W_Q    # (batch, d_in)

        # Attention scores (linear, no softmax)
        scores = torch.einsum('bi,bni->bn', Q, K)  # (batch, n)

        # Value: concatenate (x, y)
        context_xy = torch.cat([context_x, context_y], dim=-1)  # (batch, n, d_in+d_out)
        V = context_xy @ self.W_V  # (batch, n, d_out)

        # Weighted sum
        output = torch.einsum('bn,bno->bo', scores, V)  # (batch, d_out)

        return output


def compare_with_gd(n_samples=100, d_in=5, d_out=1, n_context=20):
    """
    Linear attention과 closed-form OLS 비교
    """
    # 랜덤 데이터 생성
    np.random.seed(42)

    # True weight
    w_true = np.random.randn(d_in, d_out)

    # Context data
    X_context = np.random.randn(n_context, d_in)
    y_context = X_context @ w_true + 0.1 * np.random.randn(n_context, d_out)

    # Query
    x_query = np.random.randn(d_in)
    y_true = x_query @ w_true

    # OLS solution
    w_ols = np.linalg.lstsq(X_context, y_context, rcond=None)[0]
    y_ols = x_query @ w_ols

    # Linear attention (simplified)
    # y_pred = x_query^T @ X^T @ y / (x_query^T @ X^T @ X @ x_query)
    # 이는 kernel regression과 유사

    attention_scores = X_context @ x_query  # (n,)
    y_attn = (attention_scores @ y_context) / (np.sum(attention_scores ** 2) + 1e-6)

    print(f"True y: {y_true[0]:.4f}")
    print(f"OLS prediction: {y_ols[0]:.4f}")
    print(f"Attention prediction: {y_attn[0]:.4f}")

    return y_true, y_ols, y_attn

# 실행
compare_with_gd()

8.2 ICL 실험: Synthetic Regression

import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import DataLoader, Dataset

class ICLRegressionDataset(Dataset):
    """
    각 샘플이 하나의 태스크 (context + query)인 데이터셋
    """
    def __init__(self, n_tasks, n_context, d_in, d_out=1, noise_std=0.1):
        self.n_tasks = n_tasks
        self.n_context = n_context
        self.d_in = d_in
        self.d_out = d_out
        self.noise_std = noise_std

    def __len__(self):
        return self.n_tasks

    def __getitem__(self, idx):
        # 각 태스크마다 랜덤 weight
        w = torch.randn(self.d_in, self.d_out)

        # Context
        x_context = torch.randn(self.n_context, self.d_in)
        y_context = x_context @ w + self.noise_std * torch.randn(self.n_context, self.d_out)

        # Query
        x_query = torch.randn(self.d_in)
        y_query = x_query @ w

        return {
            'x_context': x_context,
            'y_context': y_context,
            'x_query': x_query,
            'y_query': y_query,
            'w': w
        }


class SimpleICLTransformer(nn.Module):
    """
    ICL을 위한 간단한 Transformer
    """
    def __init__(self, d_in, d_out, d_model=64, n_heads=4, n_layers=4):
        super().__init__()
        self.d_in = d_in
        self.d_out = d_out
        self.d_model = d_model

        # Input embedding
        self.input_embed = nn.Linear(d_in + d_out, d_model)
        self.query_embed = nn.Linear(d_in, d_model)

        # Transformer layers
        encoder_layer = nn.TransformerEncoderLayer(
            d_model=d_model,
            nhead=n_heads,
            dim_feedforward=d_model * 4,
            batch_first=True
        )
        self.transformer = nn.TransformerEncoder(encoder_layer, num_layers=n_layers)

        # Output projection
        self.output = nn.Linear(d_model, d_out)

    def forward(self, x_context, y_context, x_query):
        """
        x_context: (batch, n_context, d_in)
        y_context: (batch, n_context, d_out)
        x_query: (batch, d_in)
        """
        batch_size = x_context.shape[0]
        n_context = x_context.shape[1]

        # Context embedding
        context_input = torch.cat([x_context, y_context], dim=-1)  # (batch, n, d_in+d_out)
        context_emb = self.input_embed(context_input)  # (batch, n, d_model)

        # Query embedding (y=0으로 padding)
        query_input = x_query.unsqueeze(1)  # (batch, 1, d_in)
        query_emb = self.query_embed(query_input)  # (batch, 1, d_model)

        # Concatenate
        seq = torch.cat([context_emb, query_emb], dim=1)  # (batch, n+1, d_model)

        # Transformer
        out = self.transformer(seq)  # (batch, n+1, d_model)

        # Query position output
        query_out = out[:, -1, :]  # (batch, d_model)

        return self.output(query_out)  # (batch, d_out)


def train_icl_model(n_epochs=100, batch_size=32):
    """
    ICL Transformer 학습
    """
    d_in, d_out = 10, 1
    n_context = 20

    # Dataset
    train_dataset = ICLRegressionDataset(
        n_tasks=10000, n_context=n_context, d_in=d_in, d_out=d_out
    )
    train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)

    # Model
    model = SimpleICLTransformer(d_in=d_in, d_out=d_out)
    optimizer = torch.optim.Adam(model.parameters(), lr=1e-3)

    # Training
    for epoch in range(n_epochs):
        total_loss = 0
        for batch in train_loader:
            optimizer.zero_grad()

            pred = model(
                batch['x_context'],
                batch['y_context'],
                batch['x_query']
            )

            loss = F.mse_loss(pred, batch['y_query'])
            loss.backward()
            optimizer.step()

            total_loss += loss.item()

        if (epoch + 1) % 10 == 0:
            print(f"Epoch {epoch+1}, Loss: {total_loss/len(train_loader):.6f}")

    return model


def analyze_icl_vs_ols(model, n_tests=100, n_context=20, d_in=10):
    """
    학습된 ICL 모델과 OLS 비교
    """
    model.eval()

    icl_errors = []
    ols_errors = []

    with torch.no_grad():
        for _ in range(n_tests):
            # 새로운 태스크
            w = torch.randn(d_in, 1)

            x_context = torch.randn(n_context, d_in)
            y_context = x_context @ w

            x_query = torch.randn(d_in)
            y_true = (x_query @ w).item()

            # ICL prediction
            y_icl = model(
                x_context.unsqueeze(0),
                y_context.unsqueeze(0),
                x_query.unsqueeze(0)
            ).item()

            # OLS prediction
            w_ols = torch.linalg.lstsq(x_context, y_context).solution
            y_ols = (x_query @ w_ols).item()

            icl_errors.append((y_icl - y_true) ** 2)
            ols_errors.append((y_ols - y_true) ** 2)

    print(f"ICL MSE: {np.mean(icl_errors):.6f}")
    print(f"OLS MSE: {np.mean(ols_errors):.6f}")
    print(f"Ratio (ICL/OLS): {np.mean(icl_errors)/np.mean(ols_errors):.4f}")

8.3 Context Length와 성능 관계 분석

import matplotlib.pyplot as plt

def analyze_context_length_effect(model, d_in=10, max_context=50):
    """
    Context 길이에 따른 ICL 성능 변화 분석
    """
    model.eval()

    context_lengths = list(range(5, max_context + 1, 5))
    icl_mses = []
    ols_mses = []

    n_tests = 100

    for n_ctx in context_lengths:
        icl_errors = []
        ols_errors = []

        with torch.no_grad():
            for _ in range(n_tests):
                w = torch.randn(d_in, 1)

                x_context = torch.randn(n_ctx, d_in)
                y_context = x_context @ w + 0.1 * torch.randn(n_ctx, 1)

                x_query = torch.randn(d_in)
                y_true = (x_query @ w).item()

                # ICL
                y_icl = model(
                    x_context.unsqueeze(0),
                    y_context.unsqueeze(0),
                    x_query.unsqueeze(0)
                ).item()

                # OLS
                w_ols = torch.linalg.lstsq(x_context, y_context).solution
                y_ols = (x_query @ w_ols).item()

                icl_errors.append((y_icl - y_true) ** 2)
                ols_errors.append((y_ols - y_true) ** 2)

        icl_mses.append(np.mean(icl_errors))
        ols_mses.append(np.mean(ols_errors))

    # Plotting
    plt.figure(figsize=(10, 6))
    plt.plot(context_lengths, icl_mses, 'b-o', label='ICL Transformer')
    plt.plot(context_lengths, ols_mses, 'r--s', label='OLS')
    plt.xlabel('Context Length')
    plt.ylabel('MSE')
    plt.title('ICL Performance vs Context Length')
    plt.legend()
    plt.grid(True)
    plt.savefig('icl_context_length.png', dpi=150, bbox_inches='tight')

    return context_lengths, icl_mses, ols_mses

9. 최신 연구 동향 (2024-2025)

9.1 주요 연구 방향

연구 방향 핵심 내용 대표 논문
State Space Models의 ICL Mamba 등 SSM의 ICL 능력 분석 Park et al., 2024; Yin & Steinhardt, 2025
Hybrid Architectures Transformer + SSM 조합 Understanding ICL Beyond Transformers, 2025
Scaling Laws ICL 성능과 모델 크기 관계 Memory Mosaics v2, 2025
Bayesian ICL 완전한 posterior 추론 Can Transformers Learn Full Bayesian Inference, ICML 2025

9.2 Open Problems

  1. Task Diversity와 ICL: 사전학습 태스크 다양성이 ICL에 미치는 영향
  2. Compositional ICL: 복잡한 태스크의 분해와 조합
  3. ICL의 한계: 어떤 태스크는 ICL이 불가능한가?
  4. Efficient ICL: Context 길이 제한 극복

10. 핵심 요약

개념 설명
In-Context Learning 파라미터 업데이트 없이 context에서 학습
Mesa-Optimization 모델 내부에서 최적화 알고리즘 수행
ICL as GD Linear attention이 gradient descent와 동치
GD++ Transformer가 수행하는 향상된 최적화 (curvature correction)
Stability ICL 일반화의 핵심 조건
Algorithm Selection Context에 따른 적응적 알고리즘 선택

참고 문헌

  1. Brown, T., et al. (2020). Language Models are Few-Shot Learners. NeurIPS.
  2. von Oswald, J., et al. (2023). Transformers learn in-context by gradient descent. ICML.
  3. Li, Y., et al. (2023). Transformers as Algorithms: Generalization and Stability in In-context Learning. ICML.
  4. Bai, Y., et al. (2024). Transformers as Statisticians: Provable In-Context Learning with In-Context Algorithm Selection. NeurIPS.
  5. Olsson, C., et al. (2022). In-context Learning and Induction Heads. Anthropic.
  6. Akyurek, E., et al. (2023). What learning algorithm is in-context learning? Investigations with linear models. ICLR.
  7. Garg, S., et al. (2022). What Can Transformers Learn In-Context? A Case Study of Simple Function Classes. NeurIPS.
  8. Li, G., et al. (2025). Transformers Meet In-Context Learning: A Universal Approximation Theory. arXiv.