콘텐츠로 이동
Data Prep
상세

Deep RL Scaling: 1000 Layer Networks

Meta Information

Item Value
Title 1000 Layer Networks for Self-Supervised RL: Scaling Depth Can Enable New Goal-Reaching Capabilities
Authors Kevin Wang et al.
Venue NeurIPS 2025 (Best Paper)
arXiv 2503.14858
Code GitHub
Project Project Page

Core Contribution

기존 강화학습(RL)은 2-5층의 얕은 네트워크를 사용해왔으나, 이 논문은 1024층까지 depth scaling을 통해 self-supervised RL에서 2x-50x 성능 향상을 달성. 단순히 성능 향상뿐 아니라, 깊이 증가 시 질적으로 다른 행동 패턴(emergent behaviors)이 출현함을 보임.

Problem Statement

  • 배경: NLP와 CV에서는 수백 층의 대형 모델(Llama 3, Stable Diffusion 3)이 성공
  • 문제: RL은 여전히 얕은 네트워크(2-5층) 사용
  • 질문: RL에서도 depth scaling이 유의미한가?

Methodology

1. Self-Supervised Goal-Conditioned RL

                     ┌─────────────────┐
                     │   Environment   │
                     └────────┬────────┘
                              │ state
┌─────────────┐       ┌───────────────┐       ┌─────────────┐
│    Goal     │──────▶│  Deep Critic  │◀──────│   Action    │
│  Embedding  │       │  (1024 layers)│       │  Embedding  │
└─────────────┘       └───────────────┘       └─────────────┘
                     ┌─────────────────┐
                     │  InfoNCE Loss   │
                     │  (Contrastive)  │
                     └─────────────────┘
  • 설정: Demonstration이나 reward 없이 exploration만으로 학습
  • 목표: 주어진 goal state에 도달할 확률 최대화
  • 방법: Contrastive RL (Actor-Critic with contrastive objectives)

2. Network Architecture

ResNet 스타일의 deep architecture:

Input
┌─────────────────────────┐
│  Dense Layer            │
│  Layer Normalization    │ × 4 (1 Block)
│  Swish Activation       │
└───────────┬─────────────┘
            │◀─── Residual Connection
┌─────────────────────────┐
│  Next Block (×256)      │
│  = 1024 layers total    │
└─────────────────────────┘

핵심 기법: - Residual connections (매 4개 블록마다) - Layer Normalization - Swish Activation

3. Hindsight Experience Replay (HER)

┌──────────────────────────────────────────┐
│  Original Trajectory                      │
│  Start → ... → Failed (Goal 미도달)       │
└────────────────┬─────────────────────────┘
                 │ Relabel
┌──────────────────────────────────────────┐
│  Relabeled Trajectory                     │
│  Start → ... → Achieved State (새 Goal)   │
└──────────────────────────────────────────┘
  • 실패한 trajectory도 학습 데이터로 활용
  • 실제 도달한 state를 goal로 재라벨링
  • Sparse reward 문제 완화

Key Results

1. Performance Improvement

Baseline Depth Deep Model Improvement
2-5 layers 64 layers 2-10x
2-5 layers 256 layers 5-20x
2-5 layers 1024 layers 10-50x

2. Emergent Capabilities

깊이가 특정 threshold를 넘으면 점프 형태로 새로운 능력이 출현:

Performance
    │                    ╭──────── 64 layers (new policy)
    │               ╭────╯
    │          ╭────╯
    │     ╭────╯              ← threshold
    │─────╯
    └────────────────────────────────▶ Depth
         4    16    32    64   128

3. Depth vs Width

동일 파라미터 수에서 depth scaling이 width scaling보다 효율적:

Scaling Method Parameters Performance
Width 2x ~2x +30%
Depth 2x ~2x +80%

4. Better Representations

Shallow (4 layers): - Goal까지 단순 Euclidean distance 학습 - 벽/장애물 무시

Deep (64+ layers): - Maze topology 정확히 반영 - Goal 주변에 더 분산된 representation - 일반화 능력 향상

Why Does Depth Work?

1. Improved Contrastive Representations

깊은 네트워크는 state-action embedding에서 더 rich한 구조 학습:

Q-value visualization (U4-Maze)

Shallow (4L):           Deep (64L):
┌─────────────┐        ┌─────────────┐
│ ░░░░░░░░░G │        │ ▓▓▓▓▓▓▓▓▓G │
│ ░░░░░░░░░░ │        │ ▓▓▓▓▓▓▓▓▓▓ │
│ ░░░░░░░░░░ │        │ ▓▓▓▓▓▓▓▓▓▓ │
│ ░░░░██░░░░ │        │ ░░░░██▓▓▓▓ │
│ S░░░██░░░░ │        │ S░░░██▓▓▓▓ │
└─────────────┘        └─────────────┘
 (Euclidean)           (Topology-aware)

