Visual Autoregressive Modeling (VAR)
메타정보
| 항목 |
내용 |
| 논문 |
Visual Autoregressive Modeling: Scalable Image Generation via Next-Scale Prediction |
| 저자 |
Keyu Tian, Yi Jiang, Zehuan Yuan, Bingyue Peng, Liwei Wang |
| 기관 |
Peking University, ByteDance |
| 발표 |
NeurIPS 2024 Best Paper |
| arXiv |
2404.02905 |
| 코드 |
github.com/FoundationVision/VAR |
| 키워드 |
Autoregressive Models, Image Generation, Scaling Laws, Next-Scale Prediction, VQ-VAE |
개요
Visual Autoregressive Modeling (VAR)은 이미지 생성을 위한 새로운 autoregressive 패러다임으로, 기존의 raster-scan "next-token prediction" 대신 coarse-to-fine "next-scale prediction"을 도입했다. 이 접근법으로 GPT 스타일 AR 모델이 최초로 Diffusion Transformer를 능가했다.
핵심 성과:
- ImageNet 256x256에서 FID 18.65 → 1.73 (10배 개선)
- Inception Score 80.4 → 350.2 (4배 개선)
- Diffusion 대비 20배 빠른 inference
- LLM과 유사한 Scaling Laws 발견 (상관계수 -0.998)
- Zero-shot 일반화: inpainting, outpainting, editing
배경: 기존 Visual AR 모델의 한계
기존 AR 모델의 접근법
기존 방식 (Raster-Scan Next-Token Prediction):
이미지 → VQ-VAE → 토큰 시퀀스 → 왼쪽→오른쪽 순차 예측
[1] → [2] → [3] → [4] → ...
↓ ↓ ↓ ↓
[5] → [6] → [7] → [8] → ...
문제점
| 문제 |
설명 |
| Unnatural Order |
이미지의 2D 구조를 1D 시퀀스로 강제 변환 |
| Long Sequence |
256x256 이미지 → 16x16 토큰 = 256개 토큰 순차 예측 |
| Slow Inference |
각 토큰을 하나씩 순차적으로 생성 |
| No Scaling Laws |
LLM처럼 모델 크기에 따른 명확한 성능 향상 없음 |
| Diffusion 대비 열등 |
FID, IS 모든 지표에서 DiT에 뒤처짐 |
대표적 기존 AR 모델 성능 (ImageNet 256x256)
| 모델 |
FID |
IS |
방식 |
| VQGAN |
18.65 |
80.4 |
Raster-scan AR |
| ViT-VQGAN |
4.17 |
175.1 |
Raster-scan AR |
| RQ-Transformer |
7.55 |
134.0 |
Residual Quantization |
| DiT-XL/2 |
2.27 |
278.2 |
Diffusion |
VAR의 핵심 아이디어
Next-Scale Prediction 패러다임
VAR 방식 (Coarse-to-Fine):
저해상도(coarse) → 고해상도(fine) 점진적 예측
Scale 1: [1x1] → 전체 이미지 대략적 구조
Scale 2: [2x2] → 4개 토큰으로 세부 추가
Scale 3: [4x4] → 16개 토큰으로 더 세부 추가
...
Scale K: [16x16] → 최종 256개 토큰
핵심 통찰: 인간의 시각 인지와 유사하게, 전체 구조를 먼저 파악하고 세부 사항을 점진적으로 추가한다.
수학적 정의
기존 AR:
p(x) = prod_{i=1}^{n} p(x_i | x_1, ..., x_{i-1})
단점: 각 토큰을 순차적으로 1개씩 예측
VAR:
p(r_1, r_2, ..., r_K) = prod_{k=1}^{K} p(r_k | r_1, ..., r_{k-1})
r_k: scale k의 토큰 맵 (병렬 예측 가능)
K: 총 scale 수 (일반적으로 10개)
장점: 각 scale 내의 토큰들을 병렬로 동시 예측한다.
아키텍처
Multi-Scale VQ-VAE
VAR의 핵심 구성 요소는 Multi-Scale VQVAE로, 이미지를 여러 해상도의 토큰 맵으로 인코딩한다.
입력 이미지 (256x256)
↓
Encoder
↓
Feature Map (16x16)
↓
Multi-Scale Quantization
↓
r_1 (1x1), r_2 (2x2), r_3 (3x3), ..., r_10 (16x16)
Scale 구성
| Scale (k) |
해상도 |
토큰 수 |
누적 토큰 |
| 1 |
1x1 |
1 |
1 |
| 2 |
2x2 |
4 |
5 |
| 3 |
3x3 |
9 |
14 |
| 4 |
4x4 |
16 |
30 |
| 5 |
5x5 |
25 |
55 |
| 6 |
6x6 |
36 |
91 |
| 7 |
8x8 |
64 |
155 |
| 8 |
10x10 |
100 |
255 |
| 9 |
13x13 |
169 |
424 |
| 10 |
16x16 |
256 |
680 |
입력: class embedding + 이전 scale 토큰들
↓
[CLS] [r_1] [r_2] ... [r_{k-1}]
↓
Transformer Blocks (Causal Attention)
↓
Next-Scale 토큰 예측 (r_k)
↓
병렬 토큰 생성 (전체 scale 동시 예측)
Attention 마스크
Scale-wise Causal Mask:
CLS r1 r2 r3 ...
CLS [ 1 0 0 0 ... ]
r1 [ 1 1 0 0 ... ]
r2 [ 1 1 1 0 ... ]
r3 [ 1 1 1 1 ... ]
...
각 scale은 이전 scale들만 참조 가능
같은 scale 내 토큰들은 서로 참조 가능 (bidirectional)
학습 방법
목적 함수
L = -sum_{k=1}^{K} log p(r_k | c, r_1, ..., r_{k-1})
c: class condition
r_k: scale k의 ground truth 토큰 맵
학습 설정
| 항목 |
값 |
| Optimizer |
AdamW |
| Learning Rate |
1e-4 (cosine decay) |
| Batch Size |
768 |
| Epochs |
350 |
| Warmup |
100 epochs |
| Weight Decay |
0.05 |
| Codebook Size |
4096 (V=4096) |
| Embedding Dim |
32 |
Classifier-Free Guidance (CFG)
# Inference with CFG
def sample_with_cfg(model, class_label, cfg_scale=1.5):
logits_cond = model(class_label) # conditional
logits_uncond = model(null_label) # unconditional
# CFG: 조건부 방향으로 더 강하게 이동
logits = logits_uncond + cfg_scale * (logits_cond - logits_uncond)
return sample_from_logits(logits)
성능 비교
ImageNet 256x256 Class-Conditional Generation
| 모델 |
타입 |
FID |
IS |
Params |
Speed |
| LDM-4 |
Diffusion |
3.60 |
247.7 |
400M |
- |
| DiT-XL/2 |
Diffusion |
2.27 |
278.2 |
675M |
1x |
| VQGAN |
AR |
18.65 |
80.4 |
227M |
- |
| VAR-d16 |
VAR |
3.30 |
274.4 |
310M |
15x |
| VAR-d20 |
VAR |
2.57 |
302.6 |
600M |
12x |
| VAR-d24 |
VAR |
2.09 |
312.0 |
1.0B |
10x |
| VAR-d30 |
VAR |
1.73 |
350.2 |
2.0B |
8x |
Scaling Laws
VAR은 LLM과 유사한 power-law scaling을 보인다:
Loss ~ N^(-0.067)
N: 모델 파라미터 수
상관계수: -0.998 (거의 완벽한 선형 관계)
| 모델 크기 |
Loss |
FID |
| 310M |
2.95 |
3.30 |
| 600M |
2.82 |
2.57 |
| 1.0B |
2.70 |
2.09 |
| 2.0B |
2.56 |
1.73 |
추론 속도 분석
단계별 비교
Diffusion (DiT-XL/2):
- 250 denoising steps 필요
- 각 step마다 full forward pass
- A100 기준: ~6초/이미지
VAR (d30):
- 10 scale steps
- 각 scale 내 병렬 예측
- A100 기준: ~0.3초/이미지 (20x 빠름)
Inference 알고리즘
def var_inference(model, class_label, K=10, cfg_scale=1.5):
"""VAR 추론 (단순화)"""
tokens = []
for k in range(1, K+1):
# 이전 scale 토큰들을 컨텍스트로 사용
context = torch.cat(tokens, dim=1) if tokens else None
# scale k의 모든 토큰을 병렬 예측
logits = model.forward_scale(context, class_label, scale=k)
# CFG 적용
if cfg_scale > 1.0:
logits = apply_cfg(model, logits, class_label, cfg_scale)
# 샘플링
scale_tokens = sample_tokens(logits)
tokens.append(scale_tokens)
# 토큰 → 이미지 디코딩
image = model.decode(torch.cat(tokens, dim=1))
return image
Zero-Shot 응용
Inpainting
입력: 마스킹된 이미지 + 마스크
방법:
1. 마스킹되지 않은 영역의 토큰 고정
2. 마스킹된 영역만 조건부 생성
3. 각 scale에서 일관성 유지
Outpainting
입력: 중앙 이미지 + 확장 영역 지정
방법:
1. 저해상도에서 전체 구조 생성
2. 고해상도에서 원본 영역 고정
3. 확장 영역만 새로 생성
Image Editing
입력: 원본 이미지 + 수정할 영역 + 새로운 조건
방법:
1. 원본 이미지를 multi-scale 토큰으로 인코딩
2. 수정 영역의 토큰만 재생성
3. 나머지 영역은 원본 유지
Python 구현 예시
Multi-Scale Tokenizer
import torch
import torch.nn as nn
import torch.nn.functional as F
class MultiScaleVQVAE(nn.Module):
"""Multi-Scale VQ-VAE for VAR"""
def __init__(
self,
vocab_size: int = 4096,
embed_dim: int = 32,
scales: list = [1, 2, 3, 4, 5, 6, 8, 10, 13, 16]
):
super().__init__()
self.scales = scales
self.vocab_size = vocab_size
self.embed_dim = embed_dim
# 공유 codebook
self.codebook = nn.Embedding(vocab_size, embed_dim)
# Encoder/Decoder
self.encoder = self._build_encoder()
self.decoder = self._build_decoder()
# Scale별 projection
self.scale_projs = nn.ModuleList([
nn.Conv2d(embed_dim, embed_dim, 1)
for _ in scales
])
def _build_encoder(self):
return nn.Sequential(
nn.Conv2d(3, 64, 4, 2, 1),
nn.ReLU(),
nn.Conv2d(64, 128, 4, 2, 1),
nn.ReLU(),
nn.Conv2d(128, 256, 4, 2, 1),
nn.ReLU(),
nn.Conv2d(256, self.embed_dim, 4, 2, 1),
)
def _build_decoder(self):
return nn.Sequential(
nn.ConvTranspose2d(self.embed_dim, 256, 4, 2, 1),
nn.ReLU(),
nn.ConvTranspose2d(256, 128, 4, 2, 1),
nn.ReLU(),
nn.ConvTranspose2d(128, 64, 4, 2, 1),
nn.ReLU(),
nn.ConvTranspose2d(64, 3, 4, 2, 1),
nn.Tanh(),
)
def encode_multiscale(self, x):
"""이미지를 multi-scale 토큰으로 인코딩"""
# x: (B, 3, 256, 256)
z = self.encoder(x) # (B, D, 16, 16)
scale_tokens = []
for i, s in enumerate(self.scales):
# 각 scale로 downsampling
z_s = F.interpolate(z, size=(s, s), mode='bilinear')
z_s = self.scale_projs[i](z_s)
# Quantize
tokens = self.quantize(z_s)
scale_tokens.append(tokens)
return scale_tokens
def quantize(self, z):
"""Vector quantization"""
B, D, H, W = z.shape
z_flat = z.permute(0, 2, 3, 1).reshape(-1, D)
# 가장 가까운 codebook entry 찾기
distances = torch.cdist(z_flat, self.codebook.weight)
indices = distances.argmin(dim=-1)
return indices.reshape(B, H, W)
def decode_from_tokens(self, all_tokens):
"""Multi-scale 토큰에서 이미지 복원"""
# 최종 scale 토큰만 사용 (또는 모든 scale 합성)
final_tokens = all_tokens[-1]
B, H, W = final_tokens.shape
z_q = self.codebook(final_tokens)
z_q = z_q.permute(0, 3, 1, 2)
return self.decoder(z_q)
class VARTransformer(nn.Module):
"""Visual Autoregressive Transformer"""
def __init__(
self,
vocab_size: int = 4096,
num_classes: int = 1000,
embed_dim: int = 1024,
depth: int = 24,
num_heads: int = 16,
scales: list = [1, 2, 3, 4, 5, 6, 8, 10, 13, 16]
):
super().__init__()
self.vocab_size = vocab_size
self.num_classes = num_classes
self.embed_dim = embed_dim
self.scales = scales
# Token embedding
self.token_embed = nn.Embedding(vocab_size, embed_dim)
# Class embedding
self.class_embed = nn.Embedding(num_classes + 1, embed_dim) # +1 for null
# Position embedding (scale-aware)
max_tokens = sum(s * s for s in scales)
self.pos_embed = nn.Parameter(torch.randn(1, max_tokens + 1, embed_dim) * 0.02)
# Transformer blocks
self.blocks = nn.ModuleList([
TransformerBlock(embed_dim, num_heads)
for _ in range(depth)
])
# Scale-specific output heads
self.output_heads = nn.ModuleList([
nn.Linear(embed_dim, vocab_size)
for _ in scales
])
self.norm = nn.LayerNorm(embed_dim)
def forward(self, tokens_list, class_label, target_scale=None):
"""
tokens_list: 이전 scale들의 토큰 리스트
class_label: (B,) class indices
target_scale: 예측할 scale index
"""
B = class_label.shape[0]
# Class embedding
class_emb = self.class_embed(class_label).unsqueeze(1) # (B, 1, D)
# 이전 scale 토큰들 임베딩
if tokens_list:
prev_tokens = torch.cat([t.flatten(1) for t in tokens_list], dim=1)
token_emb = self.token_embed(prev_tokens) # (B, N_prev, D)
x = torch.cat([class_emb, token_emb], dim=1)
else:
x = class_emb
# Position embedding
x = x + self.pos_embed[:, :x.shape[1]]
# Causal mask (scale-wise)
mask = self.create_scale_causal_mask(tokens_list, target_scale)
# Transformer forward
for block in self.blocks:
x = block(x, mask)
x = self.norm(x)
# Output logits for target scale
if target_scale is not None:
logits = self.output_heads[target_scale](x[:, -1:])
return logits.expand(-1, self.scales[target_scale] ** 2, -1)
return x
def create_scale_causal_mask(self, tokens_list, target_scale):
"""Scale-wise causal attention mask"""
# 구현 생략 (scale 간 causal, scale 내 bidirectional)
return None
@torch.no_grad()
def generate(self, class_label, cfg_scale=1.5, temperature=1.0):
"""전체 이미지 생성"""
B = class_label.shape[0]
generated_tokens = []
for k, scale in enumerate(self.scales):
# Forward pass
logits = self.forward(generated_tokens, class_label, target_scale=k)
# CFG
if cfg_scale > 1.0:
null_label = torch.full_like(class_label, self.num_classes)
logits_uncond = self.forward(generated_tokens, null_label, target_scale=k)
logits = logits_uncond + cfg_scale * (logits - logits_uncond)
# Sample
probs = F.softmax(logits / temperature, dim=-1)
tokens = torch.multinomial(probs.view(-1, self.vocab_size), 1)
tokens = tokens.view(B, scale, scale)
generated_tokens.append(tokens)
return generated_tokens
class TransformerBlock(nn.Module):
"""Standard Transformer block with pre-norm"""
def __init__(self, embed_dim, num_heads, mlp_ratio=4.0, dropout=0.0):
super().__init__()
self.norm1 = nn.LayerNorm(embed_dim)
self.attn = nn.MultiheadAttention(embed_dim, num_heads, dropout=dropout, batch_first=True)
self.norm2 = nn.LayerNorm(embed_dim)
self.mlp = nn.Sequential(
nn.Linear(embed_dim, int(embed_dim * mlp_ratio)),
nn.GELU(),
nn.Linear(int(embed_dim * mlp_ratio), embed_dim),
nn.Dropout(dropout),
)
def forward(self, x, mask=None):
x = x + self.attn(self.norm1(x), self.norm1(x), self.norm1(x), attn_mask=mask)[0]
x = x + self.mlp(self.norm2(x))
return x
학습 루프
def train_var(
model: VARTransformer,
vqvae: MultiScaleVQVAE,
dataloader,
epochs: int = 350,
lr: float = 1e-4,
device: str = 'cuda'
):
"""VAR 학습"""
optimizer = torch.optim.AdamW(model.parameters(), lr=lr, weight_decay=0.05)
scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=epochs)
model.train()
for epoch in range(epochs):
total_loss = 0
for batch_idx, (images, labels) in enumerate(dataloader):
images, labels = images.to(device), labels.to(device)
# Multi-scale 토큰화
with torch.no_grad():
scale_tokens = vqvae.encode_multiscale(images)
# 각 scale에 대한 loss 계산
loss = 0
for k in range(len(model.scales)):
# 이전 scale 토큰들을 입력으로
prev_tokens = scale_tokens[:k] if k > 0 else None
target_tokens = scale_tokens[k].flatten(1)
# Forward
logits = model(prev_tokens, labels, target_scale=k)
# Cross-entropy loss
loss += F.cross_entropy(
logits.reshape(-1, model.vocab_size),
target_tokens.reshape(-1)
)
loss = loss / len(model.scales)
optimizer.zero_grad()
loss.backward()
optimizer.step()
total_loss += loss.item()
scheduler.step()
print(f"Epoch {epoch+1}/{epochs}, Loss: {total_loss/len(dataloader):.4f}")
후속 연구
| 연구 |
기여 |
연도 |
| VAR-CLIP |
CLIP 조건화로 text-to-image 확장 |
2024 |
| Infinity |
무한 해상도 생성을 위한 VAR 확장 |
2024 |
| MAR |
Masked AR과 VAR의 결합 |
2024 |
| Open-MAGVIT2 |
오픈소스 VQVAE + VAR 구현 |
2024 |
| LlamaGen |
VAR 아이디어를 LLaMA 아키텍처에 적용 |
2024 |
핵심 요약
| 측면 |
VAR의 혁신 |
| 패러다임 |
Next-token → Next-scale prediction |
| 생성 순서 |
Raster-scan → Coarse-to-fine |
| 토큰 예측 |
순차적 → Scale 내 병렬 |
| Diffusion 비교 |
처음으로 AR이 DiT 능가 |
| Scaling Laws |
LLM 수준의 명확한 power-law |
| 속도 |
20배 빠른 inference |
| Zero-shot |
Inpainting, outpainting, editing 지원 |
참고 자료
- VAR 논문 (arXiv)
- 공식 코드 (GitHub)
- 프로젝트 페이지
- NeurIPS 2024 Best Paper