콘텐츠로 이동
Data Prep
상세

Gated Attention

개요

Gated Attention은 NeurIPS 2025 Best Paper로 선정된 Qwen 팀의 연구다. Scaled Dot-Product Attention (SDPA) 출력에 head-specific sigmoid 게이트를 추가하는 간단한 수정으로 LLM 성능을 일관되게 향상시킨다.

핵심 발견

기존 Softmax Attention

Attention(Q, K, V) = softmax(QK^T / √d) · V

Gated Attention

GatedAttention(Q, K, V) = σ(g) ⊙ softmax(QK^T / √d) · V

σ: sigmoid 함수
g: 학습 가능한 게이트 파라미터 (head별)
⊙: element-wise 곱

구현:

import torch
import torch.nn as nn

class GatedAttention(nn.Module):
    def __init__(self, dim, n_heads):
        super().__init__()
        self.n_heads = n_heads
        self.head_dim = dim // n_heads

        self.qkv = nn.Linear(dim, 3 * dim)
        self.proj = nn.Linear(dim, dim)

        # Head-specific gate
        self.gate = nn.Parameter(torch.zeros(n_heads, 1, 1))

    def forward(self, x):
        B, N, C = x.shape

        qkv = self.qkv(x).reshape(B, N, 3, self.n_heads, self.head_dim)
        qkv = qkv.permute(2, 0, 3, 1, 4)
        q, k, v = qkv[0], qkv[1], qkv[2]

        # Standard attention
        attn = (q @ k.transpose(-2, -1)) / (self.head_dim ** 0.5)
        attn = attn.softmax(dim=-1)

        # SDPA output
        out = attn @ v

        # Apply sigmoid gate (핵심!)
        gate = torch.sigmoid(self.gate)
        out = gate * out

        out = out.transpose(1, 2).reshape(B, N, C)
        return self.proj(out)

실험 결과

대규모 비교 (30+ 변형)

모델 파라미터 토큰 게이트 없음 게이트 추가 향상
Dense 1.7B 3.5T 2.89 PPL 2.81 PPL +2.8%
MoE 15B 3.5T 2.54 PPL 2.47 PPL +2.8%

게이트 위치별 비교

게이트 위치 효과 비용
Q 이전 낮음 낮음
K 이전 낮음 낮음
V 이전 중간 낮음
SDPA 이후 최고 낮음
Attention 이후 중간 중간

효과 분석

1. 비선형성 도입

Softmax attention의 저랭크 매핑에 비선형성 추가:

기존: Linear(V) → Low-rank Attention → Output
개선: Linear(V) → Low-rank Attention → σ(gate) → Output
                                   비선형 변환

수학적 해석: - Softmax attention은 본질적으로 저랭크(low-rank) 연산 - 게이트는 출력 공간에 비선형 변환 추가 - 표현력(expressiveness) 향상

2. 희소 게이팅

Query-dependent한 희소 게이팅 효과:

# 게이트 활성화 분석
gate_activations = torch.sigmoid(model.gate)

# 많은 헤드에서 0에 가까운 게이트 값 관찰
# → 불필요한 헤드 자동 비활성화

효과: - 자동 헤드 프루닝 - 계산 효율성 잠재적 향상 - 중요한 헤드에 집중

3. Attention Sink 완화

Attention Sink 문제: - 첫 번째 토큰에 과도한 attention 집중 - 정보 병목 발생 - 장문맥에서 성능 저하

게이트의 완화 효과:

[기존]
토큰 위치: [1] [2] [3] [4] [5] ...
Attention:  0.9 0.02 0.03 0.02 0.03 ...  ← Sink!

[게이트 추가]
토큰 위치: [1] [2] [3] [4] [5] ...
Attention:  0.3 0.18 0.17 0.18 0.17 ...  ← 분산됨

4. 컨텍스트 확장 성능

컨텍스트 길이 기존 게이트 추가
4K 2.89 2.81
8K 3.12 2.95
16K 3.45 3.10
32K 4.21 3.52

게이트가 장문맥 외삽(extrapolation) 성능 향상

학습 안정성

더 큰 학습률 허용

학습률 기존 게이트 추가
1e-4 수렴 수렴
3e-4 수렴 수렴
5e-4 불안정 수렴
1e-3 발산 수렴

의미: - 학습률 튜닝 여유 증가 - 빠른 수렴 가능 - 하이퍼파라미터 민감도 감소

Scaling 특성

게이트 추가 시 더 나은 스케일링 법칙:

Loss ∝ N^(-α)

기존: α ≈ 0.076
게이트: α ≈ 0.082

→ 모델 크기 증가에 따른 성능 향상 가속

실제 적용

Qwen3-Next 모델

논문의 SDPA 출력 게이팅이 Qwen3-Next에 적용됨:

Qwen3-Next = Qwen3 + Gated Attention

공개 자료: - 코드: github.com/qiuzh20/gated_attention - 모델: huggingface.co/QwQZh/gated_attention - Qwen3-Next: huggingface.co/collections/Qwen/qwen3-next

기존 모델에 적용

# 기존 Transformer에 게이트 추가하기
def add_gating_to_attention(model):
    for layer in model.layers:
        attention = layer.self_attn
        n_heads = attention.num_heads

        # 게이트 파라미터 추가
        attention.gate = nn.Parameter(
            torch.zeros(n_heads, 1, 1)
        )

        # forward 수정
        original_forward = attention.forward

        def gated_forward(self, *args, **kwargs):
            out = original_forward(*args, **kwargs)
            gate = torch.sigmoid(self.gate)
            return gate * out

        attention.forward = gated_forward.__get__(attention)

    return model

관련 연구

게이팅 메커니즘 역사

시기 모델 게이팅 방식
1997 LSTM Input/Forget/Output Gate
2015 Highway Network Transform Gate
2020 gMLP Spatial Gate
2024 Mamba Selection Gate
2025 Gated Attention SDPA Output Gate

다른 Attention 변형과 비교

방법 접근 복잡도 증가
Multi-Query KV 공유 감소
Flash Attention IO 최적화 없음
Linear Attention Softmax 제거 감소
Gated Attention 게이트 추가 최소

요약

핵심 포인트

  1. 단순함: SDPA 출력에 sigmoid 게이트 하나 추가
  2. 효과적: 일관된 성능 향상 (2-3%)
  3. 안정적: 학습 안정성 향상, 큰 학습률 허용
  4. 확장 가능: 장문맥 성능 개선
  5. 실용적: 이미 Qwen3-Next에 적용됨

구현 체크리스트

  • [ ] Head별 학습 가능한 게이트 파라미터 추가
  • [ ] SDPA 출력에 sigmoid(gate) 곱하기
  • [ ] 게이트 초기값 0 (학습 시작 시 0.5 게이트)
  • [ ] 기존 학습 설정 유지 가능

참고 자료


마지막 업데이트: 2026-02-11