콘텐츠로 이동
Data Prep
상세

Transformer

2017년 "Attention is All You Need" 논문에서 제안된 아키텍처. RNN/LSTM을 대체하고 현대 LLM의 기반이 됨.

핵심 혁신

이전 (RNN/LSTM) Transformer
순차 처리 병렬 처리
기울기 소실 직접 연결
긴 의존성 어려움 전역 어텐션
느린 학습 빠른 학습

전체 아키텍처

Encoder-Decoder (원본)

transformer diagram 1

Decoder-only (GPT 스타일)

transformer diagram 2

핵심 구성 요소

1. Token Embedding

import torch
import torch.nn as nn

class TokenEmbedding(nn.Module):
    def __init__(self, vocab_size, d_model):
        super().__init__()
        self.embedding = nn.Embedding(vocab_size, d_model)
        self.d_model = d_model

    def forward(self, x):
        # 스케일링 (원본 논문)
        return self.embedding(x) * (self.d_model ** 0.5)

2. Positional Encoding

Transformer는 위치 정보가 없으므로 명시적으로 추가.

Sinusoidal (원본):

\[PE_{(pos, 2i)} = \sin(pos / 10000^{2i/d})$$ $$PE_{(pos, 2i+1)} = \cos(pos / 10000^{2i/d})\]
class SinusoidalPositionalEncoding(nn.Module):
    def __init__(self, d_model, max_len=5000):
        super().__init__()

        pe = torch.zeros(max_len, d_model)
        position = torch.arange(0, max_len).unsqueeze(1)
        div_term = torch.exp(torch.arange(0, d_model, 2) * 
                           -(math.log(10000.0) / d_model))

        pe[:, 0::2] = torch.sin(position * div_term)
        pe[:, 1::2] = torch.cos(position * div_term)

        self.register_buffer('pe', pe.unsqueeze(0))

    def forward(self, x):
        return x + self.pe[:, :x.size(1)]

Learned (GPT 스타일):

class LearnedPositionalEncoding(nn.Module):
    def __init__(self, d_model, max_len=2048):
        super().__init__()
        self.position_embedding = nn.Embedding(max_len, d_model)

    def forward(self, x):
        positions = torch.arange(x.size(1), device=x.device)
        return x + self.position_embedding(positions)

RoPE (Rotary Position Embedding) - LLaMA:

