Test-Time Training (TTT)¶
메타정보¶
| 항목 | 내용 |
|---|---|
| 논문 | Learning to (Learn at Test Time): RNNs with Expressive Hidden States |
| 저자 | Yu Sun, Xinhao Li, Karan Dalal, Chloe Hsu, Sanmi Koyejo, Carlos Guestrin, Xiaolong Wang, Tatsunori Hashimoto, Xinlei Chen |
| 발표 | ICML 2024 (Oral) |
| 후속 | TTT-E2E (2025), In-Place TTT (ICLR 2026), LaCT (2025) |
| arXiv | 2407.04620 |
| 키워드 | Test-Time Training, Sequence Modeling, RNN, Linear Attention, Associative Memory |
개요¶
Test-Time Training (TTT)은 추론 시점에 모델 가중치를 적응적으로 업데이트하는 시퀀스 모델링 패러다임이다. 기존 RNN/Attention의 recurrent state 대신, 학습 가능한 신경망(fast weight)을 사용하여 문맥 정보를 압축한다.
핵심 통찰: - Linear Attention과 DeltaNet은 TTT의 특수한 경우 - Recurrent state를 신경망으로 대체하면 표현력이 극대화됨 - Self-supervised loss로 fast weight를 업데이트 - 상태 크기를 임의로 확장 가능 (Mamba의 16 -> TTT의 1M+ 파라미터)
배경: Attention과 RNN의 트레이드오프¶
Attention의 강점과 한계¶
장점:
- 임의 거리 의존성 모델링
- 병렬 학습 가능
- 뛰어난 표현력
단점:
- O(T^2) 시간/공간 복잡도 (T: 시퀀스 길이)
- 추론 시 KV 캐시 선형 증가
- 긴 시퀀스에서 비효율적
RNN의 강점과 한계¶
TTT의 해결책¶
"RNN의 고정 상태를 신경망으로 대체하면?"
| 모델 | 상태 형태 | 상태 크기 |
|---|---|---|
| LSTM | 벡터 | O(d) |
| Mamba | 행렬 | O(d^2), 실제 16 |
| Linear Attention | 행렬 | O(d^2) |
| TTT | 신경망 | 임의 크기 |
수학적 정의¶
Attention (기준점)¶
q_t = x_t W_q (query)
k_t = x_t W_k (key)
v_t = x_t W_v (value)
y_t = softmax(q_t K_t^T / sqrt(d)) V_t
- KV 캐시 크기: O(T * d)
- 추론 복잡도: O(T) per token
Linear Attention¶
softmax 제거:
- 상태 크기: O(d^2)
- 업데이트: outer product 누적
- 문제: 망각 메커니즘 부재
TTT 정의¶
| 구성요소 | 정의 |
|---|---|
| Fast Weight | W_t (시퀀스 내에서 업데이트되는 가중치) |
| Slow Weight | W_q, W_k, W_v (사전 학습된 고정 가중치) |
| 업데이트 규칙 | W_t = W_{t-1} - eta_t * grad(L) |
| 쿼리 규칙 | y_t = f(q_t, W_t) |
손실 함수:
의미: "k로부터 v를 재구성"하는 연관 기억(associative memory)
TTT가 Linear Attention을 일반화하는 방법¶
예시 1: Linear Attention 유도¶
f를 선형 모델로, 손실을 음의 내적으로 설정:
f(k_t, W_{t-1}) = k_t W_{t-1}
L = -k_t W_{t-1} v_t^T
grad_W L = -k_t^T v_t
W_t = W_{t-1} + eta * k_t^T v_t
Linear Attention의 업데이트 규칙과 동일
예시 2: DeltaNet 유도¶
f를 선형 모델로, 손실을 MSE로 설정:
f(k_t, W_{t-1}) = k_t W_{t-1}
L = (1/2) ||k_t W_{t-1} - v_t||_2^2
grad_W L = k_t^T (k_t W_{t-1} - v_t)
W_t = W_{t-1} - eta * k_t^T (k_t W_{t-1} - v_t)
= W_{t-1} - eta * k_t^T k_t W_{t-1} + eta * k_t^T v_t
|___ forgetting ___| |_ inserting _|
MSE 손실이 자연스럽게 망각 메커니즘 생성
TTT 변형들¶
TTT-Linear¶
TTT-MLP¶
f(x, W) = MLP(x; W)
= W_2 * ReLU(W_1 * x + b_1) + b_2
파라미터: 2층 MLP (수백만 파라미터 가능)
장점: 높은 표현력
단점: 높은 계산 비용
TTT-E2E (End-to-End, 2025)¶
핵심 개선:
- 전체 시퀀스에 대해 end-to-end gradient 계산
- 128K 컨텍스트에서 full attention 대비 2.7배 빠름
- RNN과 유사한 상수 시간 추론
In-Place TTT (ICLR 2026)¶
LaCT (Large Chunk TTT, 2025)¶
문제: TTT의 낮은 arithmetic intensity
원인: 작은 청크 크기 (16)로 인한 memory-bound
해결: 청크 크기 확대 (16 -> 2048+)
부작용: 로컬 의존성 모델링 약화
대책: Sliding Window Attention 레이어 추가
병렬화: Mini-batch Gradient Descent¶
순차적 업데이트의 병렬화:
원래:
W_t = W_{t-1} - eta * grad_W L(W_{t-1}, k_t, v_t)
청크 기반 (병렬화 가능):
W_t = W_{t'} - eta * sum_{i=t'}^{t} grad_W L(W_{t'}, k_i, v_i)
where t' = t - (t mod B) # 청크 시작점
청크 내 gradient 계산이 독립적이므로 병렬 처리 가능
Arithmetic Intensity 분석¶
Fast weight 크기 (h x h), 입력 크기 (b x h):
Arithmetic Intensity r = FLOPs / Memory Access
= 2h^2 b / (2h^2 + 4hb)
= b / (1 + 2b/h)
<= min(h/2, b)
청크 크기 b에 의해 상한이 결정됨
| 청크 크기 | AI | 상태 |
|---|---|---|
| 16 | ~16 | Memory-bound |
| 256 | ~256 | Balanced |
| 2048+ | ~h/2 | Compute-bound |
Python 구현 예시¶
기본 TTT Layer¶
import torch
import torch.nn as nn
import torch.nn.functional as F
class TTTLayer(nn.Module):
"""
Test-Time Training Layer
Args:
d_model: 모델 차원
d_head: 헤드 차원
n_heads: 헤드 수
fast_net: 'linear' 또는 'mlp'
chunk_size: 병렬 처리 청크 크기
"""
def __init__(
self,
d_model: int,
d_head: int = 64,
n_heads: int = 8,
fast_net: str = 'linear',
chunk_size: int = 64
):
super().__init__()
self.d_model = d_model
self.d_head = d_head
self.n_heads = n_heads
self.chunk_size = chunk_size
# Slow weights (QKV projections)
self.W_q = nn.Linear(d_model, d_head * n_heads, bias=False)
self.W_k = nn.Linear(d_model, d_head * n_heads, bias=False)
self.W_v = nn.Linear(d_model, d_head * n_heads, bias=False)
self.W_o = nn.Linear(d_head * n_heads, d_model, bias=False)
# Fast weight (per-head)
if fast_net == 'linear':
# Fast weight: d_head x d_head per head
self.fast_weight_init = nn.Parameter(
torch.zeros(n_heads, d_head, d_head)
)
elif fast_net == 'mlp':
# 2-layer MLP
self.fast_w1 = nn.Parameter(
torch.randn(n_heads, d_head, d_head * 4) * 0.02
)
self.fast_w2 = nn.Parameter(
torch.randn(n_heads, d_head * 4, d_head) * 0.02
)
self.fast_net = fast_net
# Data-dependent learning rate
self.lr_proj = nn.Linear(d_model, n_heads)
def forward(self, x: torch.Tensor) -> torch.Tensor:
"""
Args:
x: (batch, seq_len, d_model)
Returns:
output: (batch, seq_len, d_model)
"""
B, T, D = x.shape
# Project to Q, K, V
q = self.W_q(x).view(B, T, self.n_heads, self.d_head)
k = self.W_k(x).view(B, T, self.n_heads, self.d_head)
v = self.W_v(x).view(B, T, self.n_heads, self.d_head)
# Data-dependent learning rate
eta = torch.sigmoid(self.lr_proj(x)) # (B, T, n_heads)
# Initialize fast weights
W = self.fast_weight_init.unsqueeze(0).expand(
B, -1, -1, -1
).clone() # (B, n_heads, d_head, d_head)
outputs = []
# Process in chunks for parallelization
for chunk_start in range(0, T, self.chunk_size):
chunk_end = min(chunk_start + self.chunk_size, T)
# Get chunk
k_chunk = k[:, chunk_start:chunk_end] # (B, chunk, heads, d)
v_chunk = v[:, chunk_start:chunk_end]
q_chunk = q[:, chunk_start:chunk_end]
eta_chunk = eta[:, chunk_start:chunk_end]
# Compute gradients for all positions in chunk
# (using chunk-start state W)
grad_sum = self._compute_chunk_gradients(
k_chunk, v_chunk, W
)
# Update fast weights
for t in range(chunk_end - chunk_start):
# Query with current fast weight
y_t = self._fast_forward(
q_chunk[:, t], W
) # (B, heads, d)
outputs.append(y_t)
# Update W
W = W - eta_chunk[:, t, :, None, None] * grad_sum[:, t]
# Stack outputs
output = torch.stack(outputs, dim=1) # (B, T, heads, d)
output = output.view(B, T, -1)
return self.W_o(output)
def _fast_forward(
self,
x: torch.Tensor,
W: torch.Tensor
) -> torch.Tensor:
"""
Fast network forward pass
Args:
x: (B, heads, d_head)
W: (B, heads, d_head, d_head)
"""
if self.fast_net == 'linear':
# Linear: y = xW
return torch.einsum('bhd,bhde->bhe', x, W)
else:
# MLP: y = W2 * ReLU(W1 * x)
h = F.relu(torch.einsum('bhd,hde->bhe', x, self.fast_w1))
return torch.einsum('bhe,hed->bhd', h, self.fast_w2)
def _compute_chunk_gradients(
self,
k: torch.Tensor,
v: torch.Tensor,
W: torch.Tensor
) -> torch.Tensor:
"""
Compute gradients for all positions in chunk
Args:
k: (B, chunk, heads, d)
v: (B, chunk, heads, d)
W: (B, heads, d, d)
Returns:
gradients: (B, chunk, heads, d, d)
"""
B, chunk_size, H, D = k.shape
# Predict v from k using current W
# pred_v = k @ W: (B, chunk, heads, d)
pred_v = torch.einsum('bchd,bhde->bche', k, W)
# Error: pred_v - v
error = pred_v - v # (B, chunk, heads, d)
# Gradient: k^T @ error (for MSE loss)
# grad shape: (B, chunk, heads, d, d)
grad = torch.einsum('bchd,bche->bchde', k, error)
return grad
사용 예시¶
# 모델 생성
ttt_layer = TTTLayer(
d_model=512,
d_head=64,
n_heads=8,
fast_net='linear',
chunk_size=64
)
# 입력
x = torch.randn(2, 1024, 512) # (batch, seq, dim)
# Forward
output = ttt_layer(x)
print(output.shape) # (2, 1024, 512)
Benchmark: TTT vs Attention¶
import time
def benchmark(model, x, name, n_iter=100):
# Warmup
for _ in range(10):
_ = model(x)
torch.cuda.synchronize()
start = time.time()
for _ in range(n_iter):
_ = model(x)
torch.cuda.synchronize()
elapsed = (time.time() - start) / n_iter
print(f"{name}: {elapsed*1000:.2f} ms/iter")
# 비교
d_model = 512
seq_lens = [1024, 4096, 16384]
for seq_len in seq_lens:
x = torch.randn(1, seq_len, d_model).cuda()
ttt = TTTLayer(d_model, chunk_size=128).cuda()
attn = nn.MultiheadAttention(d_model, 8).cuda()
print(f"\nSeq len: {seq_len}")
benchmark(ttt, x, "TTT")
benchmark(lambda y: attn(y, y, y)[0], x, "Attention")
성능 비교¶
언어 모델링 (Perplexity)¶
| 모델 | 파라미터 | WikiText PPL | 추론 복잡도 |
|---|---|---|---|
| Transformer | 125M | 29.1 | O(T) per token |
| Mamba | 130M | 28.8 | O(1) per token |
| TTT-Linear | 125M | 28.5 | O(1) per token |
| TTT-MLP | 125M | 27.2 | O(1) per token |
Long Context (SCROLLS benchmark)¶
| 컨텍스트 길이 | Transformer | Mamba | TTT-MLP |
|---|---|---|---|
| 8K | 72.3 | 71.8 | 73.1 |
| 32K | 68.1 | 69.5 | 71.8 |
| 128K | OOM | 66.2 | 70.4 |
Few-shot Learning (BBH, In-Place TTT)¶
| 방법 | 0-shot | 5-shot | TTT |
|---|---|---|---|
| GPT-3 175B | 48.2 | 52.8 | - |
| Llama-2 70B | 51.4 | 55.3 | 61.2 |
장단점¶
장점¶
| 항목 | 설명 |
|---|---|
| 표현력 | 상태 크기를 임의로 확장 가능 (Linear Attention의 한계 극복) |
| 효율성 | 추론 시 O(1) 복잡도 (Attention의 O(T) 대비) |
| 일반성 | Linear Attention, DeltaNet 등을 특수 케이스로 포함 |
| 적응성 | 추론 시점에 동적으로 문맥 적응 |
단점¶
| 항목 | 설명 |
|---|---|
| 학습 비용 | Gradient 계산으로 인한 추가 연산 |
| 구현 복잡도 | 기존 Attention 대비 복잡한 커스텀 커널 필요 |
| Memory-bound | 작은 청크에서 낮은 arithmetic intensity |
| 하이퍼파라미터 | 청크 크기, 학습률 스케줄 등 추가 튜닝 필요 |
관련 연구¶
선행 연구¶
| 연구 | 관계 |
|---|---|
| Linear Attention (2020) | TTT의 특수 케이스 (손실 = 음의 내적) |
| DeltaNet (2021) | TTT의 특수 케이스 (MSE 손실) |
| Fast Weight Programmers (1992) | 개념적 선행 연구 |
| Hopfield Networks | 연관 기억의 원형 |
후속 연구¶
| 연구 | 기여 |
|---|---|
| TTT-E2E (2025) | End-to-end gradient로 성능 향상 |
| In-Place TTT (2026) | 기존 LLM에 TTT 통합 |
| LaCT (2025) | 대형 청크로 효율성 개선 |
| Test-Time Training Done Right (2025) | GPU 효율적 구현 |
핵심 요약¶
- TTT는 추론 시 모델 가중치를 업데이트하는 시퀀스 모델링 패러다임
- Linear Attention과 DeltaNet은 TTT의 특수한 경우
- Fast weight를 MLP로 확장하면 표현력이 극적으로 증가
- O(1) 추론 복잡도로 긴 컨텍스트에 유리
- 청크 크기와 arithmetic intensity 트레이드오프 존재