2. Representational Capacity Allocation

  • Deep networks: goal 근처 state에 더 많은 capacity 할당
  • Shallow networks: goal 근처 embedding이 tight하게 clustering

3. Experience Stitching

Training에 없는 start-goal 쌍도 일반화 가능: - Partial trajectories 조합으로 새로운 경로 생성 - Depth 증가 시 stitching 능력 향상

Python Example

import jax
import jax.numpy as jnp
import flax.linen as nn

class DeepRLBlock(nn.Module):
    """Single block: Dense + LayerNorm + Swish"""
    hidden_dim: int

    @nn.compact
    def __call__(self, x):
        y = nn.Dense(self.hidden_dim)(x)
        y = nn.LayerNorm()(y)
        y = nn.swish(y)
        return y

class DeepCritic(nn.Module):
    """1024-layer critic with residual connections every 4 blocks"""
    hidden_dim: int = 256
    num_blocks: int = 256  # 256 * 4 = 1024 layers

    @nn.compact
    def __call__(self, state, action, goal):
        # Concatenate inputs
        x = jnp.concatenate([state, action, goal], axis=-1)
        x = nn.Dense(self.hidden_dim)(x)

        # Deep residual blocks
        for i in range(self.num_blocks):
            residual = x

            # 4 layers per block
            for _ in range(4):
                x = DeepRLBlock(self.hidden_dim)(x)

            # Residual connection
            x = x + residual

        # Output embeddings
        state_action_emb = nn.Dense(128)(x)
        goal_emb = nn.Dense(128)(nn.Dense(self.hidden_dim)(goal))

        return state_action_emb, goal_emb

def infonce_loss(state_action_emb, goal_emb, temperature=0.1):
    """Contrastive loss for goal-conditioned RL"""
    # Normalize embeddings
    sa_norm = state_action_emb / jnp.linalg.norm(state_action_emb, axis=-1, keepdims=True)
    g_norm = goal_emb / jnp.linalg.norm(goal_emb, axis=-1, keepdims=True)

    # Similarity matrix
    logits = jnp.dot(sa_norm, g_norm.T) / temperature

    # InfoNCE loss (diagonal = positive pairs)
    labels = jnp.arange(logits.shape[0])
    loss = -jnp.mean(jax.nn.log_softmax(logits, axis=1)[jnp.arange(len(labels)), labels])

    return loss

# Example usage
key = jax.random.PRNGKey(42)
critic = DeepCritic()

# Initialize
state = jnp.ones((32, 64))   # batch of states
action = jnp.ones((32, 8))   # batch of actions
goal = jnp.ones((32, 64))    # batch of goals

params = critic.init(key, state, action, goal)
sa_emb, g_emb = critic.apply(params, state, action, goal)

loss = infonce_loss(sa_emb, g_emb)
print(f"InfoNCE Loss: {loss:.4f}")

Practical Considerations

When to Use Deep RL Scaling

Scenario Recommendation
Sparse reward Highly recommended
Complex navigation Highly recommended
Dense reward Moderate benefit
Simple control May be overkill

Training Tips

  1. Residual connections 필수: 1000층 이상에서 gradient flow 유지
  2. Layer Normalization: 각 층에서 activation 안정화
  3. Swish > ReLU: 깊은 네트워크에서 더 smooth한 gradient
  4. Hindsight relabeling: Sparse reward 환경에서 data efficiency

Computational Cost

Depth  | Training Time | Memory    | Performance
-------|---------------|-----------|------------
4L     | 1x            | 1x        | baseline
64L    | ~4x           | ~8x       | 5-10x better
256L   | ~12x          | ~20x      | 10-30x better
1024L  | ~40x          | ~60x      | 20-50x better
Paper Key Contribution Relation
Deep Residual Learning (2016) Residual connections Architecture foundation
Contrastive RL (2022) Self-supervised RL Base algorithm
Scaling Laws for RL (2024) Width scaling analysis Complementary study
Gated Attention (2025) Attention depth scaling Parallel finding in LLMs

Implications

  1. Paradigm shift: RL에서도 depth scaling이 유효함
  2. Emergent capabilities: 충분한 깊이에서 새로운 행동 출현
  3. Parameter efficiency: Width보다 depth가 더 효율적
  4. Foundation models for RL: LLM처럼 대규모 RL 모델 가능성

References

  1. Wang, K. et al. (2025). 1000 Layer Networks for Self-Supervised RL. NeurIPS 2025.
  2. He, K. et al. (2016). Deep Residual Learning for Image Recognition. CVPR.
  3. Eysenbach, B. et al. (2022). Contrastive Learning as Goal-Conditioned RL. NeurIPS.
  4. Andrychowicz, M. et al. (2017). Hindsight Experience Replay. NeurIPS.