콘텐츠로 이동
Data Prep
상세

Nested Learning

개요

Nested Learning은 NeurIPS 2025에서 Google Research가 발표한 새로운 ML 패러다임이다. 모델의 아키텍처와 최적화 알고리즘을 동일한 개념의 다른 "레벨"로 통합하여, 지속 학습(Continual Learning)에서의 파국적 망각(Catastrophic Forgetting) 문제를 해결한다.

핵심 아이디어

기존 접근법의 한계

문제 기존 해결책 한계
파국적 망각 EWC, PackNet 확장성 제한
지식 유지 리플레이 버퍼 메모리 비용
모델 성장 프로그레시브 네트워크 복잡도 증가

Nested Learning의 관점

[최적화 알고리즘]           [신경망 아키텍처]
       │                         │
       └─────── 동일한 개념 ──────┘
              다른 "레벨"

레벨 1: 가중치 업데이트 (빠름)
레벨 2: 모멘텀/적응적 학습률 (중간)
레벨 3: 아키텍처 파라미터 (느림)

핵심 통찰: - 아키텍처와 최적화는 근본적으로 같은 것 - 각 레벨은 고유한 "컨텍스트 흐름"을 가짐 - 업데이트 빈도가 레벨을 정의

이론적 기반

연관 기억 (Associative Memory) 관점

역전파 = 연관 기억:

입력(x) → 기대 오차(error) 매핑

모델이 각 데이터 포인트의 "놀라움" 정도를 기억

Attention = 연관 기억:

Query-Key → Value 매핑

시퀀스 내 토큰 간 관계를 기억

레벨 계층 구조

레벨 구성요소 업데이트 빈도 역할
1 가중치 매 배치 즉각적 학습
2 모멘텀 상태 배치마다 누적 그래디언트 안정화
3 학습률 에폭 단위 적응적 조정
4 아키텍처 태스크 단위 구조적 변화

Continuum Memory System (CMS)

기존 Transformer 메모리

[단기 기억]                    [장기 기억]
    │                              │
  Attention                   FFN 가중치
(현재 컨텍스트)              (사전학습 지식)
    │                              │
    └─────── 2가지 레벨만 ─────────┘

CMS 확장

[메모리 스펙트럼]

초단기 ─ 단기 ─ 중기 ─ 장기 ─ 영구
   │      │      │      │       │
 현재   세션   에피소드  지식   아키텍처
 토큰   컨텍스트  기억    베이스   구조

CMS 블록 구조:

class CMSBlock:
    def __init__(self, n_levels=5):
        self.memories = [
            Memory(update_freq=1),      # 매 스텝
            Memory(update_freq=10),     # 10 스텝마다
            Memory(update_freq=100),    # 100 스텝마다
            Memory(update_freq=1000),   # 1000 스텝마다
            Memory(update_freq=float('inf'))  # 고정
        ]

    def forward(self, x, step):
        outputs = []
        for i, mem in enumerate(self.memories):
            if step % mem.update_freq == 0:
                mem.update(x)
            outputs.append(mem.read(x))
        return aggregate(outputs)

Hope 아키텍처

구조

Hope는 Nested Learning 원리를 적용한 자기 수정(self-modifying) 아키텍처:

[입력] → [CMS 블록] → [Self-Modifying Layer] → [출력]
              ↑              │
              └──── 자기 참조 ────┘

핵심 특징: - Titans 아키텍처 기반 - 무한 레벨의 인-컨텍스트 학습 - 자기 참조 프로세스로 메모리 최적화

성능 결과

태스크 Transformer Mamba Titans Hope
언어 모델링 (PPL) 15.2 14.8 14.1 13.6
상식 추론 (Acc) 72.3 73.1 74.2 75.8
NIAH (F1) 0.82 0.85 0.91 0.96

NIAH: Needle-In-A-Haystack (장문맥 검색)

Deep Optimizers

개선된 모멘텀

기존 모멘텀:

v_t = β * v_{t-1} + (1-β) * g_t

Nested Learning 모멘텀:

def nested_momentum(gradient, memory, beta=0.9):
    # L2 회귀 기반 모멘텀 (데이터 샘플 간 관계 고려)
    similarity = compute_sample_similarity(gradient, memory)
    weighted_memory = similarity @ memory

    v = beta * weighted_memory + (1 - beta) * gradient

    memory.update(gradient)
    return v

효과: - 불완전한 데이터에 더 견고 - 노이즈 그래디언트 자동 필터링 - 학습 안정성 향상

실용적 시사점

적용 가능 영역

영역 적용 방법 기대 효과
지속 학습 CMS 블록 추가 망각 감소
장문맥 메모리 계층화 컨텍스트 확장
온라인 학습 적응형 최적화기 빠른 적응
멀티태스크 레벨별 파라미터 분리 태스크 간섭 감소

구현 고려사항

# 간단한 CMS 구현 예시
class SimpleCMS(nn.Module):
    def __init__(self, dim, n_levels=3):
        super().__init__()
        self.levels = nn.ModuleList([
            nn.Linear(dim, dim) for _ in range(n_levels)
        ])
        self.update_freqs = [1, 10, 100]
        self.step = 0

    def forward(self, x):
        outputs = []
        for i, (level, freq) in enumerate(zip(self.levels, self.update_freqs)):
            if self.training and self.step % freq == 0:
                # 해당 레벨 업데이트 허용
                out = level(x)
            else:
                with torch.no_grad():
                    out = level(x)
            outputs.append(out)

        self.step += 1
        return sum(outputs) / len(outputs)

한계 및 향후 연구

현재 한계

  • 계산 복잡도 증가
  • 하이퍼파라미터 (레벨 수, 업데이트 빈도) 튜닝 필요
  • 대규모 모델에서의 검증 부족

향후 연구 방향

  • 자동 레벨 구성 (AutoML)
  • 하드웨어 최적화
  • 더 긴 시간 스케일의 학습

참고 자료


마지막 업데이트: 2026-02-11