Inference-Time Scaling (추론 시 연산 확장)¶
메타 정보¶
| 항목 | 내용 |
|---|---|
| 분류 | LLM Reasoning / Test-Time Compute / Scaling Laws |
| 핵심 논문 | "Scaling LLM Test-Time Compute Optimally can be More Effective than Scaling Model Parameters" (Snell et al., NeurIPS 2024, arXiv 2408.03314) |
| 주요 저자 | Charlie Snell, Jaehoon Lee, Kelvin Xu, Aviral Kumar (Google DeepMind) |
| 핵심 개념 | 사전학습 단계의 모델 크기/데이터 확장 대신, 추론 시점의 연산량을 늘려 성능을 향상시키는 패러다임 |
| 관련 시스템 | OpenAI o1/o3, DeepSeek-R1, s1, QwQ, Gemini 2.0 Flash Thinking |
| 관련 분야 | Process Reward Models, MCTS, Chain-of-Thought, Reinforcement Learning, Self-Play |
정의¶
Inference-Time Scaling (ITS)은 학습이 완료된 LLM에 대해 추론 단계에서 추가 연산을 투입하여 응답 품질을 향상시키는 기법의 총칭이다. 전통적인 Scaling Law가 학습 시 모델 크기(N), 데이터(D), 연산량(C)의 확장에 초점을 맞췄다면, ITS는 고정된 모델에서 추론 시 연산(C_inference)을 확장하는 새로운 축을 제시한다.
전통적 Scaling Law (Kaplan et al., 2020; Hoffmann et al., 2022):
Performance ~ f(N, D, C_train)
성능 개선 = 더 큰 모델 + 더 많은 데이터 + 더 많은 학습 연산
Inference-Time Scaling (Snell et al., 2024):
Performance ~ g(N, C_train, C_inference)
성능 개선 = 고정된 모델 + 추론 시 더 많은 연산
핵심 발견:
추론 연산 확장이 모델 크기 확장보다 효율적일 수 있음
(특히 어려운 문제에서)
핵심 아이디어¶
왜 추론 시 연산을 확장하는가¶
인간의 문제 해결 과정:
쉬운 문제: 즉시 답변 (System 1)
어려운 문제: 숙고, 검증, 재시도 (System 2)
기존 LLM:
모든 문제에 동일한 연산량 (1회 forward pass)
-> 어려운 문제에서 체계적 실패
ITS 적용 LLM:
쉬운 문제: 빠른 응답 (적은 연산)
어려운 문제: 더 많은 토큰 생성, 검증, 탐색 (많은 연산)
-> 문제 난이도에 따른 적응적 연산 배분
두 가지 확장 축¶
| 축 | 설명 | 예시 |
|---|---|---|
| 학습 시 확장 (Train-Time) | 모델 파라미터, 학습 데이터, 학습 연산 증가 | GPT-4 -> GPT-5 |
| 추론 시 확장 (Test-Time) | 고정된 모델에서 추론 연산 증가 | o1의 "thinking" 토큰 |
Snell et al. (2024)의 핵심 결과:
실험 설정: MATH 벤치마크
소형 모델 + 추론 연산 확장 vs 대형 모델 + 단일 추론
결과:
Llama-3-8B + 최적 추론 확장 >= Llama-3-70B + 단일 추론
-> 14배 이상의 FLOPs 절감 가능
-> 어려운 문제일수록 추론 확장의 이점이 큼
추론 시 연산 확장 방법론¶
크게 두 가지 접근으로 분류된다.
1. 탐색 기반 (Search-Based)¶
여러 후보 답변을 생성하고 검증기를 통해 최적 답변을 선택한다.
1.1 Best-of-N (Rejection Sampling)¶
Algorithm: Best-of-N
Input: 문제 x, 생성 모델 G, 검증 모델 V, 샘플 수 N
Output: 최적 답변
1. N개의 독립적 답변 생성:
y_1, y_2, ..., y_N ~ G(y|x)
2. 각 답변에 대해 검증 점수 계산:
s_i = V(x, y_i) for i = 1, ..., N
3. 최고 점수 답변 반환:
y* = argmax_{y_i} s_i
연산 비용: O(N * C_generate + N * C_verify)
성능 변화: log(N)에 비례하여 개선 (수확 체감)
| 검증 방법 | 설명 | 장점 | 단점 |
|---|---|---|---|
| Self-Consistency | 다수결 투표 | 별도 검증기 불필요 | 정답이 소수일 때 실패 |
| ORM (Outcome RM) | 최종 답변의 정확성 평가 | 학습 데이터 구축 용이 | 중간 과정 무시 |
| PRM (Process RM) | 각 추론 단계별 평가 | 세밀한 신용 할당 | 단계별 레이블 필요 |
1.2 Beam Search with Reward Model¶
Algorithm: Beam Search + PRM
Input: 문제 x, 생성 모델 G, PRM V, beam 폭 B, 최대 단계 T
Output: 최적 추론 경로
1. 초기 beam: B_0 = {x}
2. for t = 1 to T:
a. 각 beam에서 K개 후보 다음 단계 생성:
candidates = Union_{b in B_{t-1}} {G(step|b)_1, ..., G(step|b)_K}
b. PRM으로 각 후보 평가:
scores = {V(candidate) for candidate in candidates}
c. 상위 B개 선택:
B_t = top_B(candidates, scores)
3. 최종 beam에서 최고 점수 경로 반환
연산 비용: O(T * B * K * C_generate + T * B * K * C_verify)
1.3 MCTS (Monte Carlo Tree Search)¶
MCTS for LLM Reasoning:
[문제]
/ \
[단계1a] [단계1b] -- 확장
/ \ |
[2a] [2b] [2c] -- 시뮬레이션
| | |
답A 답B 답C -- 평가
0.8 0.3 0.6 -- PRM 점수
4단계 반복:
Selection: UCB1으로 유망한 노드 선택
Expansion: LLM으로 다음 단계 생성
Simulation: 끝까지 생성 (rollout)
Backprop: 결과를 부모 노드로 전파
UCB1(s) = Q(s)/N(s) + c * sqrt(ln(N_parent) / N(s))
ReST-MCTS* (NeurIPS 2024): PRM을 가치 함수로 사용하는 MCTS 기반 자기 학습 프레임워크. MCTS로 고품질 추론 경로를 탐색하고, 이를 학습 데이터로 활용하여 모델과 PRM을 반복 개선한다.
2. 내재화 기반 (Internalized Reasoning)¶
모델 자체가 긴 추론 체인을 생성하도록 학습한다.
2.1 Long Chain-of-Thought (Extended CoT)¶
일반 CoT (수백 토큰):
문제 -> [짧은 풀이] -> 답
Extended CoT (수천~수만 토큰):
문제 -> [단계 1] -> [검증] -> [단계 2] -> [오류 발견]
-> [재시도] -> [단계 2'] -> [검증] -> ... -> 답
특징:
- 자기 검증 (self-verification)
- 오류 수정 (backtracking)
- 다중 접근법 시도
- 명시적 불확실성 표현
2.2 Budget Forcing (s1)¶
Muennighoff et al. (2025)의 s1 모델이 제안한 기법이다.
Budget Forcing:
1. 연산 예산 T_budget 설정 (원하는 thinking 토큰 수)
2. 과소 사용 방지:
if thinking_tokens < T_budget:
"Wait" 토큰을 삽입하여 추가 사고 유도
-> 모델이 검증, 재확인, 대안 탐색
3. 과다 사용 방지:
if thinking_tokens > T_budget:
강제로 end-of-thinking 토큰 삽입
-> 즉시 최종 답변 생성
효과 (AIME 2024):
Budget forcing 없이: 50%
Budget forcing 적용: 57% (+7%p)
2.3 RL 기반 추론 학습¶
DeepSeek-R1 (2025)의 접근법이다.
DeepSeek-R1 학습 파이프라인:
Phase 1: Cold Start
- 소량의 고품질 long-CoT 데이터로 SFT
- 추론 형식(think -> answer) 학습
Phase 2: 순수 RL (GRPO)
- Reward: 정답 여부 (outcome reward)
- Policy: 추론 체인 생성
- 사전 정의된 추론 형식 없이 RL만으로 학습
GRPO (Group Relative Policy Optimization):
동일 문제에 대해 그룹 내 상대적 보상으로 학습
PPO 대비 별도 가치 네트워크 불필요 -> 메모리 효율적
Phase 3: Rejection Sampling + SFT
- RL 모델로 대량 추론 경로 생성
- 정답인 경로만 필터링
- 필터링된 데이터로 재학습
Phase 4: 2차 RL
- 안전성, 유용성 보상 추가
- 최종 모델 완성
Process Reward Model (PRM) 상세¶
ITS의 핵심 구성 요소인 PRM을 상세히 다룬다.
ORM vs PRM¶
Outcome Reward Model (ORM):
입력: (문제, 전체 풀이)
출력: 최종 답변의 정확도 점수
장점: 학습 데이터 구축이 쉬움 (정답/오답만 필요)
단점: 어디서 틀렸는지 모름 -> sparse credit assignment
Process Reward Model (PRM):
입력: (문제, 풀이의 각 단계)
출력: 각 단계의 정확도 점수
장점: 세밀한 피드백 -> dense credit assignment
단점: 단계별 레이블이 필요 -> 비용이 높음
PRM 학습 데이터 구축¶
| 방법 | 설명 | 품질 | 비용 |
|---|---|---|---|
| 인간 주석 | 전문가가 각 단계를 정답/오답으로 레이블 | 최상 | 매우 높음 |
| MC Estimation | 각 단계에서 completion을 여러 번 샘플링하여 정답 비율 계산 | 중간 | 높음 |
| LLM-as-Judge | GPT-4 등으로 각 단계 평가 | 중상 | 중간 |
| Self-labeling | 학습 모델 자체가 평가 | 낮음 | 낮음 |
MC Estimation 방식:
문제: 2x + 3 = 7에서 x는?
풀이 단계 1: 양변에서 3을 뺀다 -> 2x = 4
-> 이 단계 이후 100회 completion 수행
-> 78회 정답 도달
-> PRM score = 0.78
풀이 단계 2: 양변을 2로 나눈다 -> x = 2
-> 이 단계 이후 100회 completion 수행
-> 95회 정답 도달
-> PRM score = 0.95
문제점 (Luo et al., 2025):
- Completion 모델의 능력에 의존
- 현재 단계가 맞더라도 이후 단계에서 실패할 수 있음
- 단계의 "올바름"과 "정답 도달 가능성"이 다를 수 있음
PRM 활용 패턴¶
1. Best-of-N Reranking:
N개 답변 생성 -> PRM으로 각 답변의 최소 단계 점수 기반 순위 매김
score(y) = min_{t} PRM(x, y_{1:t}) 또는 prod_{t} PRM(x, y_{1:t})
2. Step-level Beam Search:
각 추론 단계마다 PRM 점수 기반으로 beam 유지
3. MCTS Value Function:
PRM을 MCTS의 노드 가치 추정에 활용
4. RL Training Signal:
PRM을 dense reward로 사용하여 policy 학습
주요 시스템 비교¶
추론 모델 계보¶
2024.09 OpenAI o1-preview -- 최초 상용 추론 모델
2024.12 OpenAI o1 -- 정식 출시
2024.12 Google Gemini 2.0 Flash Thinking
2025.01 DeepSeek-R1 -- 오픈소스 추론 모델
2025.01 s1 (Simple Test-Time Scaling)
2025.01 QwQ (Alibaba)
2025.02 OpenAI o3 -- IOI 금메달 수준
2025.04 Claude 3.5 Sonnet (Extended Thinking)
2025.06 OpenAI o4-mini
2025~ 다수의 오픈소스 추론 모델 등장
기술 비교¶
| 시스템 | 추론 방식 | 학습 방법 | 오픈소스 | 핵심 특징 |
|---|---|---|---|---|
| o1/o3 | Internal CoT | RL (추정) | X | Hidden reasoning, 토큰 수 자동 조절 |
| DeepSeek-R1 | Explicit long CoT | GRPO (RL) | O | Cold start + 순수 RL, 자기 검증 출현 |
| s1 | Budget forcing | SFT (1K examples) | O | 극도로 적은 학습 데이터, 예산 제어 |
| QwQ | Extended thinking | SFT + RL | O (부분) | Qwen 기반 추론 특화 |
성능 비교 (AIME 2024 기준)¶
| 모델 | AIME 2024 (%) | 추론 방식 |
|---|---|---|
| GPT-4o | ~13 | Standard |
| Claude 3.5 Sonnet | ~16 | Standard |
| o1-preview | ~44 | Internal CoT |
| DeepSeek-R1 | ~79 | Long CoT + RL |
| o1 | ~83 | Internal CoT |
| o3 (high compute) | ~96 | Internal CoT + Search |
| s1 (budget forced) | ~57 | Budget forcing |
Compute-Optimal Inference¶
Snell et al. (2024)의 핵심 기여는 "주어진 추론 연산 예산을 어떻게 최적 배분하는가"이다.
문제 난이도별 최적 전략¶
Easy Questions (낮은 난이도):
최적 전략: 단일 생성, 추가 연산 불필요
추가 연산의 한계 효용: 낮음
Medium Questions (중간 난이도):
최적 전략: Best-of-N 또는 짧은 beam search
추가 연산의 한계 효용: 중간
Hard Questions (높은 난이도):
최적 전략: PRM 기반 tree search 또는 extended CoT
추가 연산의 한계 효용: 높음
핵심 통찰:
- 어려운 문제에 더 많은 연산을 투입해야 함
- 난이도를 사전에 예측하여 연산을 적응적으로 배분
- "Compute-optimal" = 난이도별 최적 연산 배분
연산 배분 함수¶
주어진 총 예산 C_total과 문제 집합 {x_1, ..., x_M}에 대해:
max_{c_1, ..., c_M} sum_i P(correct | x_i, c_i)
s.t. sum_i c_i <= C_total
여기서 c_i는 문제 x_i에 배분된 추론 연산량
최적 해: 한계 효용이 균등하도록 배분
dP/dc_i = lambda (all i)
실용적 근사:
1. 난이도 추정기로 각 문제의 난이도 d_i 예측
2. 난이도에 따라 연산 예산 배분:
c_i = C_total * w(d_i) / sum_j w(d_j)
3. 가중 함수 w(d)는 어려운 문제에 더 많은 예산 부여
Python 구현¶
Best-of-N with Reward Model¶
import torch
import torch.nn as nn
from typing import List, Tuple, Optional
from dataclasses import dataclass
@dataclass
class GenerationResult:
"""생성 결과"""
text: str
steps: List[str]
score: float = 0.0
tokens_used: int = 0
class BestOfNSampler:
"""Best-of-N 추론 시 연산 확장"""
def __init__(
self,
generator, # LLM 생성 모델
verifier, # 보상 모델 (ORM 또는 PRM)
n_samples: int = 16,
temperature: float = 0.7,
use_prm: bool = True
):
self.generator = generator
self.verifier = verifier
self.n_samples = n_samples
self.temperature = temperature
self.use_prm = use_prm
def generate_candidates(
self,
prompt: str,
n: int
) -> List[GenerationResult]:
"""N개의 후보 답변 생성"""
candidates = []
for _ in range(n):
# 온도 샘플링으로 다양한 답변 생성
output = self.generator.generate(
prompt,
temperature=self.temperature,
max_tokens=2048
)
# 추론 단계 파싱
steps = self._parse_steps(output.text)
candidates.append(GenerationResult(
text=output.text,
steps=steps,
tokens_used=output.num_tokens
))
return candidates
def score_candidates(
self,
prompt: str,
candidates: List[GenerationResult]
) -> List[GenerationResult]:
"""보상 모델로 후보 점수 매기기"""
for candidate in candidates:
if self.use_prm:
# PRM: 각 단계별 점수의 최솟값 또는 곱
step_scores = []
for i, step in enumerate(candidate.steps):
partial = "\n".join(candidate.steps[:i+1])
score = self.verifier.score(prompt, partial)
step_scores.append(score)
# 최솟값 전략 (가장 약한 단계 기준)
candidate.score = min(step_scores) if step_scores else 0.0
else:
# ORM: 전체 답변 점수
candidate.score = self.verifier.score(
prompt, candidate.text
)
return candidates
def solve(self, prompt: str) -> GenerationResult:
"""Best-of-N으로 문제 풀기"""
# 1. N개 후보 생성
candidates = self.generate_candidates(prompt, self.n_samples)
# 2. 점수 매기기
candidates = self.score_candidates(prompt, candidates)
# 3. 최고 점수 선택
best = max(candidates, key=lambda c: c.score)
return best
def _parse_steps(self, text: str) -> List[str]:
"""추론 텍스트를 단계별로 분리"""
# 줄바꿈 또는 "Step N:" 패턴으로 분리
lines = text.strip().split("\n")
steps = []
current_step = []
for line in lines:
if line.strip().startswith("Step") or line.strip().startswith("단계"):
if current_step:
steps.append("\n".join(current_step))
current_step = [line]
else:
current_step.append(line)
if current_step:
steps.append("\n".join(current_step))
return steps if steps else [text]
class SelfConsistency:
"""Self-Consistency (다수결 투표) 기반 추론 확장"""
def __init__(self, generator, n_samples: int = 40, temperature: float = 0.7):
self.generator = generator
self.n_samples = n_samples
self.temperature = temperature
def solve(self, prompt: str) -> str:
"""다수결 투표로 답변 결정"""
answers = []
for _ in range(self.n_samples):
output = self.generator.generate(
prompt,
temperature=self.temperature,
max_tokens=2048
)
answer = self._extract_answer(output.text)
answers.append(answer)
# 다수결 투표
from collections import Counter
counter = Counter(answers)
best_answer, count = counter.most_common(1)[0]
return best_answer
def _extract_answer(self, text: str) -> str:
"""최종 답변 추출 (boxed 형식 등)"""
import re
# LaTeX boxed 형식
match = re.search(r'\\boxed\{([^}]+)\}', text)
if match:
return match.group(1).strip()
# "답: " 형식
match = re.search(r'답[:\s]+(.+?)(?:\n|$)', text)
if match:
return match.group(1).strip()
# 마지막 줄
lines = text.strip().split('\n')
return lines[-1].strip()
Process Reward Model 구현¶
import torch
import torch.nn as nn
from transformers import AutoModel, AutoTokenizer
class ProcessRewardModel(nn.Module):
"""Process Reward Model (PRM)"""
def __init__(
self,
base_model_name: str = "meta-llama/Llama-3-8B",
hidden_size: int = 4096
):
super().__init__()
self.backbone = AutoModel.from_pretrained(base_model_name)
self.tokenizer = AutoTokenizer.from_pretrained(base_model_name)
# 단계별 점수 예측 헤드
self.reward_head = nn.Sequential(
nn.Linear(hidden_size, 256),
nn.ReLU(),
nn.Dropout(0.1),
nn.Linear(256, 1),
nn.Sigmoid()
)
# Step delimiter token
self.step_token = "\n" # 또는 특수 토큰
def forward(
self,
input_ids: torch.Tensor,
attention_mask: torch.Tensor,
step_positions: torch.Tensor
) -> torch.Tensor:
"""
Args:
input_ids: (batch, seq_len) 토큰 ID
attention_mask: (batch, seq_len)
step_positions: (batch, max_steps) 각 단계 끝 위치
Returns:
step_scores: (batch, max_steps) 각 단계의 정확도 점수
"""
# Backbone forward
outputs = self.backbone(
input_ids=input_ids,
attention_mask=attention_mask
)
hidden_states = outputs.last_hidden_state # (batch, seq, hidden)
# 각 단계 끝 위치의 hidden state 추출
batch_size = input_ids.shape[0]
max_steps = step_positions.shape[1]
step_hidden = []
for b in range(batch_size):
for s in range(max_steps):
pos = step_positions[b, s]
if pos >= 0:
step_hidden.append(hidden_states[b, pos])
else:
step_hidden.append(torch.zeros_like(hidden_states[b, 0]))
step_hidden = torch.stack(step_hidden).view(batch_size, max_steps, -1)
# 점수 예측
step_scores = self.reward_head(step_hidden).squeeze(-1)
return step_scores
def score_solution(
self,
problem: str,
solution_steps: list
) -> list:
"""풀이의 각 단계에 대한 점수 반환"""
scores = []
for i in range(len(solution_steps)):
# 문제 + 현재까지의 단계를 입력으로 구성
partial_solution = "\n".join(solution_steps[:i+1])
text = f"Problem: {problem}\nSolution:\n{partial_solution}"
tokens = self.tokenizer(
text,
return_tensors="pt",
truncation=True,
max_length=2048
)
with torch.no_grad():
outputs = self.backbone(**tokens)
# 마지막 토큰의 hidden state 사용
last_hidden = outputs.last_hidden_state[:, -1, :]
score = self.reward_head(last_hidden).item()
scores.append(score)
return scores
class PRMTrainer:
"""PRM 학습기 (MC Estimation 기반)"""
def __init__(
self,
prm: ProcessRewardModel,
completion_model, # 완성용 LLM
n_completions: int = 32,
learning_rate: float = 1e-5
):
self.prm = prm
self.completion_model = completion_model
self.n_completions = n_completions
self.optimizer = torch.optim.AdamW(
prm.parameters(), lr=learning_rate
)
def estimate_step_labels(
self,
problem: str,
steps: list,
correct_answer: str
) -> list:
"""
MC Estimation으로 각 단계의 레이블 추정
각 단계 이후 N번 completion하여 정답 비율 계산
"""
labels = []
for i in range(len(steps)):
partial = "\n".join(steps[:i+1])
prompt = f"Problem: {problem}\nSolution so far:\n{partial}\nContinue:"
correct_count = 0
for _ in range(self.n_completions):
completion = self.completion_model.generate(
prompt, temperature=0.8, max_tokens=512
)
answer = self._extract_answer(completion)
if self._check_answer(answer, correct_answer):
correct_count += 1
label = correct_count / self.n_completions
labels.append(label)
return labels
def train_step(self, batch):
"""한 배치 학습"""
self.optimizer.zero_grad()
# Forward
predicted_scores = self.prm(
batch['input_ids'],
batch['attention_mask'],
batch['step_positions']
)
# MSE Loss (MC 추정 레이블과의 차이)
target_scores = batch['step_labels']
mask = batch['step_mask'] # 유효한 단계만
loss = ((predicted_scores - target_scores) ** 2 * mask).sum() / mask.sum()
loss.backward()
self.optimizer.step()
return loss.item()
def _extract_answer(self, text):
import re
match = re.search(r'\\boxed\{([^}]+)\}', text)
return match.group(1).strip() if match else text.strip().split('\n')[-1]
def _check_answer(self, pred, target):
return pred.strip() == target.strip()
MCTS for LLM Reasoning¶
import math
import random
from dataclasses import dataclass, field
from typing import Optional
@dataclass
class MCTSNode:
"""MCTS 트리 노드"""
state: str # 현재까지의 추론 텍스트
parent: Optional['MCTSNode'] = None
children: list = field(default_factory=list)
visits: int = 0
value: float = 0.0
step_text: str = "" # 이 노드에서 추가된 단계
is_terminal: bool = False
@property
def q_value(self) -> float:
return self.value / max(self.visits, 1)
def ucb1(self, c: float = 1.414) -> float:
if self.visits == 0:
return float('inf')
exploitation = self.q_value
exploration = c * math.sqrt(
math.log(self.parent.visits) / self.visits
)
return exploitation + exploration
class ReasoningMCTS:
"""LLM 추론을 위한 MCTS"""
def __init__(
self,
generator, # LLM 생성 모델
prm, # Process Reward Model
max_iterations: int = 100,
max_depth: int = 10,
n_expansions: int = 3,
c_explore: float = 1.414
):
self.generator = generator
self.prm = prm
self.max_iterations = max_iterations
self.max_depth = max_depth
self.n_expansions = n_expansions
self.c_explore = c_explore
def solve(self, problem: str) -> str:
"""MCTS로 문제 풀기"""
root = MCTSNode(state=f"Problem: {problem}\nSolution:\n")
for _ in range(self.max_iterations):
# 1. Selection
node = self._select(root)
if node.is_terminal:
continue
# 2. Expansion
children = self._expand(node, problem)
if not children:
node.is_terminal = True
continue
# 3. Simulation & Evaluation
for child in children:
value = self._evaluate(problem, child)
# 4. Backpropagation
self._backpropagate(child, value)
# 최적 경로 추출
return self._best_path(root)
def _select(self, node: MCTSNode) -> MCTSNode:
"""UCB1 기반 노드 선택"""
while node.children and not node.is_terminal:
node = max(
node.children,
key=lambda c: c.ucb1(self.c_explore)
)
return node
def _expand(
self,
node: MCTSNode,
problem: str
) -> list:
"""LLM으로 다음 추론 단계 생성"""
children = []
for _ in range(self.n_expansions):
# 다음 단계 생성
next_step = self.generator.generate(
node.state + "\nNext step:",
temperature=0.8,
max_tokens=256,
stop=["\n\n"] # 한 단계만 생성
)
step_text = next_step.text.strip()
if not step_text:
continue
# 종료 조건 확인
is_terminal = self._is_terminal(step_text)
child = MCTSNode(
state=node.state + step_text + "\n",
parent=node,
step_text=step_text,
is_terminal=is_terminal
)
node.children.append(child)
children.append(child)
return children
def _evaluate(self, problem: str, node: MCTSNode) -> float:
"""PRM으로 노드 가치 평가"""
# 현재 경로의 단계들 추출
steps = self._extract_steps(node.state)
if not steps:
return 0.0
# PRM 점수
scores = self.prm.score_solution(problem, steps)
# 최소 점수 반환 (가장 약한 단계 기준)
return min(scores) if scores else 0.0
def _backpropagate(self, node: MCTSNode, value: float):
"""결과를 루트까지 전파"""
while node is not None:
node.visits += 1
node.value += value
node = node.parent
def _best_path(self, root: MCTSNode) -> str:
"""방문 횟수 기준 최적 경로 추출"""
node = root
path = []
while node.children:
node = max(node.children, key=lambda c: c.visits)
path.append(node.step_text)
return "\n".join(path)
def _is_terminal(self, text: str) -> bool:
"""종료 조건 확인"""
terminal_markers = [
"\\boxed{", "답:", "therefore", "따라서",
"final answer", "최종 답"
]
return any(m in text.lower() for m in terminal_markers)
def _extract_steps(self, state: str) -> list:
"""상태에서 추론 단계 추출"""
lines = state.split("\n")
steps = [l for l in lines if l.strip() and not l.startswith("Problem:")]
return steps
Compute-Optimal Inference Router¶
import numpy as np
from typing import Dict, Callable
from enum import Enum
class DifficultyLevel(Enum):
EASY = "easy"
MEDIUM = "medium"
HARD = "hard"
VERY_HARD = "very_hard"
class InferenceRouter:
"""
문제 난이도에 따른 최적 추론 전략 라우터
Snell et al. (2024)의 compute-optimal inference 구현
"""
def __init__(
self,
generator,
prm=None,
difficulty_estimator=None,
total_compute_budget: int = 1000 # 토큰 단위
):
self.generator = generator
self.prm = prm
self.difficulty_estimator = difficulty_estimator
self.total_compute_budget = total_compute_budget
# 난이도별 전략 매핑
self.strategies: Dict[DifficultyLevel, Callable] = {
DifficultyLevel.EASY: self._strategy_single,
DifficultyLevel.MEDIUM: self._strategy_best_of_n,
DifficultyLevel.HARD: self._strategy_beam_search,
DifficultyLevel.VERY_HARD: self._strategy_mcts
}
def solve(self, problem: str) -> str:
"""문제 난이도를 추정하고 적절한 전략 선택"""
difficulty = self._estimate_difficulty(problem)
strategy = self.strategies[difficulty]
return strategy(problem)
def solve_batch(self, problems: list) -> list:
"""
배치 문제에 대해 연산 예산을 최적 배분
쉬운 문제에 적은 예산, 어려운 문제에 많은 예산
"""
# 1. 난이도 추정
difficulties = [self._estimate_difficulty(p) for p in problems]
# 2. 연산 예산 배분
weights = {
DifficultyLevel.EASY: 1,
DifficultyLevel.MEDIUM: 4,
DifficultyLevel.HARD: 16,
DifficultyLevel.VERY_HARD: 64
}
total_weight = sum(weights[d] for d in difficulties)
budgets = [
int(self.total_compute_budget * weights[d] / total_weight)
for d in difficulties
]
# 3. 각 문제를 배분된 예산으로 풀기
results = []
for problem, difficulty, budget in zip(problems, difficulties, budgets):
result = self._solve_with_budget(problem, difficulty, budget)
results.append(result)
return results
def _estimate_difficulty(self, problem: str) -> DifficultyLevel:
"""문제 난이도 추정"""
if self.difficulty_estimator:
score = self.difficulty_estimator.predict(problem)
else:
# 간단한 휴리스틱: 문제 길이, 수학 기호 수 등
score = self._heuristic_difficulty(problem)
if score < 0.25:
return DifficultyLevel.EASY
elif score < 0.50:
return DifficultyLevel.MEDIUM
elif score < 0.75:
return DifficultyLevel.HARD
else:
return DifficultyLevel.VERY_HARD
def _heuristic_difficulty(self, problem: str) -> float:
"""휴리스틱 난이도 추정"""
import re
features = {
'length': len(problem.split()) / 200,
'math_symbols': len(re.findall(r'[+\-*/^=<>]', problem)) / 20,
'numbers': len(re.findall(r'\d+', problem)) / 10,
'nested_ops': problem.count('(') / 5,
'keywords': sum(1 for kw in [
'prove', 'show that', 'find all',
'maximize', 'minimize', 'determine'
] if kw in problem.lower()) / 3
}
score = np.mean([min(v, 1.0) for v in features.values()])
return score
def _strategy_single(self, problem: str) -> str:
"""쉬운 문제: 단일 생성"""
output = self.generator.generate(
problem, temperature=0.0, max_tokens=512
)
return output.text
def _strategy_best_of_n(self, problem: str, n: int = 8) -> str:
"""중간 난이도: Best-of-N"""
sampler = BestOfNSampler(
self.generator, self.prm,
n_samples=n, use_prm=(self.prm is not None)
)
result = sampler.solve(problem)
return result.text
def _strategy_beam_search(self, problem: str) -> str:
"""어려운 문제: PRM 기반 Beam Search"""
# 간소화된 beam search
beam_width = 4
beams = [f"Problem: {problem}\nSolution:\n"]
for depth in range(8):
candidates = []
for beam in beams:
for _ in range(3):
step = self.generator.generate(
beam + "Next step:",
temperature=0.7,
max_tokens=256,
stop=["\n\n"]
)
new_beam = beam + step.text.strip() + "\n"
if self.prm:
steps = [l for l in new_beam.split("\n")
if l.strip() and not l.startswith("Problem:")]
scores = self.prm.score_solution(problem, steps)
score = min(scores) if scores else 0.0
else:
score = random.random()
candidates.append((new_beam, score))
# 상위 beam_width개 선택
candidates.sort(key=lambda x: x[1], reverse=True)
beams = [c[0] for c in candidates[:beam_width]]
return beams[0] if beams else ""
def _strategy_mcts(self, problem: str) -> str:
"""매우 어려운 문제: MCTS"""
mcts = ReasoningMCTS(
self.generator, self.prm,
max_iterations=50,
n_expansions=3
)
return mcts.solve(problem)
def _solve_with_budget(
self,
problem: str,
difficulty: DifficultyLevel,
budget: int
) -> str:
"""예산 제한 내에서 최적 풀이"""
# 예산에 따라 N 또는 반복 횟수 조정
if difficulty == DifficultyLevel.EASY:
return self._strategy_single(problem)
elif difficulty == DifficultyLevel.MEDIUM:
n = max(2, budget // 128)
return self._strategy_best_of_n(problem, n=n)
else:
return self._strategy_beam_search(problem)
실무 가이드라인¶
방법 선택 기준¶
문제 유형 분석
|
+-- 수학/코딩 (정답 검증 가능)
| +-- 난이도 낮음 -> Self-Consistency (N=8~16)
| +-- 난이도 높음 -> PRM + Beam Search 또는 MCTS
|
+-- 자유 형식 (정답 검증 어려움)
| +-- ORM + Best-of-N
| +-- Extended CoT (긴 추론 체인)
|
+-- 실시간 응답 필요
+-- Budget Forcing (토큰 수 제어)
+-- 경량 PRM + Best-of-4
비용-성능 트레이드오프¶
| 방법 | 연산 비용 | 성능 향상 | 추가 모델 필요 | 구현 복잡도 |
|---|---|---|---|---|
| Self-Consistency | 낮음 (N배) | 5-15% | X | 낮음 |
| Best-of-N + ORM | 중간 | 10-25% | ORM | 중간 |
| Best-of-N + PRM | 높음 | 15-35% | PRM | 높음 |
| Beam Search + PRM | 높음 | 20-40% | PRM | 높음 |
| MCTS + PRM | 매우 높음 | 25-50% | PRM | 매우 높음 |
| Extended CoT (RL) | 중간 | 30-60% | X (내재화) | 학습 비용 높음 |
주의사항¶
1. 수확 체감 (Diminishing Returns)
- Best-of-N에서 N > 64 이후 개선폭 급감
- MCTS에서 반복 > 200 이후 효용 미미
- 문제의 본질적 난이도에 따른 상한 존재
2. 검증기 품질이 핵심
- PRM 품질이 낮으면 오히려 성능 저하
- 잘못된 단계에 높은 점수 -> 오답 선택
- 검증기 학습에 투자하는 것이 중요
3. 분포 이동 (Distribution Shift)
- PRM은 학습 시와 다른 분포의 추론에서 성능 저하
- 생성 모델이 변경되면 PRM도 재학습 필요
- 반복적 self-improvement에서 주의
4. 비용 관리
- 프로덕션 환경에서 N배 비용 증가
- 난이도 라우팅으로 평균 비용 절감 가능
- 사용자 대면 vs 배치 처리에 따라 전략 차별화
관련 연구 흐름¶
Chain-of-Thought (Wei et al., 2022)
|
+-- Self-Consistency (Wang et al., 2022)
|
+-- Let's Verify Step by Step (Lightman et al., 2023)
| |
| +-- PRM800K 데이터셋
| +-- Process 검증이 Outcome 검증보다 우수함 증명
|
+-- Scaling Test-Time Compute (Snell et al., NeurIPS 2024)
| |
| +-- Compute-optimal inference 이론 정립
| +-- 난이도별 최적 전략 분석
|
+-- OpenAI o1 (2024.09)
| |
| +-- Internal CoT의 상용화
| +-- Hidden reasoning tokens
|
+-- DeepSeek-R1 (2025.01)
| |
| +-- 순수 RL로 추론 능력 출현
| +-- GRPO 알고리즘
| +-- 오픈소스 추론 모델
|
+-- s1 (Muennighoff et al., 2025.01)
| |
| +-- Budget Forcing
| +-- 1K 예제만으로 추론 모델 학습
|
+-- ReST-MCTS* (NeurIPS 2024)
| |
| +-- PRM 기반 MCTS + 자기 학습
|
+-- Rewarding Progress (ICLR 2025)
|
+-- 자동화된 PRM 학습 스케일링
참고 자료¶
핵심 논문¶
- Snell, C., Lee, J., Xu, K., & Kumar, A. (2024). Scaling LLM Test-Time Compute Optimally can be More Effective than Scaling Model Parameters. NeurIPS 2024. arXiv:2408.03314.
- Lightman, H. et al. (2023). Let's Verify Step by Step. arXiv:2305.20050.
- DeepSeek-AI. (2025). DeepSeek-R1: Incentivizing Reasoning Capability in LLMs via Reinforcement Learning. arXiv:2501.12948.
- Muennighoff, N. et al. (2025). s1: Simple Test-Time Scaling. arXiv:2501.19393.
확장 논문¶
- Wang, X. et al. (2022). Self-Consistency Improves Chain of Thought Reasoning in Language Models. ICLR 2023.
- Zhang, D. et al. (2024). ReST-MCTS*: LLM Self-Training via Process Reward Guided Tree Search. NeurIPS 2024.
- Setlur, A. et al. (2025). Rewarding Progress: Scaling Automated Process Verifiers for LLM Reasoning. ICLR 2025.
- Luo, L. et al. (2025). The Lessons of Developing Process Reward Models in Mathematical Reasoning. arXiv:2501.07301.
관련 개념¶
- Chain-of-Thought Reasoning: CoT 추론의 기초
- Direct Preference Optimization: RL 기반 정렬
- Knowledge Distillation: 추론 모델 경량화 (R1 -> 소형 모델 증류)
- Neural Scaling Laws: 학습 시 Scaling Law
- Speculative Decoding: 추론 효율화 기법