콘텐츠로 이동
Data Prep
상세

Visual Autoregressive Modeling (VAR)

메타정보

항목 내용
논문 Visual Autoregressive Modeling: Scalable Image Generation via Next-Scale Prediction
저자 Keyu Tian, Yi Jiang, Zehuan Yuan, Bingyue Peng, Liwei Wang
기관 Peking University, ByteDance
발표 NeurIPS 2024 Best Paper
arXiv 2404.02905
코드 github.com/FoundationVision/VAR
키워드 Autoregressive Models, Image Generation, Scaling Laws, Next-Scale Prediction, VQ-VAE

개요

Visual Autoregressive Modeling (VAR)은 이미지 생성을 위한 새로운 autoregressive 패러다임으로, 기존의 raster-scan "next-token prediction" 대신 coarse-to-fine "next-scale prediction"을 도입했다. 이 접근법으로 GPT 스타일 AR 모델이 최초로 Diffusion Transformer를 능가했다.

핵심 성과: - ImageNet 256x256에서 FID 18.65 → 1.73 (10배 개선) - Inception Score 80.4 → 350.2 (4배 개선) - Diffusion 대비 20배 빠른 inference - LLM과 유사한 Scaling Laws 발견 (상관계수 -0.998) - Zero-shot 일반화: inpainting, outpainting, editing


배경: 기존 Visual AR 모델의 한계

기존 AR 모델의 접근법

기존 방식 (Raster-Scan Next-Token Prediction):

이미지 → VQ-VAE → 토큰 시퀀스 → 왼쪽→오른쪽 순차 예측

[1] → [2] → [3] → [4] → ...
 ↓     ↓     ↓     ↓
[5] → [6] → [7] → [8] → ...

문제점

문제 설명
Unnatural Order 이미지의 2D 구조를 1D 시퀀스로 강제 변환
Long Sequence 256x256 이미지 → 16x16 토큰 = 256개 토큰 순차 예측
Slow Inference 각 토큰을 하나씩 순차적으로 생성
No Scaling Laws LLM처럼 모델 크기에 따른 명확한 성능 향상 없음
Diffusion 대비 열등 FID, IS 모든 지표에서 DiT에 뒤처짐

대표적 기존 AR 모델 성능 (ImageNet 256x256)

모델 FID IS 방식
VQGAN 18.65 80.4 Raster-scan AR
ViT-VQGAN 4.17 175.1 Raster-scan AR
RQ-Transformer 7.55 134.0 Residual Quantization
DiT-XL/2 2.27 278.2 Diffusion

VAR의 핵심 아이디어

Next-Scale Prediction 패러다임

VAR 방식 (Coarse-to-Fine):

저해상도(coarse) → 고해상도(fine) 점진적 예측

Scale 1: [1x1]     → 전체 이미지 대략적 구조
Scale 2: [2x2]     → 4개 토큰으로 세부 추가
Scale 3: [4x4]     → 16개 토큰으로 더 세부 추가
...
Scale K: [16x16]   → 최종 256개 토큰

핵심 통찰: 인간의 시각 인지와 유사하게, 전체 구조를 먼저 파악하고 세부 사항을 점진적으로 추가한다.

수학적 정의

기존 AR:

p(x) = prod_{i=1}^{n} p(x_i | x_1, ..., x_{i-1})

단점: 각 토큰을 순차적으로 1개씩 예측

VAR:

p(r_1, r_2, ..., r_K) = prod_{k=1}^{K} p(r_k | r_1, ..., r_{k-1})

r_k: scale k의 토큰 맵 (병렬 예측 가능)
K: 총 scale 수 (일반적으로 10개)

장점: 각 scale 내의 토큰들을 병렬로 동시 예측한다.


아키텍처

Multi-Scale VQ-VAE

VAR의 핵심 구성 요소는 Multi-Scale VQVAE로, 이미지를 여러 해상도의 토큰 맵으로 인코딩한다.

입력 이미지 (256x256)
    Encoder
Feature Map (16x16)
Multi-Scale Quantization
r_1 (1x1), r_2 (2x2), r_3 (3x3), ..., r_10 (16x16)

Scale 구성

Scale (k) 해상도 토큰 수 누적 토큰
1 1x1 1 1
2 2x2 4 5
3 3x3 9 14
4 4x4 16 30
5 5x5 25 55
6 6x6 36 91
7 8x8 64 155
8 10x10 100 255
9 13x13 169 424
10 16x16 256 680

VAR Transformer

입력: class embedding + 이전 scale 토큰들
   [CLS] [r_1] [r_2] ... [r_{k-1}]
  Transformer Blocks (Causal Attention)
    Next-Scale 토큰 예측 (r_k)
   병렬 토큰 생성 (전체 scale 동시 예측)