def rotate_half(x):
    x1 = x[..., :x.shape[-1]//2]
    x2 = x[..., x.shape[-1]//2:]
    return torch.cat((-x2, x1), dim=-1)

def apply_rotary_pos_emb(q, k, cos, sin):
    q_embed = (q * cos) + (rotate_half(q) * sin)
    k_embed = (k * cos) + (rotate_half(k) * sin)
    return q_embed, k_embed

3. Multi-Head Attention

class MultiHeadAttention(nn.Module):
    def __init__(self, d_model, num_heads, dropout=0.1):
        super().__init__()
        assert d_model % num_heads == 0

        self.d_model = d_model
        self.num_heads = num_heads
        self.d_k = d_model // num_heads

        self.W_q = nn.Linear(d_model, d_model)
        self.W_k = nn.Linear(d_model, d_model)
        self.W_v = nn.Linear(d_model, d_model)
        self.W_o = nn.Linear(d_model, d_model)

        self.dropout = nn.Dropout(dropout)

    def forward(self, x, mask=None):
        batch_size, seq_len, _ = x.shape

        # Linear projections
        Q = self.W_q(x)
        K = self.W_k(x)
        V = self.W_v(x)

        # Reshape to (batch, heads, seq_len, d_k)
        Q = Q.view(batch_size, seq_len, self.num_heads, self.d_k).transpose(1, 2)
        K = K.view(batch_size, seq_len, self.num_heads, self.d_k).transpose(1, 2)
        V = V.view(batch_size, seq_len, self.num_heads, self.d_k).transpose(1, 2)

        # Scaled Dot-Product Attention
        scores = torch.matmul(Q, K.transpose(-2, -1)) / (self.d_k ** 0.5)

        if mask is not None:
            scores = scores.masked_fill(mask == 0, float('-inf'))

        attn_weights = torch.softmax(scores, dim=-1)
        attn_weights = self.dropout(attn_weights)

        attn_output = torch.matmul(attn_weights, V)

        # Reshape back
        attn_output = attn_output.transpose(1, 2).contiguous()
        attn_output = attn_output.view(batch_size, seq_len, self.d_model)

        return self.W_o(attn_output)

4. Feed-Forward Network

class FeedForward(nn.Module):
    def __init__(self, d_model, d_ff, dropout=0.1):
        super().__init__()
        self.linear1 = nn.Linear(d_model, d_ff)
        self.linear2 = nn.Linear(d_ff, d_model)
        self.dropout = nn.Dropout(dropout)
        self.activation = nn.GELU()

    def forward(self, x):
        return self.linear2(self.dropout(self.activation(self.linear1(x))))

# SwiGLU (LLaMA)
class SwiGLU(nn.Module):
    def __init__(self, d_model, d_ff):
        super().__init__()
        self.w1 = nn.Linear(d_model, d_ff, bias=False)
        self.w2 = nn.Linear(d_ff, d_model, bias=False)
        self.w3 = nn.Linear(d_model, d_ff, bias=False)

    def forward(self, x):
        return self.w2(nn.functional.silu(self.w1(x)) * self.w3(x))

5. Transformer Block

class TransformerBlock(nn.Module):
    def __init__(self, d_model, num_heads, d_ff, dropout=0.1):
        super().__init__()
        self.attention = MultiHeadAttention(d_model, num_heads, dropout)
        self.norm1 = nn.LayerNorm(d_model)
        self.ffn = FeedForward(d_model, d_ff, dropout)
        self.norm2 = nn.LayerNorm(d_model)
        self.dropout = nn.Dropout(dropout)

    def forward(self, x, mask=None):
        # Pre-LN (현대적 방식)
        attn_out = self.attention(self.norm1(x), mask)
        x = x + self.dropout(attn_out)

        ff_out = self.ffn(self.norm2(x))
        x = x + self.dropout(ff_out)

        return x

전체 모델

class GPTModel(nn.Module):
    def __init__(self, vocab_size, d_model, num_heads, num_layers, d_ff, 
                 max_seq_len, dropout=0.1):
        super().__init__()

        self.token_embedding = nn.Embedding(vocab_size, d_model)
        self.position_embedding = nn.Embedding(max_seq_len, d_model)

        self.blocks = nn.ModuleList([
            TransformerBlock(d_model, num_heads, d_ff, dropout)
            for _ in range(num_layers)
        ])

        self.ln_f = nn.LayerNorm(d_model)
        self.lm_head = nn.Linear(d_model, vocab_size, bias=False)

        # Weight tying
        self.lm_head.weight = self.token_embedding.weight

        self.dropout = nn.Dropout(dropout)

    def forward(self, input_ids, targets=None):
        batch_size, seq_len = input_ids.shape

        # Embeddings
        tok_emb = self.token_embedding(input_ids)
        pos = torch.arange(seq_len, device=input_ids.device)
        pos_emb = self.position_embedding(pos)

        x = self.dropout(tok_emb + pos_emb)

        # Causal mask
        mask = torch.tril(torch.ones(seq_len, seq_len, device=input_ids.device))
        mask = mask.unsqueeze(0).unsqueeze(0)

        # Transformer blocks
        for block in self.blocks:
            x = block(x, mask)

        x = self.ln_f(x)
        logits = self.lm_head(x)

        loss = None
        if targets is not None:
            loss = nn.functional.cross_entropy(
                logits.view(-1, logits.size(-1)),
                targets.view(-1)
            )

        return logits, loss

모델 설정 예시

모델 파라미터 d_model n_heads n_layers d_ff
GPT-2 Small 124M 768 12 12 3072
GPT-2 Medium 355M 1024 16 24 4096
GPT-3 175B 12288 96 96 49152
LLaMA-7B 7B 4096 32 32 11008
LLaMA-70B 70B 8192 64 80 28672

실무 트레이드오프

아키텍처 선택

구조 장점 단점 사용 케이스
Encoder-only 양방향 문맥, 분류/임베딩에 강함 생성 불가 검색, 분류, NER
Decoder-only 생성에 최적화, 스케일링 용이 단방향 챗봇, 코드 생성
Encoder-Decoder 입출력 분리, 번역에 강함 복잡한 구조 번역, 요약

모델 크기 선택

추론 환경별 권장 모델 크기:

Consumer GPU (24GB):    7B ~ 13B (4bit 양자화)
Server GPU (80GB):      70B (FP16) 또는 405B (4bit)
CPU only:               7B 이하 (양자화 필수)
Edge device:            1B ~ 3B (GGUF 양자화)

Pre-LN vs Post-LN

# Post-LN (원본 Transformer) - 학습 불안정
x = LayerNorm(x + Attention(x))
x = LayerNorm(x + FFN(x))

# Pre-LN (현대 LLM) - 학습 안정적
x = x + Attention(LayerNorm(x))
x = x + FFN(LayerNorm(x))

Pre-LN 장점: 깊은 모델 학습 안정화, learning rate warmup 덜 민감 Pre-LN 단점: 약간의 성능 저하 가능성 (대부분 무시 가능)

메모리 최적화

# Gradient Checkpointing - 메모리 50% 절약, 연산 30% 증가
from torch.utils.checkpoint import checkpoint

class MemoryEfficientBlock(nn.Module):
    def forward(self, x, use_checkpoint=True):
        if use_checkpoint and self.training:
            return checkpoint(self._forward, x)
        return self._forward(x)

흔한 실수

실수 증상 해결책
스케일링 누락 학습 초기 loss 발산 embedding * sqrt(d_model)
잘못된 마스킹 미래 정보 유출 causal mask 확인
Weight tying 미적용 파라미터 낭비 lm_head.weight = embed.weight
LayerNorm 위치 오류 학습 불안정 Pre-LN 사용

참고 자료