콘텐츠로 이동

Long Context Extension

개요

LLM의 컨텍스트 길이를 확장하는 기법들에 대한 문서다. 기본 Transformer의 어텐션은 시퀀스 길이의 제곱에 비례하는 메모리와 계산이 필요하므로, 긴 컨텍스트를 효율적으로 처리하기 위한 다양한 방법이 개발되었다.

핵심 개념

위치 인코딩의 중요성

Transformer는 위치 정보가 없으므로 위치 인코딩이 필수다. 긴 컨텍스트에서는 위치 인코딩이 학습 범위를 벗어난 위치에도 일반화되어야 한다.

Position Encoding 발전:
Sinusoidal -> Learned -> RoPE -> ALiBi -> YaRN

RoPE (Rotary Position Embedding)

위치 정보를 복소수 회전으로 인코딩:

q_m = R_m * q
k_n = R_n * k

where R_theta is rotation matrix:
R_theta = [[cos(theta), -sin(theta)],
           [sin(theta),  cos(theta)]]

Attention becomes:
a_{m,n} = Re[(R_m * q)^* (R_n * k)]
        = Re[q^* R_{n-m} k]
        = f(m-n)  # 상대 위치만 의존

ALiBi (Attention with Linear Biases)

어텐션 점수에 거리 기반 선형 편향 추가:

Attention(Q, K, V) = softmax(QK^T / sqrt(d) - m * |i-j|) V

where:
  m = head-specific slope
  |i-j| = position distance

아키텍처 다이어그램

