Masked Diffusion Models (MDMs)¶
메타정보¶
| 항목 | 내용 |
|---|---|
| 논문 | Train for the Worst, Plan for the Best: Understanding Token Ordering in Masked Diffusions |
| 저자 | Kulin Shah (UT Austin), Jaeyeon Kim, Sitan Chen, Vasilis Kontonis, Sham Kakade (Harvard) |
| 발표 | ICML 2025 Outstanding Paper Award |
| arXiv | 2502.06768 |
| 키워드 | Masked Diffusion, Discrete Diffusion, Adaptive Decoding, Generative Modeling |
개요¶
Masked Diffusion Models (MDMs)는 이산 도메인(discrete domain)에서의 생성 모델링을 위한 접근법이다. Autoregressive Models (ARMs)와 달리, MDMs는 학습 시 복잡성을 높이는 대신 추론 시 유연성을 확보한다.
핵심 발견: - MDMs는 학습 시 계산적으로 어려운 subproblem들을 해결해야 함 - 적응형 토큰 디코딩 전략으로 어려운 subproblem 회피 가능 - Sudoku 정확도: <7% -> ~90% 향상 (adaptive inference) - 7배 큰 파라미터의 ARM보다 우수한 성능
배경: ARMs vs MDMs¶
Autoregressive Models (ARMs)¶
특징: - 고정된 순서 (left-to-right)로 토큰 생성 - 학습이 단순함 (teacher forcing) - 추론 시 순차적 디코딩 필수
Masked Diffusion Models (MDMs)¶
특징: - 임의의 순서로 토큰 생성 가능 - 학습 시 모든 가능한 infilling 패턴 학습 필요 - 추론 시 디코딩 순서 선택 자유
핵심 문제: 학습-추론 트레이드오프¶
MDM 학습의 어려움¶
MDMs는 학습 시 지수적으로 많은 infilling 문제를 해결해야 한다:
n개 토큰에 대해:
- 가능한 마스킹 패턴: 2^n
- 각 패턴에서 masked 토큰 예측 필요
- ARMs: O(n) subproblems
- MDMs: O(2^n) subproblems
이론적 결과: 특정 subproblem들은 계산적으로 intractable (NP-hard class)
추론에서의 기회¶
학습은 어렵지만, 추론 시에는: - 어떤 토큰을 먼저 생성할지 선택 가능 - 어려운 subproblem을 피할 수 있음 - 올바른 순서 선택이 성능을 극적으로 향상
Adaptive Token Ordering¶
핵심 아이디어¶
"쉬운 토큰부터 먼저 생성하고, 어려운 토큰은 나중에"
while not all tokens generated:
1. 각 masked 토큰의 예측 confidence 계산
2. 가장 confident한 토큰 선택
3. 해당 토큰 생성 (unmask)
4. 반복
Confidence 기반 순서 결정¶
def adaptive_decode(model, masked_sequence, mask):
"""
Adaptive decoding: 가장 confident한 토큰부터 생성
"""
while mask.any():
# 모든 masked 위치에 대한 예측
logits = model(masked_sequence)
# Confidence = max probability
probs = F.softmax(logits, dim=-1)
confidence = probs.max(dim=-1).values
# Masked 위치만 고려
confidence[~mask] = -float('inf')
# 가장 confident한 위치 선택
best_pos = confidence.argmax()
# 해당 위치 생성
best_token = probs[best_pos].argmax()
masked_sequence[best_pos] = best_token
mask[best_pos] = False
return masked_sequence
실험: Sudoku 퍼즐¶
설정¶
- 9x9 Sudoku 보드 (81개 셀)
- 주어진 힌트에서 나머지 셀 예측
- 정확도: 모든 셀이 규칙을 만족하는 비율
결과¶
| Method | Accuracy |
|---|---|
| MDM (random order) | <7% |
| MDM (adaptive order) | ~90% |
| ARM (7x params, teacher forcing) | ~85% |
왜 Adaptive가 효과적인가¶
Sudoku의 constraint propagation과 유사: 1. 명확한 셀 (1개 가능 값)을 먼저 채움 2. 채워진 셀이 다른 셀의 가능 값을 제한 3. 연쇄적으로 전체 보드 완성
MDM + adaptive decoding이 이 과정을 자연스럽게 모방
방법론 상세¶
BERT-style Masking Objective¶
def mdm_loss(model, x, mask_ratio=0.15):
"""
MDM 학습: 마스킹된 토큰 복원
"""
batch_size, seq_len = x.shape
# 랜덤 마스킹
mask = torch.rand(batch_size, seq_len) < mask_ratio
# 마스크 토큰으로 대체
x_masked = x.clone()
x_masked[mask] = MASK_TOKEN
# 예측
logits = model(x_masked)
# Masked 위치만 loss 계산
loss = F.cross_entropy(
logits[mask],
x[mask]
)
return loss
Diffusion Formulation¶
MDM을 diffusion 관점에서 해석:
Forward process: x_0 -> x_1 -> ... -> x_T (점진적 마스킹)
Reverse process: x_T -> x_{T-1} -> ... -> x_0 (점진적 언마스킹)
x_t: t/T 비율의 토큰이 마스킹된 상태
Score Matching Analogy¶
Continuous diffusion: score = grad log p(x_t)
Discrete diffusion: "score" = log p(x_unmask | x_masked)
Python 구현¶
MDM 모델 구조¶
import torch
import torch.nn as nn
import torch.nn.functional as F
class MaskedDiffusionModel(nn.Module):
"""
Masked Diffusion Model for discrete sequences
학습: 마스킹된 토큰 복원
추론: adaptive ordering으로 생성
"""
def __init__(
self,
vocab_size: int,
hidden_dim: int = 512,
num_layers: int = 6,
num_heads: int = 8,
max_len: int = 512
):
super().__init__()
self.vocab_size = vocab_size
self.hidden_dim = hidden_dim
self.mask_token_id = vocab_size # Special mask token
# Embedding (vocab + mask token)
self.embedding = nn.Embedding(vocab_size + 1, hidden_dim)
self.pos_embedding = nn.Embedding(max_len, hidden_dim)
# Transformer encoder
encoder_layer = nn.TransformerEncoderLayer(
d_model=hidden_dim,
nhead=num_heads,
dim_feedforward=hidden_dim * 4,
batch_first=True
)
self.transformer = nn.TransformerEncoder(
encoder_layer,
num_layers=num_layers
)
# Output projection
self.output = nn.Linear(hidden_dim, vocab_size)
def forward(self, x, mask=None):
"""
Args:
x: Input sequence with mask tokens, (B, L)
mask: Optional attention mask
Returns:
logits: (B, L, vocab_size)
"""
B, L = x.shape
# Embeddings
pos = torch.arange(L, device=x.device).unsqueeze(0)
h = self.embedding(x) + self.pos_embedding(pos)
# Transformer
h = self.transformer(h, src_key_padding_mask=mask)
# Output
logits = self.output(h)
return logits
def compute_loss(self, x, mask_ratio=0.15):
"""
Compute MDM training loss
"""
B, L = x.shape
device = x.device
# Random masking
mask = torch.rand(B, L, device=device) < mask_ratio
# Create masked input
x_masked = x.clone()
x_masked[mask] = self.mask_token_id
# Forward
logits = self.forward(x_masked)
# Loss on masked positions only
loss = F.cross_entropy(
logits[mask],
x[mask],
reduction='mean'
)
return loss
Adaptive Decoding¶
class AdaptiveDecoder:
"""
Adaptive token ordering for MDM inference
핵심: confidence가 높은 토큰부터 생성
"""
def __init__(self, model: MaskedDiffusionModel):
self.model = model
self.mask_token_id = model.mask_token_id
@torch.no_grad()
def generate(
self,
prompt: torch.Tensor,
generate_mask: torch.Tensor,
temperature: float = 1.0,
top_k: int = 0
) -> torch.Tensor:
"""
Adaptive generation
Args:
prompt: Initial sequence with mask tokens, (B, L)
generate_mask: Boolean mask for positions to generate, (B, L)
temperature: Sampling temperature
top_k: Top-k filtering (0 = greedy)
Returns:
Generated sequence
"""
device = prompt.device
sequence = prompt.clone()
remaining_mask = generate_mask.clone()
while remaining_mask.any():
# Get predictions
logits = self.model(sequence)
# Apply temperature
logits = logits / temperature
# Calculate confidence for each position
probs = F.softmax(logits, dim=-1)
max_probs, max_tokens = probs.max(dim=-1) # (B, L)
# Mask out non-generate positions
max_probs[~remaining_mask] = -float('inf')
# Select most confident position per batch
best_positions = max_probs.argmax(dim=-1) # (B,)
# Generate tokens at best positions
for b in range(sequence.shape[0]):
pos = best_positions[b].item()
if remaining_mask[b, pos]:
if top_k > 0:
# Top-k sampling
pos_logits = logits[b, pos]
top_k_logits, top_k_indices = pos_logits.topk(top_k)
top_k_probs = F.softmax(top_k_logits, dim=-1)
idx = torch.multinomial(top_k_probs, 1)
token = top_k_indices[idx]
else:
# Greedy
token = max_tokens[b, pos]
sequence[b, pos] = token
remaining_mask[b, pos] = False
return sequence
@torch.no_grad()
def generate_parallel(
self,
prompt: torch.Tensor,
generate_mask: torch.Tensor,
steps: int = 10,
temperature: float = 1.0
) -> torch.Tensor:
"""
Parallel adaptive generation (faster)
한 step에 여러 토큰을 동시에 생성
"""
device = prompt.device
sequence = prompt.clone()
remaining_mask = generate_mask.clone()
# 각 step에서 생성할 토큰 수
total_to_generate = remaining_mask.sum().item()
tokens_per_step = max(1, total_to_generate // steps)
for _ in range(steps):
if not remaining_mask.any():
break
# Get predictions
logits = self.model(sequence) / temperature
probs = F.softmax(logits, dim=-1)
max_probs, max_tokens = probs.max(dim=-1)
# Mask out non-generate positions
max_probs[~remaining_mask] = -float('inf')
# Select top-k most confident positions
for b in range(sequence.shape[0]):
batch_mask = remaining_mask[b]
if not batch_mask.any():
continue
# Get confidence for this batch
batch_probs = max_probs[b].clone()
# Number to generate this step
n_remaining = batch_mask.sum().item()
n_generate = min(tokens_per_step, n_remaining)
# Top-k positions
_, top_positions = batch_probs.topk(n_generate)
for pos in top_positions:
pos = pos.item()
if remaining_mask[b, pos]:
sequence[b, pos] = max_tokens[b, pos]
remaining_mask[b, pos] = False
return sequence
Sudoku 특화 구현¶
class SudokuMDM(MaskedDiffusionModel):
"""
Sudoku-specific MDM
9x9 보드, 각 셀은 1-9 값
"""
def __init__(self, hidden_dim=256, num_layers=4):
super().__init__(
vocab_size=9, # 1-9 (0-indexed: 0-8)
hidden_dim=hidden_dim,
num_layers=num_layers,
max_len=81 # 9x9 board
)
# 추가: row, col, box positional encoding
self.row_embed = nn.Embedding(9, hidden_dim)
self.col_embed = nn.Embedding(9, hidden_dim)
self.box_embed = nn.Embedding(9, hidden_dim)
def forward(self, x, mask=None):
B, L = x.shape
device = x.device
# 위치 정보 계산
positions = torch.arange(L, device=device)
rows = positions // 9
cols = positions % 9
boxes = (rows // 3) * 3 + (cols // 3)
# Embeddings with Sudoku structure
h = self.embedding(x)
h = h + self.row_embed(rows)
h = h + self.col_embed(cols)
h = h + self.box_embed(boxes)
# Transformer
h = self.transformer(h)
return self.output(h)
def validate_sudoku(board: torch.Tensor) -> bool:
"""
Sudoku 솔루션 검증
"""
board = board.view(9, 9).cpu().numpy()
# 각 행 검사
for row in board:
if len(set(row)) != 9 or set(row) != set(range(9)):
return False
# 각 열 검사
for col in board.T:
if len(set(col)) != 9:
return False
# 각 3x3 박스 검사
for i in range(0, 9, 3):
for j in range(0, 9, 3):
box = board[i:i+3, j:j+3].flatten()
if len(set(box)) != 9:
return False
return True
학습 파이프라인¶
def train_mdm(
model: MaskedDiffusionModel,
dataloader,
epochs: int = 100,
lr: float = 1e-4,
mask_ratio_schedule: str = 'linear'
):
"""
MDM 학습
Args:
mask_ratio_schedule: 'linear', 'cosine', 'constant'
"""
optimizer = torch.optim.AdamW(model.parameters(), lr=lr)
scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(
optimizer, epochs
)
model.train()
for epoch in range(epochs):
total_loss = 0
# Mask ratio scheduling
if mask_ratio_schedule == 'linear':
mask_ratio = 0.15 + 0.5 * (epoch / epochs)
elif mask_ratio_schedule == 'cosine':
mask_ratio = 0.15 + 0.5 * (1 - np.cos(np.pi * epoch / epochs)) / 2
else:
mask_ratio = 0.5
for batch in dataloader:
x = batch['sequence']
loss = model.compute_loss(x, mask_ratio=mask_ratio)
optimizer.zero_grad()
loss.backward()
torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)
optimizer.step()
total_loss += loss.item()
scheduler.step()
avg_loss = total_loss / len(dataloader)
print(f"Epoch {epoch}: Loss = {avg_loss:.4f}, "
f"Mask Ratio = {mask_ratio:.2f}")
def evaluate_mdm(
model: MaskedDiffusionModel,
test_data,
use_adaptive: bool = True
):
"""
MDM 평가 (Sudoku)
"""
model.eval()
decoder = AdaptiveDecoder(model) if use_adaptive else None
correct = 0
total = 0
for puzzle, solution in test_data:
# puzzle: 힌트가 있는 보드 (빈 셀 = mask_token)
# solution: 정답
generate_mask = (puzzle == model.mask_token_id)
if use_adaptive:
prediction = decoder.generate(
puzzle.unsqueeze(0),
generate_mask.unsqueeze(0)
).squeeze(0)
else:
# Random order decoding
prediction = random_order_decode(model, puzzle, generate_mask)
if validate_sudoku(prediction):
correct += 1
total += 1
accuracy = correct / total
print(f"Accuracy: {accuracy:.2%}")
return accuracy
이론적 분석¶
MDM 학습의 Hardness¶
정리: 특정 분포에서 MDM의 일부 subproblem은 NP-hard
증명 스케치: 1. Boolean satisfiability를 MDM subproblem으로 환원 2. 특정 마스킹 패턴에서 복원이 SAT 해결과 동치 3. 따라서 polynomial time 해 불가능 (P != NP 가정)
실질적 의미: - 모든 subproblem을 완벽히 학습하는 것은 계산적으로 불가능 - 그러나 "쉬운" subproblem은 잘 학습 가능 - Adaptive decoding은 쉬운 path를 선택
Adaptive Decoding의 이론적 기반¶
ARM: P(x) = prod_i P(x_i | x_{<i}) # 고정 순서
MDM: P(x) = prod_i P(x_sigma(i) | x_sigma(<i)) # 가변 순서
where sigma = argmax_ordering product of confidences
Adaptive ordering은 각 step에서 최고 confidence를 선택하여: - Local optimum을 통해 global solution에 접근 - Constraint propagation 효과 달성
응용 분야¶
1. 언어 모델링¶
- 양방향 컨텍스트 활용 가능
- Infilling tasks (코드 완성, 텍스트 편집)
- Non-monotonic generation
2. 생물학적 시퀀스¶
# 단백질/RNA 서열 생성
class ProteinMDM(MaskedDiffusionModel):
def __init__(self):
super().__init__(
vocab_size=20, # 20 amino acids
hidden_dim=768,
num_layers=12
)
- 구조적 제약 (2차/3차 구조)이 있는 서열 생성
- Adaptive decoding이 자연스럽게 구조적 일관성 유지
3. 논리적 추론¶
- Sudoku, constraint satisfaction problems
- Mathematical proof generation
- Code synthesis with constraints
4. 멀티모달 생성¶
- 이미지 토큰화 후 MDM 적용
- Text-to-image에서 양방향 attention 활용
ARMs vs MDMs 비교 요약¶
| 측면 | ARMs | MDMs |
|---|---|---|
| 학습 복잡도 | O(n) subproblems | O(2^n) subproblems |
| 추론 순서 | 고정 (left-to-right) | 유연 (adaptive 가능) |
| Teacher forcing | 가능 | 어려움 |
| 양방향 컨텍스트 | 불가 | 가능 |
| Constraint 문제 | 약함 | 강함 (adaptive 시) |
| 학습 안정성 | 높음 | 낮음 (variance 큼) |
| 추론 효율성 | 순차적 | 병렬화 가능 |
핵심 인사이트¶
- Trade-off 이해: MDM은 학습의 어려움을 추론의 유연성과 교환
- Adaptive의 핵심: 올바른 디코딩 순서가 성능을 극적으로 개선
- Constraint propagation: MDM + adaptive = 자연스러운 constraint solver
- 실용성: 논리적 추론, 구조적 생성에서 ARM 대비 우위
한계 및 향후 연구¶
현재 한계¶
- 학습 시 variance가 높음 (다양한 마스킹 패턴)
- 최적 디코딩 순서 찾기가 NP-hard일 수 있음
- 대규모 언어 모델로의 확장 검증 필요
향후 방향¶
- 효율적 학습: Curriculum learning, importance sampling
- Better ordering: 학습된 ordering policy
- Hybrid approaches: ARM + MDM 결합
- Scaling: LLM 규모에서의 MDM 탐구
참고 문헌¶
- Shah et al. "Train for the Worst, Plan for the Best: Understanding Token Ordering in Masked Diffusions" ICML 2025
- Austin et al. "Structured Denoising Diffusion Models in Discrete State-Spaces" NeurIPS 2021
- He et al. "Diffusion Language Models Can Perform Many Tasks with Scaling and Instruction-Finetuning" arXiv 2023
- Lou et al. "Discrete Diffusion Modeling by Estimating the Ratios of the Data Distribution" ICML 2024
- Sahoo et al. "Simple and Effective Masked Diffusion Language Models" NeurIPS 2024