Deep RL Scaling: 1000 Layer Networks¶
Meta Information¶
| Item | Value |
|---|---|
| Title | 1000 Layer Networks for Self-Supervised RL: Scaling Depth Can Enable New Goal-Reaching Capabilities |
| Authors | Kevin Wang et al. |
| Venue | NeurIPS 2025 (Best Paper) |
| arXiv | 2503.14858 |
| Code | GitHub |
| Project | Project Page |
Core Contribution¶
기존 강화학습(RL)은 2-5층의 얕은 네트워크를 사용해왔으나, 이 논문은 1024층까지 depth scaling을 통해 self-supervised RL에서 2x-50x 성능 향상을 달성. 단순히 성능 향상뿐 아니라, 깊이 증가 시 질적으로 다른 행동 패턴(emergent behaviors)이 출현함을 보임.
Problem Statement¶
- 배경: NLP와 CV에서는 수백 층의 대형 모델(Llama 3, Stable Diffusion 3)이 성공
- 문제: RL은 여전히 얕은 네트워크(2-5층) 사용
- 질문: RL에서도 depth scaling이 유의미한가?
Methodology¶
1. Self-Supervised Goal-Conditioned RL¶
┌─────────────────┐
│ Environment │
└────────┬────────┘
│ state
▼
┌─────────────┐ ┌───────────────┐ ┌─────────────┐
│ Goal │──────▶│ Deep Critic │◀──────│ Action │
│ Embedding │ │ (1024 layers)│ │ Embedding │
└─────────────┘ └───────────────┘ └─────────────┘
│
▼
┌─────────────────┐
│ InfoNCE Loss │
│ (Contrastive) │
└─────────────────┘
- 설정: Demonstration이나 reward 없이 exploration만으로 학습
- 목표: 주어진 goal state에 도달할 확률 최대화
- 방법: Contrastive RL (Actor-Critic with contrastive objectives)
2. Network Architecture¶
ResNet 스타일의 deep architecture:
Input
│
▼
┌─────────────────────────┐
│ Dense Layer │
│ Layer Normalization │ × 4 (1 Block)
│ Swish Activation │
└───────────┬─────────────┘
│
│◀─── Residual Connection
▼
┌─────────────────────────┐
│ Next Block (×256) │
│ = 1024 layers total │
└─────────────────────────┘
핵심 기법: - Residual connections (매 4개 블록마다) - Layer Normalization - Swish Activation
3. Hindsight Experience Replay (HER)¶
┌──────────────────────────────────────────┐
│ Original Trajectory │
│ Start → ... → Failed (Goal 미도달) │
└────────────────┬─────────────────────────┘
│ Relabel
▼
┌──────────────────────────────────────────┐
│ Relabeled Trajectory │
│ Start → ... → Achieved State (새 Goal) │
└──────────────────────────────────────────┘
- 실패한 trajectory도 학습 데이터로 활용
- 실제 도달한 state를 goal로 재라벨링
- Sparse reward 문제 완화
Key Results¶
1. Performance Improvement¶
| Baseline Depth | Deep Model | Improvement |
|---|---|---|
| 2-5 layers | 64 layers | 2-10x |
| 2-5 layers | 256 layers | 5-20x |
| 2-5 layers | 1024 layers | 10-50x |
2. Emergent Capabilities¶
깊이가 특정 threshold를 넘으면 점프 형태로 새로운 능력이 출현:
Performance
│
│ ╭──────── 64 layers (new policy)
│ ╭────╯
│ ╭────╯
│ ╭────╯ ← threshold
│─────╯
│
└────────────────────────────────▶ Depth
4 16 32 64 128
3. Depth vs Width¶
동일 파라미터 수에서 depth scaling이 width scaling보다 효율적:
| Scaling Method | Parameters | Performance |
|---|---|---|
| Width 2x | ~2x | +30% |
| Depth 2x | ~2x | +80% |
4. Better Representations¶
Shallow (4 layers): - Goal까지 단순 Euclidean distance 학습 - 벽/장애물 무시
Deep (64+ layers): - Maze topology 정확히 반영 - Goal 주변에 더 분산된 representation - 일반화 능력 향상
Why Does Depth Work?¶
1. Improved Contrastive Representations¶
깊은 네트워크는 state-action embedding에서 더 rich한 구조 학습:
Q-value visualization (U4-Maze)
Shallow (4L): Deep (64L):
┌─────────────┐ ┌─────────────┐
│ ░░░░░░░░░G │ │ ▓▓▓▓▓▓▓▓▓G │
│ ░░░░░░░░░░ │ │ ▓▓▓▓▓▓▓▓▓▓ │
│ ░░░░░░░░░░ │ │ ▓▓▓▓▓▓▓▓▓▓ │
│ ░░░░██░░░░ │ │ ░░░░██▓▓▓▓ │
│ S░░░██░░░░ │ │ S░░░██▓▓▓▓ │
└─────────────┘ └─────────────┘
(Euclidean) (Topology-aware)
2. Representational Capacity Allocation¶
- Deep networks: goal 근처 state에 더 많은 capacity 할당
- Shallow networks: goal 근처 embedding이 tight하게 clustering
3. Experience Stitching¶
Training에 없는 start-goal 쌍도 일반화 가능: - Partial trajectories 조합으로 새로운 경로 생성 - Depth 증가 시 stitching 능력 향상
Python Example¶
import jax
import jax.numpy as jnp
import flax.linen as nn
class DeepRLBlock(nn.Module):
"""Single block: Dense + LayerNorm + Swish"""
hidden_dim: int
@nn.compact
def __call__(self, x):
y = nn.Dense(self.hidden_dim)(x)
y = nn.LayerNorm()(y)
y = nn.swish(y)
return y
class DeepCritic(nn.Module):
"""1024-layer critic with residual connections every 4 blocks"""
hidden_dim: int = 256
num_blocks: int = 256 # 256 * 4 = 1024 layers
@nn.compact
def __call__(self, state, action, goal):
# Concatenate inputs
x = jnp.concatenate([state, action, goal], axis=-1)
x = nn.Dense(self.hidden_dim)(x)
# Deep residual blocks
for i in range(self.num_blocks):
residual = x
# 4 layers per block
for _ in range(4):
x = DeepRLBlock(self.hidden_dim)(x)
# Residual connection
x = x + residual
# Output embeddings
state_action_emb = nn.Dense(128)(x)
goal_emb = nn.Dense(128)(nn.Dense(self.hidden_dim)(goal))
return state_action_emb, goal_emb
def infonce_loss(state_action_emb, goal_emb, temperature=0.1):
"""Contrastive loss for goal-conditioned RL"""
# Normalize embeddings
sa_norm = state_action_emb / jnp.linalg.norm(state_action_emb, axis=-1, keepdims=True)
g_norm = goal_emb / jnp.linalg.norm(goal_emb, axis=-1, keepdims=True)
# Similarity matrix
logits = jnp.dot(sa_norm, g_norm.T) / temperature
# InfoNCE loss (diagonal = positive pairs)
labels = jnp.arange(logits.shape[0])
loss = -jnp.mean(jax.nn.log_softmax(logits, axis=1)[jnp.arange(len(labels)), labels])
return loss
# Example usage
key = jax.random.PRNGKey(42)
critic = DeepCritic()
# Initialize
state = jnp.ones((32, 64)) # batch of states
action = jnp.ones((32, 8)) # batch of actions
goal = jnp.ones((32, 64)) # batch of goals
params = critic.init(key, state, action, goal)
sa_emb, g_emb = critic.apply(params, state, action, goal)
loss = infonce_loss(sa_emb, g_emb)
print(f"InfoNCE Loss: {loss:.4f}")
Practical Considerations¶
When to Use Deep RL Scaling¶
| Scenario | Recommendation |
|---|---|
| Sparse reward | Highly recommended |
| Complex navigation | Highly recommended |
| Dense reward | Moderate benefit |
| Simple control | May be overkill |
Training Tips¶
- Residual connections 필수: 1000층 이상에서 gradient flow 유지
- Layer Normalization: 각 층에서 activation 안정화
- Swish > ReLU: 깊은 네트워크에서 더 smooth한 gradient
- Hindsight relabeling: Sparse reward 환경에서 data efficiency
Computational Cost¶
Depth | Training Time | Memory | Performance
-------|---------------|-----------|------------
4L | 1x | 1x | baseline
64L | ~4x | ~8x | 5-10x better
256L | ~12x | ~20x | 10-30x better
1024L | ~40x | ~60x | 20-50x better
Related Work¶
| Paper | Key Contribution | Relation |
|---|---|---|
| Deep Residual Learning (2016) | Residual connections | Architecture foundation |
| Contrastive RL (2022) | Self-supervised RL | Base algorithm |
| Scaling Laws for RL (2024) | Width scaling analysis | Complementary study |
| Gated Attention (2025) | Attention depth scaling | Parallel finding in LLMs |
Implications¶
- Paradigm shift: RL에서도 depth scaling이 유효함
- Emergent capabilities: 충분한 깊이에서 새로운 행동 출현
- Parameter efficiency: Width보다 depth가 더 효율적
- Foundation models for RL: LLM처럼 대규모 RL 모델 가능성
References¶
- Wang, K. et al. (2025). 1000 Layer Networks for Self-Supervised RL. NeurIPS 2025.
- He, K. et al. (2016). Deep Residual Learning for Image Recognition. CVPR.
- Eysenbach, B. et al. (2022). Contrastive Learning as Goal-Conditioned RL. NeurIPS.
- Andrychowicz, M. et al. (2017). Hindsight Experience Replay. NeurIPS.