콘텐츠로 이동
Data Prep
상세

Inference-Time Scaling (추론 시 연산 확장)

메타 정보

항목 내용
분류 LLM Reasoning / Test-Time Compute / Scaling Laws
핵심 논문 "Scaling LLM Test-Time Compute Optimally can be More Effective than Scaling Model Parameters" (Snell et al., NeurIPS 2024, arXiv 2408.03314)
주요 저자 Charlie Snell, Jaehoon Lee, Kelvin Xu, Aviral Kumar (Google DeepMind)
핵심 개념 사전학습 단계의 모델 크기/데이터 확장 대신, 추론 시점의 연산량을 늘려 성능을 향상시키는 패러다임
관련 시스템 OpenAI o1/o3, DeepSeek-R1, s1, QwQ, Gemini 2.0 Flash Thinking
관련 분야 Process Reward Models, MCTS, Chain-of-Thought, Reinforcement Learning, Self-Play

정의

Inference-Time Scaling (ITS)은 학습이 완료된 LLM에 대해 추론 단계에서 추가 연산을 투입하여 응답 품질을 향상시키는 기법의 총칭이다. 전통적인 Scaling Law가 학습 시 모델 크기(N), 데이터(D), 연산량(C)의 확장에 초점을 맞췄다면, ITS는 고정된 모델에서 추론 시 연산(C_inference)을 확장하는 새로운 축을 제시한다.

전통적 Scaling Law (Kaplan et al., 2020; Hoffmann et al., 2022):
  Performance ~ f(N, D, C_train)

  성능 개선 = 더 큰 모델 + 더 많은 데이터 + 더 많은 학습 연산

Inference-Time Scaling (Snell et al., 2024):
  Performance ~ g(N, C_train, C_inference)

  성능 개선 = 고정된 모델 + 추론 시 더 많은 연산

핵심 발견:
  추론 연산 확장이 모델 크기 확장보다 효율적일 수 있음
  (특히 어려운 문제에서)

핵심 아이디어

왜 추론 시 연산을 확장하는가

인간의 문제 해결 과정:
  쉬운 문제: 즉시 답변 (System 1)
  어려운 문제: 숙고, 검증, 재시도 (System 2)

기존 LLM:
  모든 문제에 동일한 연산량 (1회 forward pass)
  -> 어려운 문제에서 체계적 실패

ITS 적용 LLM:
  쉬운 문제: 빠른 응답 (적은 연산)
  어려운 문제: 더 많은 토큰 생성, 검증, 탐색 (많은 연산)
  -> 문제 난이도에 따른 적응적 연산 배분

두 가지 확장 축

설명 예시
학습 시 확장 (Train-Time) 모델 파라미터, 학습 데이터, 학습 연산 증가 GPT-4 -> GPT-5
추론 시 확장 (Test-Time) 고정된 모델에서 추론 연산 증가 o1의 "thinking" 토큰

Snell et al. (2024)의 핵심 결과:

실험 설정: MATH 벤치마크

소형 모델 + 추론 연산 확장  vs  대형 모델 + 단일 추론

결과:
  Llama-3-8B + 최적 추론 확장 >= Llama-3-70B + 단일 추론

  -> 14배 이상의 FLOPs 절감 가능
  -> 어려운 문제일수록 추론 확장의 이점이 큼

추론 시 연산 확장 방법론

크게 두 가지 접근으로 분류된다.

1. 탐색 기반 (Search-Based)

여러 후보 답변을 생성하고 검증기를 통해 최적 답변을 선택한다.

1.1 Best-of-N (Rejection Sampling)

Algorithm: Best-of-N
Input: 문제 x, 생성 모델 G, 검증 모델 V, 샘플 수 N
Output: 최적 답변

1. N개의 독립적 답변 생성:
   y_1, y_2, ..., y_N ~ G(y|x)

2. 각 답변에 대해 검증 점수 계산:
   s_i = V(x, y_i)   for i = 1, ..., N

3. 최고 점수 답변 반환:
   y* = argmax_{y_i} s_i

연산 비용: O(N * C_generate + N * C_verify)
성능 변화: log(N)에 비례하여 개선 (수확 체감)
검증 방법 설명 장점 단점
Self-Consistency 다수결 투표 별도 검증기 불필요 정답이 소수일 때 실패
ORM (Outcome RM) 최종 답변의 정확성 평가 학습 데이터 구축 용이 중간 과정 무시
PRM (Process RM) 각 추론 단계별 평가 세밀한 신용 할당 단계별 레이블 필요

1.2 Beam Search with Reward Model

Algorithm: Beam Search + PRM
Input: 문제 x, 생성 모델 G, PRM V, beam 폭 B, 최대 단계 T
Output: 최적 추론 경로