Attention 마스크

Scale-wise Causal Mask:

       CLS  r1   r2   r3   ...
CLS  [  1   0    0    0   ... ]
r1   [  1   1    0    0   ... ]
r2   [  1   1    1    0   ... ]
r3   [  1   1    1    1   ... ]
...

각 scale은 이전 scale들만 참조 가능
같은 scale 내 토큰들은 서로 참조 가능 (bidirectional)

학습 방법

목적 함수

L = -sum_{k=1}^{K} log p(r_k | c, r_1, ..., r_{k-1})

c: class condition
r_k: scale k의 ground truth 토큰 맵

학습 설정

항목
Optimizer AdamW
Learning Rate 1e-4 (cosine decay)
Batch Size 768
Epochs 350
Warmup 100 epochs
Weight Decay 0.05
Codebook Size 4096 (V=4096)
Embedding Dim 32

Classifier-Free Guidance (CFG)

# Inference with CFG
def sample_with_cfg(model, class_label, cfg_scale=1.5):
    logits_cond = model(class_label)      # conditional
    logits_uncond = model(null_label)     # unconditional

    # CFG: 조건부 방향으로 더 강하게 이동
    logits = logits_uncond + cfg_scale * (logits_cond - logits_uncond)
    return sample_from_logits(logits)

성능 비교

ImageNet 256x256 Class-Conditional Generation

모델 타입 FID IS Params Speed
LDM-4 Diffusion 3.60 247.7 400M -
DiT-XL/2 Diffusion 2.27 278.2 675M 1x
VQGAN AR 18.65 80.4 227M -
VAR-d16 VAR 3.30 274.4 310M 15x
VAR-d20 VAR 2.57 302.6 600M 12x
VAR-d24 VAR 2.09 312.0 1.0B 10x
VAR-d30 VAR 1.73 350.2 2.0B 8x

Scaling Laws

VAR은 LLM과 유사한 power-law scaling을 보인다:

Loss ~ N^(-0.067)

N: 모델 파라미터 수
상관계수: -0.998 (거의 완벽한 선형 관계)
모델 크기 Loss FID
310M 2.95 3.30
600M 2.82 2.57
1.0B 2.70 2.09
2.0B 2.56 1.73

추론 속도 분석

단계별 비교

Diffusion (DiT-XL/2):
- 250 denoising steps 필요
- 각 step마다 full forward pass
- A100 기준: ~6초/이미지

VAR (d30):
- 10 scale steps
- 각 scale 내 병렬 예측
- A100 기준: ~0.3초/이미지 (20x 빠름)

Inference 알고리즘

def var_inference(model, class_label, K=10, cfg_scale=1.5):
    """VAR 추론 (단순화)"""
    tokens = []

    for k in range(1, K+1):
        # 이전 scale 토큰들을 컨텍스트로 사용
        context = torch.cat(tokens, dim=1) if tokens else None

        # scale k의 모든 토큰을 병렬 예측
        logits = model.forward_scale(context, class_label, scale=k)

        # CFG 적용
        if cfg_scale > 1.0:
            logits = apply_cfg(model, logits, class_label, cfg_scale)

        # 샘플링
        scale_tokens = sample_tokens(logits)
        tokens.append(scale_tokens)

    # 토큰 → 이미지 디코딩
    image = model.decode(torch.cat(tokens, dim=1))
    return image

Zero-Shot 응용

Inpainting

입력: 마스킹된 이미지 + 마스크

방법:
1. 마스킹되지 않은 영역의 토큰 고정
2. 마스킹된 영역만 조건부 생성
3. 각 scale에서 일관성 유지

Outpainting

입력: 중앙 이미지 + 확장 영역 지정

방법:
1. 저해상도에서 전체 구조 생성
2. 고해상도에서 원본 영역 고정
3. 확장 영역만 새로 생성

Image Editing

입력: 원본 이미지 + 수정할 영역 + 새로운 조건

방법:
1. 원본 이미지를 multi-scale 토큰으로 인코딩
2. 수정 영역의 토큰만 재생성
3. 나머지 영역은 원본 유지

Python 구현 예시

Multi-Scale Tokenizer

import torch
import torch.nn as nn
import torch.nn.functional as F