RoPE 적용 과정

    Query q, Key k (both: d-dimensional)
              |
              v
    +---------------------+
    | Split into pairs    |
    | [q1,q2], [q3,q4]... |
    +----------+----------+
               |
               v
    +---------------------+
    | Apply 2D rotation   |
    | at position m       |
    |                     |
    | [q1'] = [cos(m*t1) -sin(m*t1)] [q1]
    | [q2']   [sin(m*t1)  cos(m*t1)] [q2]
    |                     |
    | theta_i = 10000^(-2i/d)
    +----------+----------+
               |
               v
    Rotated Query q'_m
    (similarly for Key k'_n)

Position Interpolation (PI)

    Original RoPE: positions [0, 1, 2, ..., L]

    For context extension to 4L:

    +---------------------------------------------+
    |  Linear Interpolation (PI)                  |
    |                                             |
    |  New positions: [0, 0.25, 0.5, 0.75, 1, ...]|
    |                                             |
    |  theta'(pos) = theta(pos / scale)           |
    |  where scale = target_len / original_len    |
    +---------------------------------------------+

    Problem: Compresses high-frequency components

YaRN (Yet another RoPE extensioN)

    YaRN Strategy: Different treatment by frequency

    RoPE Dimensions:
    +------------+------------+------------+
    | Low Freq   | Medium     | High Freq  |
    | (dim 0-16) | (dim 16-48)| (dim 48-64)|
    +------------+------------+------------+
         |            |             |
         v            v             v
    No scaling    Interpolate    NTK-aware
    (preserve)    (linear PI)    (extend)

    Formula:
    theta'(d) = theta(d) * (
        1                        if d < d_low
        (1-gamma) + gamma*s      if d_low <= d < d_high  
        s                        if d >= d_high
    )

    where s = scale factor, gamma = interpolation weight

Sliding Window Attention

    Full Attention (O(n^2)):
    +---+---+---+---+---+---+---+---+
    | 1 | 1 | 1 | 1 | 1 | 1 | 1 | 1 |
    | 1 | 1 | 1 | 1 | 1 | 1 | 1 | 1 |
    | 1 | 1 | 1 | 1 | 1 | 1 | 1 | 1 |
    | 1 | 1 | 1 | 1 | 1 | 1 | 1 | 1 |
    | 1 | 1 | 1 | 1 | 1 | 1 | 1 | 1 |
    | 1 | 1 | 1 | 1 | 1 | 1 | 1 | 1 |
    | 1 | 1 | 1 | 1 | 1 | 1 | 1 | 1 |
    | 1 | 1 | 1 | 1 | 1 | 1 | 1 | 1 |
    +---+---+---+---+---+---+---+---+

    Sliding Window (O(n*w)):
    +---+---+---+---+---+---+---+---+
    | 1 | 1 | 1 | 0 | 0 | 0 | 0 | 0 |
    | 1 | 1 | 1 | 1 | 0 | 0 | 0 | 0 |
    | 1 | 1 | 1 | 1 | 1 | 0 | 0 | 0 |
    | 0 | 1 | 1 | 1 | 1 | 1 | 0 | 0 |
    | 0 | 0 | 1 | 1 | 1 | 1 | 1 | 0 |
    | 0 | 0 | 0 | 1 | 1 | 1 | 1 | 1 |
    | 0 | 0 | 0 | 0 | 1 | 1 | 1 | 1 |
    | 0 | 0 | 0 | 0 | 0 | 1 | 1 | 1 |
    +---+---+---+---+---+---+---+---+

    w = window size (e.g., 4096)

Multi-Scale Attention (Longformer style)

    Combining local and global attention:

    +---+---+---+---+---+---+---+---+
    | G | G | G | G | G | G | G | G |  <- Global tokens
    | G | L | L | L | 0 | 0 | 0 | 0 |  <- Local window
    | G | L | L | L | L | 0 | 0 | 0 |     for each
    | G | L | L | L | L | L | 0 | 0 |     position
    | G | 0 | L | L | L | L | L | 0 |
    | G | 0 | 0 | L | L | L | L | L |
    | G | 0 | 0 | 0 | L | L | L | L |
    | G | 0 | 0 | 0 | 0 | L | L | L |
    +---+---+---+---+---+---+---+---+

    G = Global attention (to/from special tokens)
    L = Local sliding window attention
    0 = Masked (no attention)

Ring Attention (Distributed)

    Ring Attention for very long sequences across devices:

    Device 0          Device 1          Device 2
    +--------+        +--------+        +--------+
    | Q[0:n] |        | Q[n:2n]|        |Q[2n:3n]|
    | K[0:n] |        | K[n:2n]|        |K[2n:3n]|
    | V[0:n] |        | V[n:2n]|        |V[2n:3n]|
    +---+----+        +---+----+        +---+----+
        |                 |                 |
        +-----------------+-----------------+
                    Ring Communication
        +-----------------+-----------------+
        |                 |                 |
        v                 v                 v
    Compute           Compute           Compute
    Attention        Attention         Attention
    (partial)        (partial)         (partial)
        |                 |                 |
        +-----------------+-----------------+
                    Accumulate

LongRoPE 아키텍처

    LongRoPE: Two-stage extension

    Stage 1: Search optimal rescale factors
    +----------------------------------------+
    | Input: Base model (e.g., 4K context)   |
    | Target: Extend to 256K                 |
    |                                        |
    | Search space:                          |
    | - Lambda factors for each RoPE dim     |
    | - Non-uniform interpolation            |
    +----------------------------------------+
               |
               v
    Stage 2: Fine-tune with progressive extension
    +----------------------------------------+
    | 4K -> 64K -> 128K -> 256K -> 2M        |
    |                                        |
    | Short context: aggressive interpolation|
    | Long context: conservative             |
    +----------------------------------------+

대표 기법 비교

기법 최대 컨텍스트 Fine-tuning 외삽 능력
Sinusoidal 학습 길이 필요 없음
RoPE 학습 길이 필요 제한적
ALiBi 무제한 불필요 우수
PI (Position Interpolation) ~4-8x 필요 중간
NTK-aware ~4-8x 필요 중간
YaRN ~16-32x 필요 우수
LongRoPE 2M+ 필요 매우 우수

대표 모델

모델 컨텍스트 길이 기법
GPT-4 Turbo 128K 비공개
Claude 3 200K 비공개
Gemini 1.5 Pro 1M+ 비공개
Llama 3.1 128K RoPE + 확장
Mistral 32K Sliding Window
Yi-34B 200K YaRN
Command R 128K 비공개
Jamba 1.5 256K Hybrid (SSM)

장단점

Position Interpolation (PI)

장점: - 간단한 구현 - 적은 fine-tuning으로 확장 가능

단점: - 고주파 정보 손실 - 확장 비율 제한 (4-8x)

ALiBi

장점: - Fine-tuning 없이 외삽 가능 - 구현 간단

단점: - 일부 태스크에서 성능 저하 - RoPE 대비 in-context learning 약함

YaRN

장점: - 큰 확장 비율 지원 - 품질 저하 최소화

단점: - 하이퍼파라미터 튜닝 필요 - Fine-tuning 필요

Sliding Window

장점: - O(n) 메모리로 긴 시퀀스 처리 - 로컬 컨텍스트에 효율적

단점: - 장거리 의존성 직접 참조 불가 - 전역 정보 손실 가능

코드 예시

RoPE 구현

import torch
import torch.nn as nn
import math

class RotaryPositionEmbedding(nn.Module):
    def __init__(self, dim: int, max_seq_len: int = 8192, base: float = 10000.0):
        super().__init__()
        self.dim = dim
        self.max_seq_len = max_seq_len
        self.base = base

        # Compute inverse frequencies
        inv_freq = 1.0 / (base ** (torch.arange(0, dim, 2).float() / dim))
        self.register_buffer("inv_freq", inv_freq)

        # Precompute rotary embeddings
        self._build_cache(max_seq_len)

    def _build_cache(self, seq_len: int):
        positions = torch.arange(seq_len, dtype=self.inv_freq.dtype)
        freqs = torch.einsum("i,j->ij", positions, self.inv_freq)

        # [seq_len, dim/2] -> [seq_len, dim]
        emb = torch.cat((freqs, freqs), dim=-1)

        self.register_buffer("cos_cached", emb.cos())
        self.register_buffer("sin_cached", emb.sin())

    def forward(self, x, seq_len: int = None):
        # x: (batch, n_heads, seq_len, head_dim)
        if seq_len is None:
            seq_len = x.shape[2]

        return (
            self.cos_cached[:seq_len],
            self.sin_cached[:seq_len]
        )


def rotate_half(x):
    """Rotate half the hidden dims of the input."""
    x1 = x[..., : x.shape[-1] // 2]
    x2 = x[..., x.shape[-1] // 2 :]
    return torch.cat((-x2, x1), dim=-1)


def apply_rotary_pos_emb(q, k, cos, sin):
    """Apply rotary position embedding to queries and keys."""
    # cos, sin: (seq_len, dim)
    # q, k: (batch, n_heads, seq_len, head_dim)

    cos = cos.unsqueeze(0).unsqueeze(0)  # (1, 1, seq_len, dim)
    sin = sin.unsqueeze(0).unsqueeze(0)

    q_embed = (q * cos) + (rotate_half(q) * sin)
    k_embed = (k * cos) + (rotate_half(k) * sin)

    return q_embed, k_embed

Position Interpolation

class ScaledRotaryEmbedding(RotaryPositionEmbedding):
    """RoPE with Position Interpolation for context extension."""

    def __init__(
        self,
        dim: int,
        max_seq_len: int = 8192,
        base: float = 10000.0,
        scaling_factor: float = 1.0  # target_len / original_len
    ):
        self.scaling_factor = scaling_factor
        super().__init__(dim, max_seq_len, base)

    def _build_cache(self, seq_len: int):
        # Scale positions for interpolation
        positions = torch.arange(seq_len, dtype=self.inv_freq.dtype)
        positions = positions / self.scaling_factor  # Key difference

        freqs = torch.einsum("i,j->ij", positions, self.inv_freq)
        emb = torch.cat((freqs, freqs), dim=-1)

        self.register_buffer("cos_cached", emb.cos())
        self.register_buffer("sin_cached", emb.sin())

YaRN 구현

class YaRNRotaryEmbedding(nn.Module):
    """
    YaRN: Yet another RoPE extensioN

    Combines:
    - NTK-aware interpolation for high frequencies
    - Linear interpolation for medium frequencies  
    - No scaling for low frequencies
    """

    def __init__(
        self,
        dim: int,
        max_seq_len: int = 8192,
        base: float = 10000.0,
        scale: float = 1.0,
        original_max_seq_len: int = 4096,
        beta_fast: float = 32.0,
        beta_slow: float = 1.0,
    ):
        super().__init__()
        self.dim = dim
        self.scale = scale
        self.original_max_seq_len = original_max_seq_len

        # Compute interpolation factors per dimension
        inv_freq = 1.0 / (base ** (torch.arange(0, dim, 2).float() / dim))

        # Find wavelength boundaries
        low_freq_wavelen = original_max_seq_len / beta_slow
        high_freq_wavelen = original_max_seq_len / beta_fast

        wavelen = 2 * math.pi / inv_freq

        # Compute per-dimension scaling
        scaling_factors = torch.ones_like(inv_freq)

        for i, wl in enumerate(wavelen):
            if wl < high_freq_wavelen:
                # High frequency: no scaling
                scaling_factors[i] = 1.0
            elif wl > low_freq_wavelen:
                # Low frequency: full scaling
                scaling_factors[i] = scale
            else:
                # Medium frequency: interpolate
                smooth = (wl - high_freq_wavelen) / (low_freq_wavelen - high_freq_wavelen)
                scaling_factors[i] = 1.0 + (scale - 1.0) * smooth

        self.register_buffer("inv_freq", inv_freq / scaling_factors)
        self._build_cache(max_seq_len)

    def _build_cache(self, seq_len: int):
        positions = torch.arange(seq_len, dtype=self.inv_freq.dtype)
        freqs = torch.einsum("i,j->ij", positions, self.inv_freq)
        emb = torch.cat((freqs, freqs), dim=-1)

        # Apply YaRN attention scaling
        mscale = 0.1 * math.log(self.scale) + 1.0

        self.register_buffer("cos_cached", emb.cos() * mscale)
        self.register_buffer("sin_cached", emb.sin() * mscale)

    def forward(self, x, seq_len: int = None):
        if seq_len is None:
            seq_len = x.shape[2]
        return self.cos_cached[:seq_len], self.sin_cached[:seq_len]

Sliding Window Attention

class SlidingWindowAttention(nn.Module):
    """Attention with sliding window for efficient long sequence processing."""

    def __init__(
        self,
        d_model: int,
        n_heads: int,
        window_size: int = 4096,
        dropout: float = 0.1
    ):
        super().__init__()
        self.n_heads = n_heads
        self.d_k = d_model // n_heads
        self.window_size = window_size

        self.W_q = nn.Linear(d_model, d_model)
        self.W_k = nn.Linear(d_model, d_model)
        self.W_v = nn.Linear(d_model, d_model)
        self.W_o = nn.Linear(d_model, d_model)

        self.dropout = nn.Dropout(dropout)

    def _create_sliding_window_mask(self, seq_len: int, device: torch.device):
        """Create causal sliding window attention mask."""
        # Start with causal mask
        mask = torch.tril(torch.ones(seq_len, seq_len, device=device))

        # Apply sliding window
        for i in range(seq_len):
            start = max(0, i - self.window_size + 1)
            mask[i, :start] = 0

        return mask.unsqueeze(0).unsqueeze(0)  # (1, 1, seq_len, seq_len)

    def forward(self, x):
        batch_size, seq_len, _ = x.shape

        Q = self.W_q(x).view(batch_size, seq_len, self.n_heads, self.d_k).transpose(1, 2)
        K = self.W_k(x).view(batch_size, seq_len, self.n_heads, self.d_k).transpose(1, 2)
        V = self.W_v(x).view(batch_size, seq_len, self.n_heads, self.d_k).transpose(1, 2)

        # Compute attention scores
        scores = torch.matmul(Q, K.transpose(-2, -1)) / math.sqrt(self.d_k)

        # Apply sliding window mask
        mask = self._create_sliding_window_mask(seq_len, x.device)
        scores = scores.masked_fill(mask == 0, float('-inf'))

        attn_weights = torch.softmax(scores, dim=-1)
        attn_weights = self.dropout(attn_weights)

        context = torch.matmul(attn_weights, V)
        context = context.transpose(1, 2).contiguous().view(batch_size, seq_len, -1)

        return self.W_o(context)

ALiBi 구현

class ALiBiAttention(nn.Module):
    """Attention with Linear Biases for position encoding."""

    def __init__(self, d_model: int, n_heads: int, dropout: float = 0.1):
        super().__init__()
        self.n_heads = n_heads
        self.d_k = d_model // n_heads

        self.W_q = nn.Linear(d_model, d_model)
        self.W_k = nn.Linear(d_model, d_model)
        self.W_v = nn.Linear(d_model, d_model)
        self.W_o = nn.Linear(d_model, d_model)

        self.dropout = nn.Dropout(dropout)

        # Compute ALiBi slopes
        self.register_buffer("slopes", self._get_alibi_slopes(n_heads))

    def _get_alibi_slopes(self, n_heads: int):
        """Get ALiBi slopes for each attention head."""
        def get_slopes_power_of_2(n):
            start = 2 ** (-(2 ** -(math.log2(n) - 3)))
            ratio = start
            return [start * (ratio ** i) for i in range(n)]

        if math.log2(n_heads).is_integer():
            return torch.tensor(get_slopes_power_of_2(n_heads))
        else:
            closest_power_of_2 = 2 ** math.floor(math.log2(n_heads))
            slopes_1 = get_slopes_power_of_2(closest_power_of_2)
            slopes_2 = get_slopes_power_of_2(2 * closest_power_of_2)[0::2][:n_heads - closest_power_of_2]
            return torch.tensor(slopes_1 + slopes_2)

    def _get_alibi_bias(self, seq_len: int, device: torch.device):
        """Compute ALiBi bias matrix."""
        # Distance matrix
        positions = torch.arange(seq_len, device=device)
        distance = positions.unsqueeze(0) - positions.unsqueeze(1)  # (seq_len, seq_len)
        distance = distance.abs().float()

        # Apply slopes
        alibi = distance.unsqueeze(0) * self.slopes.unsqueeze(-1).unsqueeze(-1).to(device)

        return -alibi  # Negative because we subtract from attention scores

    def forward(self, x, mask=None):
        batch_size, seq_len, _ = x.shape

        Q = self.W_q(x).view(batch_size, seq_len, self.n_heads, self.d_k).transpose(1, 2)
        K = self.W_k(x).view(batch_size, seq_len, self.n_heads, self.d_k).transpose(1, 2)
        V = self.W_v(x).view(batch_size, seq_len, self.n_heads, self.d_k).transpose(1, 2)

        # Attention scores with ALiBi bias
        scores = torch.matmul(Q, K.transpose(-2, -1)) / math.sqrt(self.d_k)
        scores = scores + self._get_alibi_bias(seq_len, x.device)

        # Causal mask
        if mask is None:
            mask = torch.tril(torch.ones(seq_len, seq_len, device=x.device))
            mask = mask.unsqueeze(0).unsqueeze(0)

        scores = scores.masked_fill(mask == 0, float('-inf'))

        attn_weights = torch.softmax(scores, dim=-1)
        attn_weights = self.dropout(attn_weights)

        context = torch.matmul(attn_weights, V)
        context = context.transpose(1, 2).contiguous().view(batch_size, seq_len, -1)

        return self.W_o(context)

참고 논문

  1. Su, J., et al. (2021). "RoFormer: Enhanced Transformer with Rotary Position Embedding."
  2. arXiv: https://arxiv.org/abs/2104.09864

  3. Press, O., et al. (2021). "Train Short, Test Long: Attention with Linear Biases Enables Input Length Generalization." (ALiBi)

  4. arXiv: https://arxiv.org/abs/2108.12409

  5. Chen, S., et al. (2023). "Extending Context Window of Large Language Models via Positional Interpolation." (PI)

  6. arXiv: https://arxiv.org/abs/2306.15595

  7. Peng, B., et al. (2023). "YaRN: Efficient Context Window Extension of Large Language Models."

  8. arXiv: https://arxiv.org/abs/2309.00071

  9. Ding, Y., et al. (2024). "LongRoPE: Extending LLM Context Window Beyond 2 Million Tokens."

  10. arXiv: https://arxiv.org/abs/2402.13753

  11. Liu, H., et al. (2023). "Scaling Laws of RoPE-based Extrapolation."

  12. arXiv: https://arxiv.org/abs/2310.05209

  13. Dao, T. (2024). "FlashAttention-2: Faster Attention with Better Parallelism and Work Partitioning."

  14. arXiv: https://arxiv.org/abs/2307.08691

  15. Liu, Z., et al. (2023). "Ring Attention with Blockwise Transformers for Near-Infinite Context."

  16. arXiv: https://arxiv.org/abs/2310.01889