1. 초기 beam: B_0 = {x}
2. for t = 1 to T:
   a. 각 beam에서 K개 후보 다음 단계 생성:
      candidates = Union_{b in B_{t-1}} {G(step|b)_1, ..., G(step|b)_K}
   b. PRM으로 각 후보 평가:
      scores = {V(candidate) for candidate in candidates}
   c. 상위 B개 선택:
      B_t = top_B(candidates, scores)
3. 최종 beam에서 최고 점수 경로 반환

연산 비용: O(T * B * K * C_generate + T * B * K * C_verify)
MCTS for LLM Reasoning:

      [문제]
      /    \
   [단계1a]  [단계1b]     -- 확장
    / \        |
 [2a] [2b]   [2c]         -- 시뮬레이션
  |     |      |
 답A   답B    답C          -- 평가
 0.8   0.3    0.6          -- PRM 점수

4단계 반복:
  Selection:  UCB1으로 유망한 노드 선택
  Expansion:  LLM으로 다음 단계 생성
  Simulation: 끝까지 생성 (rollout)
  Backprop:   결과를 부모 노드로 전파

UCB1(s) = Q(s)/N(s) + c * sqrt(ln(N_parent) / N(s))

ReST-MCTS* (NeurIPS 2024): PRM을 가치 함수로 사용하는 MCTS 기반 자기 학습 프레임워크. MCTS로 고품질 추론 경로를 탐색하고, 이를 학습 데이터로 활용하여 모델과 PRM을 반복 개선한다.

2. 내재화 기반 (Internalized Reasoning)

모델 자체가 긴 추론 체인을 생성하도록 학습한다.

2.1 Long Chain-of-Thought (Extended CoT)

일반 CoT (수백 토큰):
  문제 -> [짧은 풀이] -> 답