class MultiScaleVQVAE(nn.Module):
    """Multi-Scale VQ-VAE for VAR"""

    def __init__(
        self,
        vocab_size: int = 4096,
        embed_dim: int = 32,
        scales: list = [1, 2, 3, 4, 5, 6, 8, 10, 13, 16]
    ):
        super().__init__()
        self.scales = scales
        self.vocab_size = vocab_size
        self.embed_dim = embed_dim

        # 공유 codebook
        self.codebook = nn.Embedding(vocab_size, embed_dim)

        # Encoder/Decoder
        self.encoder = self._build_encoder()
        self.decoder = self._build_decoder()

        # Scale별 projection
        self.scale_projs = nn.ModuleList([
            nn.Conv2d(embed_dim, embed_dim, 1)
            for _ in scales
        ])

    def _build_encoder(self):
        return nn.Sequential(
            nn.Conv2d(3, 64, 4, 2, 1),
            nn.ReLU(),
            nn.Conv2d(64, 128, 4, 2, 1),
            nn.ReLU(),
            nn.Conv2d(128, 256, 4, 2, 1),
            nn.ReLU(),
            nn.Conv2d(256, self.embed_dim, 4, 2, 1),
        )

    def _build_decoder(self):
        return nn.Sequential(
            nn.ConvTranspose2d(self.embed_dim, 256, 4, 2, 1),
            nn.ReLU(),
            nn.ConvTranspose2d(256, 128, 4, 2, 1),
            nn.ReLU(),
            nn.ConvTranspose2d(128, 64, 4, 2, 1),
            nn.ReLU(),
            nn.ConvTranspose2d(64, 3, 4, 2, 1),
            nn.Tanh(),
        )

    def encode_multiscale(self, x):
        """이미지를 multi-scale 토큰으로 인코딩"""
        # x: (B, 3, 256, 256)
        z = self.encoder(x)  # (B, D, 16, 16)

        scale_tokens = []
        for i, s in enumerate(self.scales):
            # 각 scale로 downsampling
            z_s = F.interpolate(z, size=(s, s), mode='bilinear')
            z_s = self.scale_projs[i](z_s)

            # Quantize
            tokens = self.quantize(z_s)
            scale_tokens.append(tokens)

        return scale_tokens

    def quantize(self, z):
        """Vector quantization"""
        B, D, H, W = z.shape
        z_flat = z.permute(0, 2, 3, 1).reshape(-1, D)

        # 가장 가까운 codebook entry 찾기
        distances = torch.cdist(z_flat, self.codebook.weight)
        indices = distances.argmin(dim=-1)

        return indices.reshape(B, H, W)

    def decode_from_tokens(self, all_tokens):
        """Multi-scale 토큰에서 이미지 복원"""
        # 최종 scale 토큰만 사용 (또는 모든 scale 합성)
        final_tokens = all_tokens[-1]
        B, H, W = final_tokens.shape

        z_q = self.codebook(final_tokens)
        z_q = z_q.permute(0, 3, 1, 2)

        return self.decoder(z_q)

VAR Transformer

class VARTransformer(nn.Module):
    """Visual Autoregressive Transformer"""

    def __init__(
        self,
        vocab_size: int = 4096,
        num_classes: int = 1000,
        embed_dim: int = 1024,
        depth: int = 24,
        num_heads: int = 16,
        scales: list = [1, 2, 3, 4, 5, 6, 8, 10, 13, 16]
    ):
        super().__init__()
        self.vocab_size = vocab_size
        self.num_classes = num_classes
        self.embed_dim = embed_dim
        self.scales = scales

        # Token embedding
        self.token_embed = nn.Embedding(vocab_size, embed_dim)

        # Class embedding
        self.class_embed = nn.Embedding(num_classes + 1, embed_dim)  # +1 for null

        # Position embedding (scale-aware)
        max_tokens = sum(s * s for s in scales)
        self.pos_embed = nn.Parameter(torch.randn(1, max_tokens + 1, embed_dim) * 0.02)

        # Transformer blocks
        self.blocks = nn.ModuleList([
            TransformerBlock(embed_dim, num_heads)
            for _ in range(depth)
        ])

        # Scale-specific output heads
        self.output_heads = nn.ModuleList([
            nn.Linear(embed_dim, vocab_size)
            for _ in scales
        ])

        self.norm = nn.LayerNorm(embed_dim)

    def forward(self, tokens_list, class_label, target_scale=None):
        """
        tokens_list: 이전 scale들의 토큰 리스트
        class_label: (B,) class indices
        target_scale: 예측할 scale index
        """
        B = class_label.shape[0]

        # Class embedding
        class_emb = self.class_embed(class_label).unsqueeze(1)  # (B, 1, D)

        # 이전 scale 토큰들 임베딩
        if tokens_list:
            prev_tokens = torch.cat([t.flatten(1) for t in tokens_list], dim=1)
            token_emb = self.token_embed(prev_tokens)  # (B, N_prev, D)
            x = torch.cat([class_emb, token_emb], dim=1)
        else:
            x = class_emb

        # Position embedding
        x = x + self.pos_embed[:, :x.shape[1]]

        # Causal mask (scale-wise)
        mask = self.create_scale_causal_mask(tokens_list, target_scale)

        # Transformer forward
        for block in self.blocks:
            x = block(x, mask)

        x = self.norm(x)

        # Output logits for target scale
        if target_scale is not None:
            logits = self.output_heads[target_scale](x[:, -1:])
            return logits.expand(-1, self.scales[target_scale] ** 2, -1)

        return x

    def create_scale_causal_mask(self, tokens_list, target_scale):
        """Scale-wise causal attention mask"""
        # 구현 생략 (scale 간 causal, scale 내 bidirectional)
        return None

    @torch.no_grad()
    def generate(self, class_label, cfg_scale=1.5, temperature=1.0):
        """전체 이미지 생성"""
        B = class_label.shape[0]
        generated_tokens = []

        for k, scale in enumerate(self.scales):
            # Forward pass
            logits = self.forward(generated_tokens, class_label, target_scale=k)

            # CFG
            if cfg_scale > 1.0:
                null_label = torch.full_like(class_label, self.num_classes)
                logits_uncond = self.forward(generated_tokens, null_label, target_scale=k)
                logits = logits_uncond + cfg_scale * (logits - logits_uncond)

            # Sample
            probs = F.softmax(logits / temperature, dim=-1)
            tokens = torch.multinomial(probs.view(-1, self.vocab_size), 1)
            tokens = tokens.view(B, scale, scale)

            generated_tokens.append(tokens)

        return generated_tokens


