콘텐츠로 이동
Data Prep
상세

Masked Diffusion Models (MDMs)

메타정보

항목 내용
논문 Train for the Worst, Plan for the Best: Understanding Token Ordering in Masked Diffusions
저자 Kulin Shah (UT Austin), Jaeyeon Kim, Sitan Chen, Vasilis Kontonis, Sham Kakade (Harvard)
발표 ICML 2025 Outstanding Paper Award
arXiv 2502.06768
키워드 Masked Diffusion, Discrete Diffusion, Adaptive Decoding, Generative Modeling

개요

Masked Diffusion Models (MDMs)는 이산 도메인(discrete domain)에서의 생성 모델링을 위한 접근법이다. Autoregressive Models (ARMs)와 달리, MDMs는 학습 시 복잡성을 높이는 대신 추론 시 유연성을 확보한다.

핵심 발견: - MDMs는 학습 시 계산적으로 어려운 subproblem들을 해결해야 함 - 적응형 토큰 디코딩 전략으로 어려운 subproblem 회피 가능 - Sudoku 정확도: <7% -> ~90% 향상 (adaptive inference) - 7배 큰 파라미터의 ARM보다 우수한 성능


배경: ARMs vs MDMs

Autoregressive Models (ARMs)

P(x) = P(x_1) * P(x_2|x_1) * P(x_3|x_1,x_2) * ... * P(x_n|x_1,...,x_{n-1})

특징: - 고정된 순서 (left-to-right)로 토큰 생성 - 학습이 단순함 (teacher forcing) - 추론 시 순차적 디코딩 필수

Masked Diffusion Models (MDMs)

P(x) = sum over all orderings: P(x_sigma(1)) * P(x_sigma(2)|x_sigma(1)) * ...

특징: - 임의의 순서로 토큰 생성 가능 - 학습 시 모든 가능한 infilling 패턴 학습 필요 - 추론 시 디코딩 순서 선택 자유


핵심 문제: 학습-추론 트레이드오프

MDM 학습의 어려움

MDMs는 학습 시 지수적으로 많은 infilling 문제를 해결해야 한다:

n개 토큰에 대해:
- 가능한 마스킹 패턴: 2^n
- 각 패턴에서 masked 토큰 예측 필요
- ARMs: O(n) subproblems
- MDMs: O(2^n) subproblems

이론적 결과: 특정 subproblem들은 계산적으로 intractable (NP-hard class)

추론에서의 기회

학습은 어렵지만, 추론 시에는: - 어떤 토큰을 먼저 생성할지 선택 가능 - 어려운 subproblem을 피할 수 있음 - 올바른 순서 선택이 성능을 극적으로 향상


Adaptive Token Ordering

핵심 아이디어

"쉬운 토큰부터 먼저 생성하고, 어려운 토큰은 나중에"

while not all tokens generated:
    1. 각 masked 토큰의 예측 confidence 계산
    2. 가장 confident한 토큰 선택
    3. 해당 토큰 생성 (unmask)
    4. 반복

Confidence 기반 순서 결정

def adaptive_decode(model, masked_sequence, mask):
    """
    Adaptive decoding: 가장 confident한 토큰부터 생성
    """
    while mask.any():
        # 모든 masked 위치에 대한 예측
        logits = model(masked_sequence)

        # Confidence = max probability
        probs = F.softmax(logits, dim=-1)
        confidence = probs.max(dim=-1).values

        # Masked 위치만 고려
        confidence[~mask] = -float('inf')

        # 가장 confident한 위치 선택
        best_pos = confidence.argmax()

        # 해당 위치 생성
        best_token = probs[best_pos].argmax()
        masked_sequence[best_pos] = best_token
        mask[best_pos] = False

    return masked_sequence

실험: Sudoku 퍼즐

설정

  • 9x9 Sudoku 보드 (81개 셀)
  • 주어진 힌트에서 나머지 셀 예측
  • 정확도: 모든 셀이 규칙을 만족하는 비율

