콘텐츠로 이동
Data Prep
상세

Test-Time Training (TTT)

메타정보

항목 내용
논문 Learning to (Learn at Test Time): RNNs with Expressive Hidden States
저자 Yu Sun, Xinhao Li, Karan Dalal, Chloe Hsu, Sanmi Koyejo, Carlos Guestrin, Xiaolong Wang, Tatsunori Hashimoto, Xinlei Chen
발표 ICML 2024 (Oral)
후속 TTT-E2E (2025), In-Place TTT (ICLR 2026), LaCT (2025)
arXiv 2407.04620
키워드 Test-Time Training, Sequence Modeling, RNN, Linear Attention, Associative Memory

개요

Test-Time Training (TTT)은 추론 시점에 모델 가중치를 적응적으로 업데이트하는 시퀀스 모델링 패러다임이다. 기존 RNN/Attention의 recurrent state 대신, 학습 가능한 신경망(fast weight)을 사용하여 문맥 정보를 압축한다.

핵심 통찰: - Linear Attention과 DeltaNet은 TTT의 특수한 경우 - Recurrent state를 신경망으로 대체하면 표현력이 극대화됨 - Self-supervised loss로 fast weight를 업데이트 - 상태 크기를 임의로 확장 가능 (Mamba의 16 -> TTT의 1M+ 파라미터)


배경: Attention과 RNN의 트레이드오프

Attention의 강점과 한계

장점:
- 임의 거리 의존성 모델링
- 병렬 학습 가능
- 뛰어난 표현력

단점:
- O(T^2) 시간/공간 복잡도 (T: 시퀀스 길이)
- 추론 시 KV 캐시 선형 증가
- 긴 시퀀스에서 비효율적

RNN의 강점과 한계

장점:
- O(T) 시간 복잡도
- 고정 크기 상태
- 효율적 추론

단점:
- 제한된 상태 크기 (보통 수백~수천 차원)
- 장기 의존성 학습 어려움
- 병렬 학습의 한계

TTT의 해결책

"RNN의 고정 상태를 신경망으로 대체하면?"

모델 상태 형태 상태 크기
LSTM 벡터 O(d)
Mamba 행렬 O(d^2), 실제 16
Linear Attention 행렬 O(d^2)
TTT 신경망 임의 크기

수학적 정의

Attention (기준점)

q_t = x_t W_q    (query)
k_t = x_t W_k    (key)
v_t = x_t W_v    (value)

y_t = softmax(q_t K_t^T / sqrt(d)) V_t
  • KV 캐시 크기: O(T * d)
  • 추론 복잡도: O(T) per token

Linear Attention

softmax 제거:

y_t = q_t (K_t^T V_t)
    = q_t S_t

where S_t = S_{t-1} + k_t^T v_t
  • 상태 크기: O(d^2)
  • 업데이트: outer product 누적
  • 문제: 망각 메커니즘 부재

TTT 정의

구성요소 정의
Fast Weight W_t (시퀀스 내에서 업데이트되는 가중치)
Slow Weight W_q, W_k, W_v (사전 학습된 고정 가중치)
업데이트 규칙 W_t = W_{t-1} - eta_t * grad(L)
쿼리 규칙 y_t = f(q_t, W_t)

손실 함수:

L(k_t, v_t, W_{t-1}) = ||f(k_t, W_{t-1}) - v_t||_2^2

의미: "k로부터 v를 재구성"하는 연관 기억(associative memory)


TTT가 Linear Attention을 일반화하는 방법

예시 1: Linear Attention 유도

f를 선형 모델로, 손실을 음의 내적으로 설정:

f(k_t, W_{t-1}) = k_t W_{t-1}
L = -k_t W_{t-1} v_t^T

grad_W L = -k_t^T v_t

W_t = W_{t-1} + eta * k_t^T v_t

Linear Attention의 업데이트 규칙과 동일

예시 2: DeltaNet 유도

f를 선형 모델로, 손실을 MSE로 설정:

f(k_t, W_{t-1}) = k_t W_{t-1}
L = (1/2) ||k_t W_{t-1} - v_t||_2^2

grad_W L = k_t^T (k_t W_{t-1} - v_t)

W_t = W_{t-1} - eta * k_t^T (k_t W_{t-1} - v_t)
    = W_{t-1} - eta * k_t^T k_t W_{t-1} + eta * k_t^T v_t
              |___ forgetting ___|      |_ inserting _|

MSE 손실이 자연스럽게 망각 메커니즘 생성


TTT 변형들

TTT-Linear

f(x, W) = xW

파라미터: d x d 행렬 (Linear Attention과 동일)
장점: 계산 효율성
단점: 제한된 표현력

TTT-MLP

f(x, W) = MLP(x; W)
        = W_2 * ReLU(W_1 * x + b_1) + b_2

파라미터: 2층 MLP (수백만 파라미터 가능)
장점: 높은 표현력
단점: 높은 계산 비용

TTT-E2E (End-to-End, 2025)

핵심 개선:
- 전체 시퀀스에 대해 end-to-end gradient 계산
- 128K 컨텍스트에서 full attention 대비 2.7배 빠름
- RNN과 유사한 상수 시간 추론

In-Place TTT (ICLR 2026)

핵심 개선:
- 기존 LLM에 TTT 능력 추가 (아키텍처 변경 없음)
- Test-time에 in-place로 모델 적응
- Few-shot 성능 대폭 향상

LaCT (Large Chunk TTT, 2025)

문제: TTT의 낮은 arithmetic intensity
원인: 작은 청크 크기 (16)로 인한 memory-bound

해결: 청크 크기 확대 (16 -> 2048+)
부작용: 로컬 의존성 모델링 약화
대책: Sliding Window Attention 레이어 추가

병렬화: Mini-batch Gradient Descent

순차적 업데이트의 병렬화:

원래:
W_t = W_{t-1} - eta * grad_W L(W_{t-1}, k_t, v_t)