class TransformerBlock(nn.Module):
    """Standard Transformer block with pre-norm"""

    def __init__(self, embed_dim, num_heads, mlp_ratio=4.0, dropout=0.0):
        super().__init__()
        self.norm1 = nn.LayerNorm(embed_dim)
        self.attn = nn.MultiheadAttention(embed_dim, num_heads, dropout=dropout, batch_first=True)
        self.norm2 = nn.LayerNorm(embed_dim)
        self.mlp = nn.Sequential(
            nn.Linear(embed_dim, int(embed_dim * mlp_ratio)),
            nn.GELU(),
            nn.Linear(int(embed_dim * mlp_ratio), embed_dim),
            nn.Dropout(dropout),
        )

    def forward(self, x, mask=None):
        x = x + self.attn(self.norm1(x), self.norm1(x), self.norm1(x), attn_mask=mask)[0]
        x = x + self.mlp(self.norm2(x))
        return x

학습 루프

def train_var(
    model: VARTransformer,
    vqvae: MultiScaleVQVAE,
    dataloader,
    epochs: int = 350,
    lr: float = 1e-4,
    device: str = 'cuda'
):
    """VAR 학습"""
    optimizer = torch.optim.AdamW(model.parameters(), lr=lr, weight_decay=0.05)
    scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=epochs)

    model.train()

    for epoch in range(epochs):
        total_loss = 0

        for batch_idx, (images, labels) in enumerate(dataloader):
            images, labels = images.to(device), labels.to(device)

            # Multi-scale 토큰화
            with torch.no_grad():
                scale_tokens = vqvae.encode_multiscale(images)

            # 각 scale에 대한 loss 계산
            loss = 0
            for k in range(len(model.scales)):
                # 이전 scale 토큰들을 입력으로
                prev_tokens = scale_tokens[:k] if k > 0 else None
                target_tokens = scale_tokens[k].flatten(1)

                # Forward
                logits = model(prev_tokens, labels, target_scale=k)

                # Cross-entropy loss
                loss += F.cross_entropy(
                    logits.reshape(-1, model.vocab_size),
                    target_tokens.reshape(-1)
                )

            loss = loss / len(model.scales)

            optimizer.zero_grad()
            loss.backward()
            optimizer.step()

            total_loss += loss.item()

        scheduler.step()
        print(f"Epoch {epoch+1}/{epochs}, Loss: {total_loss/len(dataloader):.4f}")

후속 연구

연구 기여 연도
VAR-CLIP CLIP 조건화로 text-to-image 확장 2024
Infinity 무한 해상도 생성을 위한 VAR 확장 2024
MAR Masked AR과 VAR의 결합 2024
Open-MAGVIT2 오픈소스 VQVAE + VAR 구현 2024
LlamaGen VAR 아이디어를 LLaMA 아키텍처에 적용 2024

핵심 요약

측면 VAR의 혁신
패러다임 Next-token → Next-scale prediction
생성 순서 Raster-scan → Coarse-to-fine
토큰 예측 순차적 → Scale 내 병렬
Diffusion 비교 처음으로 AR이 DiT 능가
Scaling Laws LLM 수준의 명확한 power-law
속도 20배 빠른 inference
Zero-shot Inpainting, outpainting, editing 지원

참고 자료

  1. VAR 논문 (arXiv)
  2. 공식 코드 (GitHub)
  3. 프로젝트 페이지
  4. NeurIPS 2024 Best Paper