Extended CoT (수천~수만 토큰):
  문제 -> [단계 1] -> [검증] -> [단계 2] -> [오류 발견]
       -> [재시도] -> [단계 2'] -> [검증] -> ... -> 답

특징:
  - 자기 검증 (self-verification)
  - 오류 수정 (backtracking)
  - 다중 접근법 시도
  - 명시적 불확실성 표현

2.2 Budget Forcing (s1)

Muennighoff et al. (2025)의 s1 모델이 제안한 기법이다.

Budget Forcing:

1. 연산 예산 T_budget 설정 (원하는 thinking 토큰 수)

2. 과소 사용 방지:
   if thinking_tokens < T_budget:
     "Wait" 토큰을 삽입하여 추가 사고 유도
     -> 모델이 검증, 재확인, 대안 탐색

3. 과다 사용 방지:
   if thinking_tokens > T_budget:
     강제로 end-of-thinking 토큰 삽입
     -> 즉시 최종 답변 생성

효과 (AIME 2024):
  Budget forcing 없이: 50%
  Budget forcing 적용: 57% (+7%p)

2.3 RL 기반 추론 학습

DeepSeek-R1 (2025)의 접근법이다.

DeepSeek-R1 학습 파이프라인:

Phase 1: Cold Start
  - 소량의 고품질 long-CoT 데이터로 SFT
  - 추론 형식(think -> answer) 학습

Phase 2: 순수 RL (GRPO)
  - Reward: 정답 여부 (outcome reward)
  - Policy: 추론 체인 생성
  - 사전 정의된 추론 형식 없이 RL만으로 학습

  GRPO (Group Relative Policy Optimization):
    동일 문제에 대해 그룹 내 상대적 보상으로 학습
    PPO 대비 별도 가치 네트워크 불필요 -> 메모리 효율적

Phase 3: Rejection Sampling + SFT
  - RL 모델로 대량 추론 경로 생성
  - 정답인 경로만 필터링
  - 필터링된 데이터로 재학습

Phase 4: 2차 RL
  - 안전성, 유용성 보상 추가
  - 최종 모델 완성

Process Reward Model (PRM) 상세

ITS의 핵심 구성 요소인 PRM을 상세히 다룬다.

ORM vs PRM

Outcome Reward Model (ORM):
  입력: (문제, 전체 풀이)
  출력: 최종 답변의 정확도 점수

  장점: 학습 데이터 구축이 쉬움 (정답/오답만 필요)
  단점: 어디서 틀렸는지 모름 -> sparse credit assignment

Process Reward Model (PRM):
  입력: (문제, 풀이의 각 단계)
  출력: 각 단계의 정확도 점수

  장점: 세밀한 피드백 -> dense credit assignment
  단점: 단계별 레이블이 필요 -> 비용이 높음

PRM 학습 데이터 구축

방법 설명 품질 비용
인간 주석 전문가가 각 단계를 정답/오답으로 레이블 최상 매우 높음
MC Estimation 각 단계에서 completion을 여러 번 샘플링하여 정답 비율 계산 중간 높음
LLM-as-Judge GPT-4 등으로 각 단계 평가 중상 중간
Self-labeling 학습 모델 자체가 평가 낮음 낮음
MC Estimation 방식:

문제: 2x + 3 = 7에서 x는?

풀이 단계 1: 양변에서 3을 뺀다 -> 2x = 4
  -> 이 단계 이후 100회 completion 수행
  -> 78회 정답 도달
  -> PRM score = 0.78

풀이 단계 2: 양변을 2로 나눈다 -> x = 2
  -> 이 단계 이후 100회 completion 수행
  -> 95회 정답 도달
  -> PRM score = 0.95

문제점 (Luo et al., 2025):
  - Completion 모델의 능력에 의존
  - 현재 단계가 맞더라도 이후 단계에서 실패할 수 있음
  - 단계의 "올바름"과 "정답 도달 가능성"이 다를 수 있음

PRM 활용 패턴

1. Best-of-N Reranking:
   N개 답변 생성 -> PRM으로 각 답변의 최소 단계 점수 기반 순위 매김

   score(y) = min_{t} PRM(x, y_{1:t})  또는  prod_{t} PRM(x, y_{1:t})

2. Step-level Beam Search:
   각 추론 단계마다 PRM 점수 기반으로 beam 유지

3. MCTS Value Function:
   PRM을 MCTS의 노드 가치 추정에 활용

4. RL Training Signal:
   PRM을 dense reward로 사용하여 policy 학습

주요 시스템 비교

추론 모델 계보

2024.09  OpenAI o1-preview          -- 최초 상용 추론 모델
2024.12  OpenAI o1                  -- 정식 출시
2024.12  Google Gemini 2.0 Flash Thinking
2025.01  DeepSeek-R1                -- 오픈소스 추론 모델
2025.01  s1 (Simple Test-Time Scaling)
2025.01  QwQ (Alibaba)
2025.02  OpenAI o3                  -- IOI 금메달 수준
2025.04  Claude 3.5 Sonnet (Extended Thinking)
2025.06  OpenAI o4-mini
2025~    다수의 오픈소스 추론 모델 등장

기술 비교

시스템 추론 방식 학습 방법 오픈소스 핵심 특징
o1/o3 Internal CoT RL (추정) X Hidden reasoning, 토큰 수 자동 조절
DeepSeek-R1 Explicit long CoT GRPO (RL) O Cold start + 순수 RL, 자기 검증 출현
s1 Budget forcing SFT (1K examples) O 극도로 적은 학습 데이터, 예산 제어
QwQ Extended thinking SFT + RL O (부분) Qwen 기반 추론 특화

성능 비교 (AIME 2024 기준)

모델 AIME 2024 (%) 추론 방식
GPT-4o ~13 Standard
Claude 3.5 Sonnet ~16 Standard
o1-preview ~44 Internal CoT
DeepSeek-R1 ~79 Long CoT + RL
o1 ~83 Internal CoT
o3 (high compute) ~96 Internal CoT + Search
s1 (budget forced) ~57 Budget forcing

Compute-Optimal Inference

Snell et al. (2024)의 핵심 기여는 "주어진 추론 연산 예산을 어떻게 최적 배분하는가"이다.

문제 난이도별 최적 전략

Easy Questions (낮은 난이도):
  최적 전략: 단일 생성, 추가 연산 불필요
  추가 연산의 한계 효용: 낮음

Medium Questions (중간 난이도):
  최적 전략: Best-of-N 또는 짧은 beam search
  추가 연산의 한계 효용: 중간

Hard Questions (높은 난이도):
  최적 전략: PRM 기반 tree search 또는 extended CoT
  추가 연산의 한계 효용: 높음

핵심 통찰:
  - 어려운 문제에 더 많은 연산을 투입해야 함
  - 난이도를 사전에 예측하여 연산을 적응적으로 배분
  - "Compute-optimal" = 난이도별 최적 연산 배분

연산 배분 함수

주어진 총 예산 C_total과 문제 집합 {x_1, ..., x_M}에 대해:

max_{c_1, ..., c_M} sum_i P(correct | x_i, c_i)
s.t. sum_i c_i <= C_total

여기서 c_i는 문제 x_i에 배분된 추론 연산량

최적 해: 한계 효용이 균등하도록 배분
  dP/dc_i = lambda (all i)

실용적 근사:
  1. 난이도 추정기로 각 문제의 난이도 d_i 예측
  2. 난이도에 따라 연산 예산 배분:
     c_i = C_total * w(d_i) / sum_j w(d_j)
  3. 가중 함수 w(d)는 어려운 문제에 더 많은 예산 부여

Python 구현

Best-of-N with Reward Model

import torch
import torch.nn as nn
from typing import List, Tuple, Optional
from dataclasses import dataclass


@dataclass
class GenerationResult:
    """생성 결과"""
    text: str
    steps: List[str]
    score: float = 0.0
    tokens_used: int = 0


class BestOfNSampler:
    """Best-of-N 추론 시 연산 확장"""

    def __init__(
        self,
        generator,      # LLM 생성 모델
        verifier,        # 보상 모델 (ORM 또는 PRM)
        n_samples: int = 16,
        temperature: float = 0.7,
        use_prm: bool = True
    ):
        self.generator = generator
        self.verifier = verifier
        self.n_samples = n_samples
        self.temperature = temperature
        self.use_prm = use_prm

    def generate_candidates(
        self,
        prompt: str,
        n: int
    ) -> List[GenerationResult]:
        """N개의 후보 답변 생성"""
        candidates = []

        for _ in range(n):
            # 온도 샘플링으로 다양한 답변 생성
            output = self.generator.generate(
                prompt,
                temperature=self.temperature,
                max_tokens=2048
            )

            # 추론 단계 파싱
            steps = self._parse_steps(output.text)

            candidates.append(GenerationResult(
                text=output.text,
                steps=steps,
                tokens_used=output.num_tokens
            ))

        return candidates

    def score_candidates(
        self,
        prompt: str,
        candidates: List[GenerationResult]
    ) -> List[GenerationResult]:
        """보상 모델로 후보 점수 매기기"""
        for candidate in candidates:
            if self.use_prm:
                # PRM: 각 단계별 점수의 최솟값 또는 곱
                step_scores = []
                for i, step in enumerate(candidate.steps):
                    partial = "\n".join(candidate.steps[:i+1])
                    score = self.verifier.score(prompt, partial)
                    step_scores.append(score)

                # 최솟값 전략 (가장 약한 단계 기준)
                candidate.score = min(step_scores) if step_scores else 0.0
            else:
                # ORM: 전체 답변 점수
                candidate.score = self.verifier.score(
                    prompt, candidate.text
                )

        return candidates

    def solve(self, prompt: str) -> GenerationResult:
        """Best-of-N으로 문제 풀기"""
        # 1. N개 후보 생성
        candidates = self.generate_candidates(prompt, self.n_samples)

        # 2. 점수 매기기
        candidates = self.score_candidates(prompt, candidates)

        # 3. 최고 점수 선택
        best = max(candidates, key=lambda c: c.score)

        return best

    def _parse_steps(self, text: str) -> List[str]:
        """추론 텍스트를 단계별로 분리"""
        # 줄바꿈 또는 "Step N:" 패턴으로 분리
        lines = text.strip().split("\n")
        steps = []
        current_step = []

        for line in lines:
            if line.strip().startswith("Step") or line.strip().startswith("단계"):
                if current_step:
                    steps.append("\n".join(current_step))
                current_step = [line]
            else:
                current_step.append(line)

        if current_step:
            steps.append("\n".join(current_step))

        return steps if steps else [text]


class SelfConsistency:
    """Self-Consistency (다수결 투표) 기반 추론 확장"""

    def __init__(self, generator, n_samples: int = 40, temperature: float = 0.7):
        self.generator = generator
        self.n_samples = n_samples
        self.temperature = temperature

    def solve(self, prompt: str) -> str:
        """다수결 투표로 답변 결정"""
        answers = []

        for _ in range(self.n_samples):
            output = self.generator.generate(
                prompt,
                temperature=self.temperature,
                max_tokens=2048
            )
            answer = self._extract_answer(output.text)
            answers.append(answer)

        # 다수결 투표
        from collections import Counter
        counter = Counter(answers)
        best_answer, count = counter.most_common(1)[0]

        return best_answer

    def _extract_answer(self, text: str) -> str:
        """최종 답변 추출 (boxed 형식 등)"""
        import re
        # LaTeX boxed 형식
        match = re.search(r'\\boxed\{([^}]+)\}', text)
        if match:
            return match.group(1).strip()

        # "답: " 형식
        match = re.search(r'답[:\s]+(.+?)(?:\n|$)', text)
        if match:
            return match.group(1).strip()

        # 마지막 줄
        lines = text.strip().split('\n')
        return lines[-1].strip()

Process Reward Model 구현

import torch
import torch.nn as nn
from transformers import AutoModel, AutoTokenizer


class ProcessRewardModel(nn.Module):
    """Process Reward Model (PRM)"""

    def __init__(
        self,
        base_model_name: str = "meta-llama/Llama-3-8B",
        hidden_size: int = 4096
    ):
        super().__init__()
        self.backbone = AutoModel.from_pretrained(base_model_name)
        self.tokenizer = AutoTokenizer.from_pretrained(base_model_name)

        # 단계별 점수 예측 헤드
        self.reward_head = nn.Sequential(
            nn.Linear(hidden_size, 256),
            nn.ReLU(),
            nn.Dropout(0.1),
            nn.Linear(256, 1),
            nn.Sigmoid()
        )

        # Step delimiter token
        self.step_token = "\n"  # 또는 특수 토큰

    def forward(
        self,
        input_ids: torch.Tensor,
        attention_mask: torch.Tensor,
        step_positions: torch.Tensor
    ) -> torch.Tensor:
        """
        Args:
            input_ids: (batch, seq_len) 토큰 ID
            attention_mask: (batch, seq_len)
            step_positions: (batch, max_steps) 각 단계 끝 위치

        Returns:
            step_scores: (batch, max_steps) 각 단계의 정확도 점수
        """
        # Backbone forward
        outputs = self.backbone(
            input_ids=input_ids,
            attention_mask=attention_mask
        )
        hidden_states = outputs.last_hidden_state  # (batch, seq, hidden)

        # 각 단계 끝 위치의 hidden state 추출
        batch_size = input_ids.shape[0]
        max_steps = step_positions.shape[1]

        step_hidden = []
        for b in range(batch_size):
            for s in range(max_steps):
                pos = step_positions[b, s]
                if pos >= 0:
                    step_hidden.append(hidden_states[b, pos])
                else:
                    step_hidden.append(torch.zeros_like(hidden_states[b, 0]))

        step_hidden = torch.stack(step_hidden).view(batch_size, max_steps, -1)

        # 점수 예측
        step_scores = self.reward_head(step_hidden).squeeze(-1)

        return step_scores

    def score_solution(
        self,
        problem: str,
        solution_steps: list
    ) -> list:
        """풀이의 각 단계에 대한 점수 반환"""
        scores = []

        for i in range(len(solution_steps)):
            # 문제 + 현재까지의 단계를 입력으로 구성
            partial_solution = "\n".join(solution_steps[:i+1])
            text = f"Problem: {problem}\nSolution:\n{partial_solution}"

            tokens = self.tokenizer(
                text,
                return_tensors="pt",
                truncation=True,
                max_length=2048
            )

            with torch.no_grad():
                outputs = self.backbone(**tokens)
                # 마지막 토큰의 hidden state 사용
                last_hidden = outputs.last_hidden_state[:, -1, :]
                score = self.reward_head(last_hidden).item()

            scores.append(score)

        return scores


class PRMTrainer:
    """PRM 학습기 (MC Estimation 기반)"""

    def __init__(
        self,
        prm: ProcessRewardModel,
        completion_model,   # 완성용 LLM
        n_completions: int = 32,
        learning_rate: float = 1e-5
    ):
        self.prm = prm
        self.completion_model = completion_model
        self.n_completions = n_completions
        self.optimizer = torch.optim.AdamW(
            prm.parameters(), lr=learning_rate
        )

    def estimate_step_labels(
        self,
        problem: str,
        steps: list,
        correct_answer: str
    ) -> list:
        """
        MC Estimation으로 각 단계의 레이블 추정

        각 단계 이후 N번 completion하여 정답 비율 계산
        """
        labels = []

        for i in range(len(steps)):
            partial = "\n".join(steps[:i+1])
            prompt = f"Problem: {problem}\nSolution so far:\n{partial}\nContinue:"

            correct_count = 0
            for _ in range(self.n_completions):
                completion = self.completion_model.generate(
                    prompt, temperature=0.8, max_tokens=512
                )
                answer = self._extract_answer(completion)
                if self._check_answer(answer, correct_answer):
                    correct_count += 1

            label = correct_count / self.n_completions
            labels.append(label)

        return labels

    def train_step(self, batch):
        """한 배치 학습"""
        self.optimizer.zero_grad()

        # Forward
        predicted_scores = self.prm(
            batch['input_ids'],
            batch['attention_mask'],
            batch['step_positions']
        )

        # MSE Loss (MC 추정 레이블과의 차이)
        target_scores = batch['step_labels']
        mask = batch['step_mask']  # 유효한 단계만

        loss = ((predicted_scores - target_scores) ** 2 * mask).sum() / mask.sum()

        loss.backward()
        self.optimizer.step()

        return loss.item()

    def _extract_answer(self, text):
        import re
        match = re.search(r'\\boxed\{([^}]+)\}', text)
        return match.group(1).strip() if match else text.strip().split('\n')[-1]

    def _check_answer(self, pred, target):
        return pred.strip() == target.strip()

MCTS for LLM Reasoning

import math
import random
from dataclasses import dataclass, field
from typing import Optional


@dataclass
class MCTSNode:
    """MCTS 트리 노드"""
    state: str                          # 현재까지의 추론 텍스트
    parent: Optional['MCTSNode'] = None
    children: list = field(default_factory=list)
    visits: int = 0
    value: float = 0.0
    step_text: str = ""                 # 이 노드에서 추가된 단계
    is_terminal: bool = False

    @property
    def q_value(self) -> float:
        return self.value / max(self.visits, 1)

    def ucb1(self, c: float = 1.414) -> float:
        if self.visits == 0:
            return float('inf')
        exploitation = self.q_value
        exploration = c * math.sqrt(
            math.log(self.parent.visits) / self.visits
        )
        return exploitation + exploration


class ReasoningMCTS:
    """LLM 추론을 위한 MCTS"""

    def __init__(
        self,
        generator,          # LLM 생성 모델
        prm,                # Process Reward Model
        max_iterations: int = 100,
        max_depth: int = 10,
        n_expansions: int = 3,
        c_explore: float = 1.414
    ):
        self.generator = generator
        self.prm = prm
        self.max_iterations = max_iterations
        self.max_depth = max_depth
        self.n_expansions = n_expansions
        self.c_explore = c_explore

    def solve(self, problem: str) -> str:
        """MCTS로 문제 풀기"""
        root = MCTSNode(state=f"Problem: {problem}\nSolution:\n")

        for _ in range(self.max_iterations):
            # 1. Selection
            node = self._select(root)

            if node.is_terminal:
                continue

            # 2. Expansion
            children = self._expand(node, problem)

            if not children:
                node.is_terminal = True
                continue

            # 3. Simulation & Evaluation
            for child in children:
                value = self._evaluate(problem, child)

                # 4. Backpropagation
                self._backpropagate(child, value)

        # 최적 경로 추출
        return self._best_path(root)

    def _select(self, node: MCTSNode) -> MCTSNode:
        """UCB1 기반 노드 선택"""
        while node.children and not node.is_terminal:
            node = max(
                node.children,
                key=lambda c: c.ucb1(self.c_explore)
            )
        return node

    def _expand(
        self,
        node: MCTSNode,
        problem: str
    ) -> list:
        """LLM으로 다음 추론 단계 생성"""
        children = []

        for _ in range(self.n_expansions):
            # 다음 단계 생성
            next_step = self.generator.generate(
                node.state + "\nNext step:",
                temperature=0.8,
                max_tokens=256,
                stop=["\n\n"]  # 한 단계만 생성
            )

            step_text = next_step.text.strip()

            if not step_text:
                continue

            # 종료 조건 확인
            is_terminal = self._is_terminal(step_text)

            child = MCTSNode(
                state=node.state + step_text + "\n",
                parent=node,
                step_text=step_text,
                is_terminal=is_terminal
            )

            node.children.append(child)
            children.append(child)

        return children

    def _evaluate(self, problem: str, node: MCTSNode) -> float:
        """PRM으로 노드 가치 평가"""
        # 현재 경로의 단계들 추출
        steps = self._extract_steps(node.state)

        if not steps:
            return 0.0

        # PRM 점수
        scores = self.prm.score_solution(problem, steps)

        # 최소 점수 반환 (가장 약한 단계 기준)
        return min(scores) if scores else 0.0

    def _backpropagate(self, node: MCTSNode, value: float):
        """결과를 루트까지 전파"""
        while node is not None:
            node.visits += 1
            node.value += value
            node = node.parent

    def _best_path(self, root: MCTSNode) -> str:
        """방문 횟수 기준 최적 경로 추출"""
        node = root
        path = []

        while node.children:
            node = max(node.children, key=lambda c: c.visits)
            path.append(node.step_text)

        return "\n".join(path)

    def _is_terminal(self, text: str) -> bool:
        """종료 조건 확인"""
        terminal_markers = [
            "\\boxed{", "답:", "therefore", "따라서",
            "final answer", "최종 답"
        ]
        return any(m in text.lower() for m in terminal_markers)

    def _extract_steps(self, state: str) -> list:
        """상태에서 추론 단계 추출"""
        lines = state.split("\n")
        steps = [l for l in lines if l.strip() and not l.startswith("Problem:")]
        return steps

Compute-Optimal Inference Router

import numpy as np
from typing import Dict, Callable
from enum import Enum


class DifficultyLevel(Enum):
    EASY = "easy"
    MEDIUM = "medium"
    HARD = "hard"
    VERY_HARD = "very_hard"


class InferenceRouter:
    """
    문제 난이도에 따른 최적 추론 전략 라우터

    Snell et al. (2024)의 compute-optimal inference 구현
    """

    def __init__(
        self,
        generator,
        prm=None,
        difficulty_estimator=None,
        total_compute_budget: int = 1000  # 토큰 단위
    ):
        self.generator = generator
        self.prm = prm
        self.difficulty_estimator = difficulty_estimator
        self.total_compute_budget = total_compute_budget

        # 난이도별 전략 매핑
        self.strategies: Dict[DifficultyLevel, Callable] = {
            DifficultyLevel.EASY: self._strategy_single,
            DifficultyLevel.MEDIUM: self._strategy_best_of_n,
            DifficultyLevel.HARD: self._strategy_beam_search,
            DifficultyLevel.VERY_HARD: self._strategy_mcts
        }

    def solve(self, problem: str) -> str:
        """문제 난이도를 추정하고 적절한 전략 선택"""
        difficulty = self._estimate_difficulty(problem)
        strategy = self.strategies[difficulty]
        return strategy(problem)

    def solve_batch(self, problems: list) -> list:
        """
        배치 문제에 대해 연산 예산을 최적 배분

        쉬운 문제에 적은 예산, 어려운 문제에 많은 예산
        """
        # 1. 난이도 추정
        difficulties = [self._estimate_difficulty(p) for p in problems]

        # 2. 연산 예산 배분
        weights = {
            DifficultyLevel.EASY: 1,
            DifficultyLevel.MEDIUM: 4,
            DifficultyLevel.HARD: 16,
            DifficultyLevel.VERY_HARD: 64
        }

        total_weight = sum(weights[d] for d in difficulties)
        budgets = [
            int(self.total_compute_budget * weights[d] / total_weight)
            for d in difficulties
        ]

        # 3. 각 문제를 배분된 예산으로 풀기
        results = []
        for problem, difficulty, budget in zip(problems, difficulties, budgets):
            result = self._solve_with_budget(problem, difficulty, budget)
            results.append(result)

        return results

    def _estimate_difficulty(self, problem: str) -> DifficultyLevel:
        """문제 난이도 추정"""
        if self.difficulty_estimator:
            score = self.difficulty_estimator.predict(problem)
        else:
            # 간단한 휴리스틱: 문제 길이, 수학 기호 수 등
            score = self._heuristic_difficulty(problem)

        if score < 0.25:
            return DifficultyLevel.EASY
        elif score < 0.50:
            return DifficultyLevel.MEDIUM
        elif score < 0.75:
            return DifficultyLevel.HARD
        else:
            return DifficultyLevel.VERY_HARD

    def _heuristic_difficulty(self, problem: str) -> float:
        """휴리스틱 난이도 추정"""
        import re

        features = {
            'length': len(problem.split()) / 200,
            'math_symbols': len(re.findall(r'[+\-*/^=<>]', problem)) / 20,
            'numbers': len(re.findall(r'\d+', problem)) / 10,
            'nested_ops': problem.count('(') / 5,
            'keywords': sum(1 for kw in [
                'prove', 'show that', 'find all',
                'maximize', 'minimize', 'determine'
            ] if kw in problem.lower()) / 3
        }

        score = np.mean([min(v, 1.0) for v in features.values()])
        return score

    def _strategy_single(self, problem: str) -> str:
        """쉬운 문제: 단일 생성"""
        output = self.generator.generate(
            problem, temperature=0.0, max_tokens=512
        )
        return output.text

    def _strategy_best_of_n(self, problem: str, n: int = 8) -> str:
        """중간 난이도: Best-of-N"""
        sampler = BestOfNSampler(
            self.generator, self.prm,
            n_samples=n, use_prm=(self.prm is not None)
        )
        result = sampler.solve(problem)
        return result.text

    def _strategy_beam_search(self, problem: str) -> str:
        """어려운 문제: PRM 기반 Beam Search"""
        # 간소화된 beam search
        beam_width = 4
        beams = [f"Problem: {problem}\nSolution:\n"]

        for depth in range(8):
            candidates = []
            for beam in beams:
                for _ in range(3):
                    step = self.generator.generate(
                        beam + "Next step:",
                        temperature=0.7,
                        max_tokens=256,
                        stop=["\n\n"]
                    )
                    new_beam = beam + step.text.strip() + "\n"

                    if self.prm:
                        steps = [l for l in new_beam.split("\n") 
                                if l.strip() and not l.startswith("Problem:")]
                        scores = self.prm.score_solution(problem, steps)
                        score = min(scores) if scores else 0.0
                    else:
                        score = random.random()

                    candidates.append((new_beam, score))

            # 상위 beam_width개 선택
            candidates.sort(key=lambda x: x[1], reverse=True)
            beams = [c[0] for c in candidates[:beam_width]]

        return beams[0] if beams else ""

    def _strategy_mcts(self, problem: str) -> str:
        """매우 어려운 문제: MCTS"""
        mcts = ReasoningMCTS(
            self.generator, self.prm,
            max_iterations=50,
            n_expansions=3
        )
        return mcts.solve(problem)

    def _solve_with_budget(
        self,
        problem: str,
        difficulty: DifficultyLevel,
        budget: int
    ) -> str:
        """예산 제한 내에서 최적 풀이"""
        # 예산에 따라 N 또는 반복 횟수 조정
        if difficulty == DifficultyLevel.EASY:
            return self._strategy_single(problem)
        elif difficulty == DifficultyLevel.MEDIUM:
            n = max(2, budget // 128)
            return self._strategy_best_of_n(problem, n=n)
        else:
            return self._strategy_beam_search(problem)

실무 가이드라인

방법 선택 기준

문제 유형 분석
    |
    +-- 수학/코딩 (정답 검증 가능)
    |       +-- 난이도 낮음 -> Self-Consistency (N=8~16)
    |       +-- 난이도 높음 -> PRM + Beam Search 또는 MCTS
    |
    +-- 자유 형식 (정답 검증 어려움)
    |       +-- ORM + Best-of-N
    |       +-- Extended CoT (긴 추론 체인)
    |
    +-- 실시간 응답 필요
            +-- Budget Forcing (토큰 수 제어)
            +-- 경량 PRM + Best-of-4

비용-성능 트레이드오프

방법 연산 비용 성능 향상 추가 모델 필요 구현 복잡도
Self-Consistency 낮음 (N배) 5-15% X 낮음
Best-of-N + ORM 중간 10-25% ORM 중간
Best-of-N + PRM 높음 15-35% PRM 높음
Beam Search + PRM 높음 20-40% PRM 높음
MCTS + PRM 매우 높음 25-50% PRM 매우 높음
Extended CoT (RL) 중간 30-60% X (내재화) 학습 비용 높음

주의사항

1. 수확 체감 (Diminishing Returns)
   - Best-of-N에서 N > 64 이후 개선폭 급감
   - MCTS에서 반복 > 200 이후 효용 미미
   - 문제의 본질적 난이도에 따른 상한 존재

2. 검증기 품질이 핵심
   - PRM 품질이 낮으면 오히려 성능 저하
   - 잘못된 단계에 높은 점수 -> 오답 선택
   - 검증기 학습에 투자하는 것이 중요

3. 분포 이동 (Distribution Shift)
   - PRM은 학습 시와 다른 분포의 추론에서 성능 저하
   - 생성 모델이 변경되면 PRM도 재학습 필요
   - 반복적 self-improvement에서 주의

4. 비용 관리
   - 프로덕션 환경에서 N배 비용 증가
   - 난이도 라우팅으로 평균 비용 절감 가능
   - 사용자 대면 vs 배치 처리에 따라 전략 차별화

관련 연구 흐름

Chain-of-Thought (Wei et al., 2022)
    |
    +-- Self-Consistency (Wang et al., 2022)
    |
    +-- Let's Verify Step by Step (Lightman et al., 2023)
    |       |
    |       +-- PRM800K 데이터셋
    |       +-- Process 검증이 Outcome 검증보다 우수함 증명
    |
    +-- Scaling Test-Time Compute (Snell et al., NeurIPS 2024)
    |       |
    |       +-- Compute-optimal inference 이론 정립
    |       +-- 난이도별 최적 전략 분석
    |
    +-- OpenAI o1 (2024.09)
    |       |
    |       +-- Internal CoT의 상용화
    |       +-- Hidden reasoning tokens
    |
    +-- DeepSeek-R1 (2025.01)
    |       |
    |       +-- 순수 RL로 추론 능력 출현
    |       +-- GRPO 알고리즘
    |       +-- 오픈소스 추론 모델
    |
    +-- s1 (Muennighoff et al., 2025.01)
    |       |
    |       +-- Budget Forcing
    |       +-- 1K 예제만으로 추론 모델 학습
    |
    +-- ReST-MCTS* (NeurIPS 2024)
    |       |
    |       +-- PRM 기반 MCTS + 자기 학습
    |
    +-- Rewarding Progress (ICLR 2025)
            |
            +-- 자동화된 PRM 학습 스케일링

참고 자료

핵심 논문

  1. Snell, C., Lee, J., Xu, K., & Kumar, A. (2024). Scaling LLM Test-Time Compute Optimally can be More Effective than Scaling Model Parameters. NeurIPS 2024. arXiv:2408.03314.
  2. Lightman, H. et al. (2023). Let's Verify Step by Step. arXiv:2305.20050.
  3. DeepSeek-AI. (2025). DeepSeek-R1: Incentivizing Reasoning Capability in LLMs via Reinforcement Learning. arXiv:2501.12948.
  4. Muennighoff, N. et al. (2025). s1: Simple Test-Time Scaling. arXiv:2501.19393.

확장 논문

  1. Wang, X. et al. (2022). Self-Consistency Improves Chain of Thought Reasoning in Language Models. ICLR 2023.
  2. Zhang, D. et al. (2024). ReST-MCTS*: LLM Self-Training via Process Reward Guided Tree Search. NeurIPS 2024.
  3. Setlur, A. et al. (2025). Rewarding Progress: Scaling Automated Process Verifiers for LLM Reasoning. ICLR 2025.
  4. Luo, L. et al. (2025). The Lessons of Developing Process Reward Models in Mathematical Reasoning. arXiv:2501.07301.

관련 개념