결과

Method Accuracy
MDM (random order) <7%
MDM (adaptive order) ~90%
ARM (7x params, teacher forcing) ~85%

왜 Adaptive가 효과적인가

Sudoku의 constraint propagation과 유사: 1. 명확한 셀 (1개 가능 값)을 먼저 채움 2. 채워진 셀이 다른 셀의 가능 값을 제한 3. 연쇄적으로 전체 보드 완성

MDM + adaptive decoding이 이 과정을 자연스럽게 모방


방법론 상세

BERT-style Masking Objective

def mdm_loss(model, x, mask_ratio=0.15):
    """
    MDM 학습: 마스킹된 토큰 복원
    """
    batch_size, seq_len = x.shape

    # 랜덤 마스킹
    mask = torch.rand(batch_size, seq_len) < mask_ratio

    # 마스크 토큰으로 대체
    x_masked = x.clone()
    x_masked[mask] = MASK_TOKEN

    # 예측
    logits = model(x_masked)

    # Masked 위치만 loss 계산
    loss = F.cross_entropy(
        logits[mask],
        x[mask]
    )

    return loss

Diffusion Formulation

MDM을 diffusion 관점에서 해석:

Forward process: x_0 -> x_1 -> ... -> x_T (점진적 마스킹)
Reverse process: x_T -> x_{T-1} -> ... -> x_0 (점진적 언마스킹)

x_t: t/T 비율의 토큰이 마스킹된 상태

Score Matching Analogy

Continuous diffusion: score = grad log p(x_t)
Discrete diffusion: "score" = log p(x_unmask | x_masked)

Python 구현

MDM 모델 구조

import torch
import torch.nn as nn
import torch.nn.functional as F

class MaskedDiffusionModel(nn.Module):
    """
    Masked Diffusion Model for discrete sequences

    학습: 마스킹된 토큰 복원
    추론: adaptive ordering으로 생성
    """
    def __init__(
        self,
        vocab_size: int,
        hidden_dim: int = 512,
        num_layers: int = 6,
        num_heads: int = 8,
        max_len: int = 512
    ):
        super().__init__()
        self.vocab_size = vocab_size
        self.hidden_dim = hidden_dim
        self.mask_token_id = vocab_size  # Special mask token

        # Embedding (vocab + mask token)
        self.embedding = nn.Embedding(vocab_size + 1, hidden_dim)
        self.pos_embedding = nn.Embedding(max_len, hidden_dim)

        # Transformer encoder
        encoder_layer = nn.TransformerEncoderLayer(
            d_model=hidden_dim,
            nhead=num_heads,
            dim_feedforward=hidden_dim * 4,
            batch_first=True
        )
        self.transformer = nn.TransformerEncoder(
            encoder_layer,
            num_layers=num_layers
        )

        # Output projection
        self.output = nn.Linear(hidden_dim, vocab_size)

    def forward(self, x, mask=None):
        """
        Args:
            x: Input sequence with mask tokens, (B, L)
            mask: Optional attention mask
        Returns:
            logits: (B, L, vocab_size)
        """
        B, L = x.shape

        # Embeddings
        pos = torch.arange(L, device=x.device).unsqueeze(0)
        h = self.embedding(x) + self.pos_embedding(pos)

        # Transformer
        h = self.transformer(h, src_key_padding_mask=mask)

        # Output
        logits = self.output(h)

        return logits

    def compute_loss(self, x, mask_ratio=0.15):
        """
        Compute MDM training loss
        """
        B, L = x.shape
        device = x.device

        # Random masking
        mask = torch.rand(B, L, device=device) < mask_ratio

        # Create masked input
        x_masked = x.clone()
        x_masked[mask] = self.mask_token_id

        # Forward
        logits = self.forward(x_masked)

        # Loss on masked positions only
        loss = F.cross_entropy(
            logits[mask],
            x[mask],
            reduction='mean'
        )

        return loss