청크 기반 (병렬화 가능):
W_t = W_{t'} - eta * sum_{i=t'}^{t} grad_W L(W_{t'}, k_i, v_i)

where t' = t - (t mod B)  # 청크 시작점

청크 내 gradient 계산이 독립적이므로 병렬 처리 가능


Arithmetic Intensity 분석

Fast weight 크기 (h x h), 입력 크기 (b x h):

Arithmetic Intensity r = FLOPs / Memory Access
                       = 2h^2 b / (2h^2 + 4hb)
                       = b / (1 + 2b/h)
                       <= min(h/2, b)

청크 크기 b에 의해 상한이 결정됨

청크 크기 AI 상태
16 ~16 Memory-bound
256 ~256 Balanced
2048+ ~h/2 Compute-bound

Python 구현 예시

기본 TTT Layer

import torch
import torch.nn as nn
import torch.nn.functional as F

class TTTLayer(nn.Module):
    """
    Test-Time Training Layer

    Args:
        d_model: 모델 차원
        d_head: 헤드 차원
        n_heads: 헤드 수
        fast_net: 'linear' 또는 'mlp'
        chunk_size: 병렬 처리 청크 크기
    """
    def __init__(
        self, 
        d_model: int, 
        d_head: int = 64, 
        n_heads: int = 8,
        fast_net: str = 'linear',
        chunk_size: int = 64
    ):
        super().__init__()
        self.d_model = d_model
        self.d_head = d_head
        self.n_heads = n_heads
        self.chunk_size = chunk_size

        # Slow weights (QKV projections)
        self.W_q = nn.Linear(d_model, d_head * n_heads, bias=False)
        self.W_k = nn.Linear(d_model, d_head * n_heads, bias=False)
        self.W_v = nn.Linear(d_model, d_head * n_heads, bias=False)
        self.W_o = nn.Linear(d_head * n_heads, d_model, bias=False)

        # Fast weight (per-head)
        if fast_net == 'linear':
            # Fast weight: d_head x d_head per head
            self.fast_weight_init = nn.Parameter(
                torch.zeros(n_heads, d_head, d_head)
            )
        elif fast_net == 'mlp':
            # 2-layer MLP
            self.fast_w1 = nn.Parameter(
                torch.randn(n_heads, d_head, d_head * 4) * 0.02
            )
            self.fast_w2 = nn.Parameter(
                torch.randn(n_heads, d_head * 4, d_head) * 0.02
            )

        self.fast_net = fast_net

        # Data-dependent learning rate
        self.lr_proj = nn.Linear(d_model, n_heads)

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        """
        Args:
            x: (batch, seq_len, d_model)
        Returns:
            output: (batch, seq_len, d_model)
        """
        B, T, D = x.shape

        # Project to Q, K, V
        q = self.W_q(x).view(B, T, self.n_heads, self.d_head)
        k = self.W_k(x).view(B, T, self.n_heads, self.d_head)
        v = self.W_v(x).view(B, T, self.n_heads, self.d_head)

        # Data-dependent learning rate
        eta = torch.sigmoid(self.lr_proj(x))  # (B, T, n_heads)

        # Initialize fast weights
        W = self.fast_weight_init.unsqueeze(0).expand(
            B, -1, -1, -1
        ).clone()  # (B, n_heads, d_head, d_head)

        outputs = []

        # Process in chunks for parallelization
        for chunk_start in range(0, T, self.chunk_size):
            chunk_end = min(chunk_start + self.chunk_size, T)

            # Get chunk
            k_chunk = k[:, chunk_start:chunk_end]  # (B, chunk, heads, d)
            v_chunk = v[:, chunk_start:chunk_end]
            q_chunk = q[:, chunk_start:chunk_end]
            eta_chunk = eta[:, chunk_start:chunk_end]

            # Compute gradients for all positions in chunk
            # (using chunk-start state W)
            grad_sum = self._compute_chunk_gradients(
                k_chunk, v_chunk, W
            )

            # Update fast weights
            for t in range(chunk_end - chunk_start):
                # Query with current fast weight
                y_t = self._fast_forward(
                    q_chunk[:, t], W
                )  # (B, heads, d)
                outputs.append(y_t)

                # Update W
                W = W - eta_chunk[:, t, :, None, None] * grad_sum[:, t]

        # Stack outputs
        output = torch.stack(outputs, dim=1)  # (B, T, heads, d)
        output = output.view(B, T, -1)

        return self.W_o(output)

    def _fast_forward(
        self, 
        x: torch.Tensor, 
        W: torch.Tensor
    ) -> torch.Tensor:
        """
        Fast network forward pass

        Args:
            x: (B, heads, d_head)
            W: (B, heads, d_head, d_head)
        """
        if self.fast_net == 'linear':
            # Linear: y = xW
            return torch.einsum('bhd,bhde->bhe', x, W)
        else:
            # MLP: y = W2 * ReLU(W1 * x)
            h = F.relu(torch.einsum('bhd,hde->bhe', x, self.fast_w1))
            return torch.einsum('bhe,hed->bhd', h, self.fast_w2)

    def _compute_chunk_gradients(
        self,
        k: torch.Tensor,
        v: torch.Tensor,
        W: torch.Tensor
    ) -> torch.Tensor:
        """
        Compute gradients for all positions in chunk

        Args:
            k: (B, chunk, heads, d)
            v: (B, chunk, heads, d)
            W: (B, heads, d, d)
        Returns:
            gradients: (B, chunk, heads, d, d)
        """
        B, chunk_size, H, D = k.shape

        # Predict v from k using current W
        # pred_v = k @ W: (B, chunk, heads, d)
        pred_v = torch.einsum('bchd,bhde->bche', k, W)

        # Error: pred_v - v
        error = pred_v - v  # (B, chunk, heads, d)

        # Gradient: k^T @ error (for MSE loss)
        # grad shape: (B, chunk, heads, d, d)
        grad = torch.einsum('bchd,bche->bchde', k, error)

        return grad

사용 예시

# 모델 생성
ttt_layer = TTTLayer(
    d_model=512,
    d_head=64,
    n_heads=8,
    fast_net='linear',
    chunk_size=64
)

# 입력
x = torch.randn(2, 1024, 512)  # (batch, seq, dim)

# Forward
output = ttt_layer(x)
print(output.shape)  # (2, 1024, 512)

Benchmark: TTT vs Attention

import time

def benchmark(model, x, name, n_iter=100):
    # Warmup
    for _ in range(10):
        _ = model(x)

    torch.cuda.synchronize()
    start = time.time()

    for _ in range(n_iter):
        _ = model(x)

    torch.cuda.synchronize()
    elapsed = (time.time() - start) / n_iter

    print(f"{name}: {elapsed*1000:.2f} ms/iter")

# 비교
d_model = 512
seq_lens = [1024, 4096, 16384]

for seq_len in seq_lens:
    x = torch.randn(1, seq_len, d_model).cuda()

    ttt = TTTLayer(d_model, chunk_size=128).cuda()
    attn = nn.MultiheadAttention(d_model, 8).cuda()

    print(f"\nSeq len: {seq_len}")
    benchmark(ttt, x, "TTT")
    benchmark(lambda y: attn(y, y, y)[0], x, "Attention")

성능 비교

언어 모델링 (Perplexity)

모델 파라미터 WikiText PPL 추론 복잡도
Transformer 125M 29.1 O(T) per token
Mamba 130M 28.8 O(1) per token
TTT-Linear 125M 28.5 O(1) per token
TTT-MLP 125M 27.2 O(1) per token

Long Context (SCROLLS benchmark)

컨텍스트 길이 Transformer Mamba TTT-MLP
8K 72.3 71.8 73.1
32K 68.1 69.5 71.8
128K OOM 66.2 70.4

Few-shot Learning (BBH, In-Place TTT)

방법 0-shot 5-shot TTT
GPT-3 175B 48.2 52.8 -
Llama-2 70B 51.4 55.3 61.2

장단점

장점

항목 설명
표현력 상태 크기를 임의로 확장 가능 (Linear Attention의 한계 극복)
효율성 추론 시 O(1) 복잡도 (Attention의 O(T) 대비)
일반성 Linear Attention, DeltaNet 등을 특수 케이스로 포함
적응성 추론 시점에 동적으로 문맥 적응

단점

항목 설명
학습 비용 Gradient 계산으로 인한 추가 연산
구현 복잡도 기존 Attention 대비 복잡한 커스텀 커널 필요
Memory-bound 작은 청크에서 낮은 arithmetic intensity
하이퍼파라미터 청크 크기, 학습률 스케줄 등 추가 튜닝 필요

관련 연구

선행 연구

연구 관계
Linear Attention (2020) TTT의 특수 케이스 (손실 = 음의 내적)
DeltaNet (2021) TTT의 특수 케이스 (MSE 손실)
Fast Weight Programmers (1992) 개념적 선행 연구
Hopfield Networks 연관 기억의 원형

후속 연구

연구 기여
TTT-E2E (2025) End-to-end gradient로 성능 향상
In-Place TTT (2026) 기존 LLM에 TTT 통합
LaCT (2025) 대형 청크로 효율성 개선
Test-Time Training Done Right (2025) GPU 효율적 구현

핵심 요약

  1. TTT는 추론 시 모델 가중치를 업데이트하는 시퀀스 모델링 패러다임
  2. Linear Attention과 DeltaNet은 TTT의 특수한 경우
  3. Fast weight를 MLP로 확장하면 표현력이 극적으로 증가
  4. O(1) 추론 복잡도로 긴 컨텍스트에 유리
  5. 청크 크기와 arithmetic intensity 트레이드오프 존재

참고 자료