Adaptive Decoding

class AdaptiveDecoder:
    """
    Adaptive token ordering for MDM inference

    핵심: confidence가 높은 토큰부터 생성
    """
    def __init__(self, model: MaskedDiffusionModel):
        self.model = model
        self.mask_token_id = model.mask_token_id

    @torch.no_grad()
    def generate(
        self,
        prompt: torch.Tensor,
        generate_mask: torch.Tensor,
        temperature: float = 1.0,
        top_k: int = 0
    ) -> torch.Tensor:
        """
        Adaptive generation

        Args:
            prompt: Initial sequence with mask tokens, (B, L)
            generate_mask: Boolean mask for positions to generate, (B, L)
            temperature: Sampling temperature
            top_k: Top-k filtering (0 = greedy)

        Returns:
            Generated sequence
        """
        device = prompt.device
        sequence = prompt.clone()
        remaining_mask = generate_mask.clone()

        while remaining_mask.any():
            # Get predictions
            logits = self.model(sequence)

            # Apply temperature
            logits = logits / temperature

            # Calculate confidence for each position
            probs = F.softmax(logits, dim=-1)
            max_probs, max_tokens = probs.max(dim=-1)  # (B, L)

            # Mask out non-generate positions
            max_probs[~remaining_mask] = -float('inf')

            # Select most confident position per batch
            best_positions = max_probs.argmax(dim=-1)  # (B,)

            # Generate tokens at best positions
            for b in range(sequence.shape[0]):
                pos = best_positions[b].item()
                if remaining_mask[b, pos]:
                    if top_k > 0:
                        # Top-k sampling
                        pos_logits = logits[b, pos]
                        top_k_logits, top_k_indices = pos_logits.topk(top_k)
                        top_k_probs = F.softmax(top_k_logits, dim=-1)
                        idx = torch.multinomial(top_k_probs, 1)
                        token = top_k_indices[idx]
                    else:
                        # Greedy
                        token = max_tokens[b, pos]

                    sequence[b, pos] = token
                    remaining_mask[b, pos] = False

        return sequence

    @torch.no_grad()
    def generate_parallel(
        self,
        prompt: torch.Tensor,
        generate_mask: torch.Tensor,
        steps: int = 10,
        temperature: float = 1.0
    ) -> torch.Tensor:
        """
        Parallel adaptive generation (faster)

        한 step에 여러 토큰을 동시에 생성
        """
        device = prompt.device
        sequence = prompt.clone()
        remaining_mask = generate_mask.clone()

        # 각 step에서 생성할 토큰 수
        total_to_generate = remaining_mask.sum().item()
        tokens_per_step = max(1, total_to_generate // steps)

        for _ in range(steps):
            if not remaining_mask.any():
                break

            # Get predictions
            logits = self.model(sequence) / temperature
            probs = F.softmax(logits, dim=-1)
            max_probs, max_tokens = probs.max(dim=-1)

            # Mask out non-generate positions
            max_probs[~remaining_mask] = -float('inf')

            # Select top-k most confident positions
            for b in range(sequence.shape[0]):
                batch_mask = remaining_mask[b]
                if not batch_mask.any():
                    continue

                # Get confidence for this batch
                batch_probs = max_probs[b].clone()

                # Number to generate this step
                n_remaining = batch_mask.sum().item()
                n_generate = min(tokens_per_step, n_remaining)

                # Top-k positions
                _, top_positions = batch_probs.topk(n_generate)

                for pos in top_positions:
                    pos = pos.item()
                    if remaining_mask[b, pos]:
                        sequence[b, pos] = max_tokens[b, pos]
                        remaining_mask[b, pos] = False

        return sequence

Sudoku 특화 구현

class SudokuMDM(MaskedDiffusionModel):
    """
    Sudoku-specific MDM

    9x9 보드, 각 셀은 1-9 값
    """
    def __init__(self, hidden_dim=256, num_layers=4):
        super().__init__(
            vocab_size=9,  # 1-9 (0-indexed: 0-8)
            hidden_dim=hidden_dim,
            num_layers=num_layers,
            max_len=81  # 9x9 board
        )

        # 추가: row, col, box positional encoding
        self.row_embed = nn.Embedding(9, hidden_dim)
        self.col_embed = nn.Embedding(9, hidden_dim)
        self.box_embed = nn.Embedding(9, hidden_dim)

    def forward(self, x, mask=None):
        B, L = x.shape
        device = x.device

        # 위치 정보 계산
        positions = torch.arange(L, device=device)
        rows = positions // 9
        cols = positions % 9
        boxes = (rows // 3) * 3 + (cols // 3)

        # Embeddings with Sudoku structure
        h = self.embedding(x)
        h = h + self.row_embed(rows)
        h = h + self.col_embed(cols)
        h = h + self.box_embed(boxes)

        # Transformer
        h = self.transformer(h)

        return self.output(h)


def validate_sudoku(board: torch.Tensor) -> bool:
    """
    Sudoku 솔루션 검증
    """
    board = board.view(9, 9).cpu().numpy()

    # 각 행 검사
    for row in board:
        if len(set(row)) != 9 or set(row) != set(range(9)):
            return False

    # 각 열 검사
    for col in board.T:
        if len(set(col)) != 9:
            return False

    # 각 3x3 박스 검사
    for i in range(0, 9, 3):
        for j in range(0, 9, 3):
            box = board[i:i+3, j:j+3].flatten()
            if len(set(box)) != 9:
                return False

    return True

학습 파이프라인

def train_mdm(
    model: MaskedDiffusionModel,
    dataloader,
    epochs: int = 100,
    lr: float = 1e-4,
    mask_ratio_schedule: str = 'linear'
):
    """
    MDM 학습

    Args:
        mask_ratio_schedule: 'linear', 'cosine', 'constant'
    """
    optimizer = torch.optim.AdamW(model.parameters(), lr=lr)
    scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(
        optimizer, epochs
    )

    model.train()

    for epoch in range(epochs):
        total_loss = 0

        # Mask ratio scheduling
        if mask_ratio_schedule == 'linear':
            mask_ratio = 0.15 + 0.5 * (epoch / epochs)
        elif mask_ratio_schedule == 'cosine':
            mask_ratio = 0.15 + 0.5 * (1 - np.cos(np.pi * epoch / epochs)) / 2
        else:
            mask_ratio = 0.5

        for batch in dataloader:
            x = batch['sequence']

            loss = model.compute_loss(x, mask_ratio=mask_ratio)

            optimizer.zero_grad()
            loss.backward()
            torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)
            optimizer.step()

            total_loss += loss.item()

        scheduler.step()

        avg_loss = total_loss / len(dataloader)
        print(f"Epoch {epoch}: Loss = {avg_loss:.4f}, "
              f"Mask Ratio = {mask_ratio:.2f}")


def evaluate_mdm(
    model: MaskedDiffusionModel,
    test_data,
    use_adaptive: bool = True
):
    """
    MDM 평가 (Sudoku)
    """
    model.eval()
    decoder = AdaptiveDecoder(model) if use_adaptive else None

    correct = 0
    total = 0

    for puzzle, solution in test_data:
        # puzzle: 힌트가 있는 보드 (빈 셀 = mask_token)
        # solution: 정답

        generate_mask = (puzzle == model.mask_token_id)

        if use_adaptive:
            prediction = decoder.generate(
                puzzle.unsqueeze(0),
                generate_mask.unsqueeze(0)
            ).squeeze(0)
        else:
            # Random order decoding
            prediction = random_order_decode(model, puzzle, generate_mask)

        if validate_sudoku(prediction):
            correct += 1
        total += 1

    accuracy = correct / total
    print(f"Accuracy: {accuracy:.2%}")
    return accuracy

이론적 분석

MDM 학습의 Hardness

정리: 특정 분포에서 MDM의 일부 subproblem은 NP-hard

증명 스케치: 1. Boolean satisfiability를 MDM subproblem으로 환원 2. 특정 마스킹 패턴에서 복원이 SAT 해결과 동치 3. 따라서 polynomial time 해 불가능 (P != NP 가정)

실질적 의미: - 모든 subproblem을 완벽히 학습하는 것은 계산적으로 불가능 - 그러나 "쉬운" subproblem은 잘 학습 가능 - Adaptive decoding은 쉬운 path를 선택

Adaptive Decoding의 이론적 기반

ARM: P(x) = prod_i P(x_i | x_{<i})   # 고정 순서

MDM: P(x) = prod_i P(x_sigma(i) | x_sigma(<i))  # 가변 순서
     where sigma = argmax_ordering product of confidences

Adaptive ordering은 각 step에서 최고 confidence를 선택하여: - Local optimum을 통해 global solution에 접근 - Constraint propagation 효과 달성


응용 분야

1. 언어 모델링

  • 양방향 컨텍스트 활용 가능
  • Infilling tasks (코드 완성, 텍스트 편집)
  • Non-monotonic generation

2. 생물학적 시퀀스

# 단백질/RNA 서열 생성
class ProteinMDM(MaskedDiffusionModel):
    def __init__(self):
        super().__init__(
            vocab_size=20,  # 20 amino acids
            hidden_dim=768,
            num_layers=12
        )
  • 구조적 제약 (2차/3차 구조)이 있는 서열 생성
  • Adaptive decoding이 자연스럽게 구조적 일관성 유지

3. 논리적 추론

  • Sudoku, constraint satisfaction problems
  • Mathematical proof generation
  • Code synthesis with constraints

4. 멀티모달 생성

  • 이미지 토큰화 후 MDM 적용
  • Text-to-image에서 양방향 attention 활용

ARMs vs MDMs 비교 요약

측면 ARMs MDMs
학습 복잡도 O(n) subproblems O(2^n) subproblems
추론 순서 고정 (left-to-right) 유연 (adaptive 가능)
Teacher forcing 가능 어려움
양방향 컨텍스트 불가 가능
Constraint 문제 약함 강함 (adaptive 시)
학습 안정성 높음 낮음 (variance 큼)
추론 효율성 순차적 병렬화 가능

핵심 인사이트

  1. Trade-off 이해: MDM은 학습의 어려움을 추론의 유연성과 교환
  2. Adaptive의 핵심: 올바른 디코딩 순서가 성능을 극적으로 개선
  3. Constraint propagation: MDM + adaptive = 자연스러운 constraint solver
  4. 실용성: 논리적 추론, 구조적 생성에서 ARM 대비 우위

한계 및 향후 연구

현재 한계

  • 학습 시 variance가 높음 (다양한 마스킹 패턴)
  • 최적 디코딩 순서 찾기가 NP-hard일 수 있음
  • 대규모 언어 모델로의 확장 검증 필요

향후 방향

  1. 효율적 학습: Curriculum learning, importance sampling
  2. Better ordering: 학습된 ordering policy
  3. Hybrid approaches: ARM + MDM 결합
  4. Scaling: LLM 규모에서의 MDM 탐구

참고 문헌

  1. Shah et al. "Train for the Worst, Plan for the Best: Understanding Token Ordering in Masked Diffusions" ICML 2025
  2. Austin et al. "Structured Denoising Diffusion Models in Discrete State-Spaces" NeurIPS 2021
  3. He et al. "Diffusion Language Models Can Perform Many Tasks with Scaling and Instruction-Finetuning" arXiv 2023
  4. Lou et al. "Discrete Diffusion Modeling by Estimating the Ratios of the Data Distribution" ICML 2024
  5. Sahoo et al. "Simple and Effective Masked Diffusion Language Models" NeurIPS 2024