콘텐츠로 이동
Data Prep
상세

Direct Preference Optimization (DPO)

메타 정보

항목 내용
논문 Direct Preference Optimization: Your Language Model is Secretly a Reward Model
저자 Rafael Rafailov, Archit Sharma, Eric Mitchell, Stefano Ermon, Christopher D. Manning, Chelsea Finn
소속 Stanford University
발표 NeurIPS 2023
arXiv 2305.18290
분야 LLM Alignment, Preference Learning

개요

Direct Preference Optimization(DPO)은 RLHF(Reinforcement Learning from Human Feedback)의 복잡성을 제거한 선호도 학습 알고리즘이다. 보상 모델을 명시적으로 학습하지 않고, 인간 선호도 데이터로부터 직접 언어 모델을 최적화한다.

핵심 통찰

RLHF의 보상 모델과 최적 정책 사이에는 해석적 관계가 존재한다. 이 관계를 역으로 활용하면, 보상 모델 학습과 RL 최적화 단계를 단일 분류 손실로 대체할 수 있다.


RLHF vs DPO 비교

RLHF 파이프라인

[사전학습 LM] --> [SFT] --> [보상 모델 학습] --> [PPO 최적화] --> [정렬된 LM]
                              |                    |
                          선호도 데이터          불안정, 복잡

RLHF 단계:

  1. Supervised Fine-Tuning (SFT): 고품질 데이터로 사전학습 모델 미세조정
  2. Reward Model Training: 인간 선호도 데이터로 보상 모델 학습
  3. RL Fine-Tuning: PPO로 보상 최대화하며 KL 발산 제약 적용

DPO 파이프라인

[사전학습 LM] --> [SFT] --> [DPO 최적화] --> [정렬된 LM]
                              |
                          선호도 데이터
                          (분류 손실만)

DPO 장점:

  • 보상 모델 학습 불필요
  • RL 알고리즘(PPO) 불필요
  • 학습 중 샘플링 불필요
  • 하이퍼파라미터 튜닝 최소화
  • 구현 단순화

수학적 배경

RLHF 목표 함수

RLHF는 다음 목표를 최적화한다:

max_pi E[r(x,y)] - beta * KL(pi || pi_ref)

여기서: - pi: 학습할 정책 - pi_ref: 참조 정책 (SFT 모델) - r(x,y): 보상 함수 - beta: KL 페널티 계수

최적 정책의 해석적 해

위 목표의 최적 정책은 다음과 같이 표현된다:

pi*(y|x) = (1/Z(x)) * pi_ref(y|x) * exp(r(x,y) / beta)

이를 보상에 대해 재정리하면:

r(x,y) = beta * log(pi*(y|x) / pi_ref(y|x)) + beta * log(Z(x))

DPO 손실 함수

Bradley-Terry 선호도 모델을 적용하고 위 관계를 대입하면:

L_DPO(pi; pi_ref) = -E[(x, y_w, y_l)] [ log sigmoid(beta * (log(pi(y_w|x)/pi_ref(y_w|x)) - log(pi(y_l|x)/pi_ref(y_l|x)))) ]

여기서: - y_w: 선호되는 응답 (winner) - y_l: 비선호 응답 (loser) - beta: 온도 파라미터

손실 함수 해석

DPO 손실은 다음을 수행한다: 1. 선호 응답의 상대적 로그 확률 증가 2. 비선호 응답의 상대적 로그 확률 감소 3. 참조 모델 대비 변화량 제어


알고리즘

DPO 학습 절차

입력: 
  - 참조 정책 pi_ref (SFT 모델)
  - 선호도 데이터셋 D = {(x, y_w, y_l)}
  - 파라미터 beta

1. pi_theta를 pi_ref로 초기화
2. 각 배치 (x, y_w, y_l) in D:
   a. 로그 확률 계산:
      log_pi_w = log pi_theta(y_w|x)
      log_pi_l = log pi_theta(y_l|x)
      log_ref_w = log pi_ref(y_w|x)
      log_ref_l = log pi_ref(y_l|x)

   b. 보상 차이 계산:
      delta = beta * ((log_pi_w - log_ref_w) - (log_pi_l - log_ref_l))

   c. 손실 계산:
      loss = -log(sigmoid(delta))

   d. 그래디언트 업데이트
3. 학습된 pi_theta 반환

구현 고려사항

항목 권장값 설명
beta 0.1 - 0.5 낮으면 탐색 증가, 높으면 참조 모델 유지
학습률 1e-6 - 5e-6 SFT보다 낮게 설정
에폭 1-3 과적합 주의
배치 크기 32-128 GPU 메모리에 따라 조정

Python 구현 예시

기본 DPO 손실 함수

import torch
import torch.nn.functional as F


def compute_dpo_loss(
    policy_chosen_logps: torch.Tensor,
    policy_rejected_logps: torch.Tensor,
    reference_chosen_logps: torch.Tensor,
    reference_rejected_logps: torch.Tensor,
    beta: float = 0.1
) -> torch.Tensor:
    """
    DPO 손실 함수 계산

    Args:
        policy_chosen_logps: 정책 모델의 선호 응답 로그 확률
        policy_rejected_logps: 정책 모델의 비선호 응답 로그 확률
        reference_chosen_logps: 참조 모델의 선호 응답 로그 확률
        reference_rejected_logps: 참조 모델의 비선호 응답 로그 확률
        beta: 온도 파라미터

    Returns:
        DPO 손실값
    """
    # 로그 비율 계산
    pi_logratios = policy_chosen_logps - policy_rejected_logps
    ref_logratios = reference_chosen_logps - reference_rejected_logps

    # 보상 차이
    logits = beta * (pi_logratios - ref_logratios)

    # 이진 교차 엔트로피 손실
    loss = -F.logsigmoid(logits).mean()

    return loss

로그 확률 계산

def compute_log_probs(
    model: torch.nn.Module,
    input_ids: torch.Tensor,
    attention_mask: torch.Tensor,
    labels: torch.Tensor
) -> torch.Tensor:
    """
    시퀀스의 로그 확률 합 계산

    Args:
        model: 언어 모델
        input_ids: 입력 토큰 ID
        attention_mask: 어텐션 마스크
        labels: 레이블 (응답 부분만 유효)

    Returns:
        시퀀스별 로그 확률 합
    """
    outputs = model(
        input_ids=input_ids,
        attention_mask=attention_mask
    )
    logits = outputs.logits

    # 다음 토큰 예측을 위해 시프트
    shift_logits = logits[:, :-1, :].contiguous()
    shift_labels = labels[:, 1:].contiguous()

    # 토큰별 로그 확률
    log_probs = F.log_softmax(shift_logits, dim=-1)

    # 실제 토큰의 로그 확률 추출
    token_log_probs = torch.gather(
        log_probs, 
        dim=-1, 
        index=shift_labels.unsqueeze(-1)
    ).squeeze(-1)

    # 유효한 토큰만 합산 (labels != -100)
    mask = (shift_labels != -100).float()
    sequence_log_probs = (token_log_probs * mask).sum(dim=-1)

    return sequence_log_probs

전체 학습 루프

from transformers import AutoModelForCausalLM, AutoTokenizer
from torch.optim import AdamW
from tqdm import tqdm


class DPOTrainer:
    def __init__(
        self,
        model_name: str,
        beta: float = 0.1,
        learning_rate: float = 1e-6
    ):
        self.beta = beta

        # 정책 모델 (학습 대상)
        self.policy_model = AutoModelForCausalLM.from_pretrained(
            model_name,
            torch_dtype=torch.bfloat16
        )

        # 참조 모델 (고정)
        self.ref_model = AutoModelForCausalLM.from_pretrained(
            model_name,
            torch_dtype=torch.bfloat16
        )
        self.ref_model.eval()
        for param in self.ref_model.parameters():
            param.requires_grad = False

        self.tokenizer = AutoTokenizer.from_pretrained(model_name)
        self.optimizer = AdamW(
            self.policy_model.parameters(), 
            lr=learning_rate
        )

    def train_step(self, batch: dict) -> float:
        """단일 배치 학습"""
        self.policy_model.train()

        # 선호/비선호 응답의 로그 확률 계산
        with torch.no_grad():
            ref_chosen_logps = compute_log_probs(
                self.ref_model,
                batch['chosen_input_ids'],
                batch['chosen_attention_mask'],
                batch['chosen_labels']
            )
            ref_rejected_logps = compute_log_probs(
                self.ref_model,
                batch['rejected_input_ids'],
                batch['rejected_attention_mask'],
                batch['rejected_labels']
            )

        policy_chosen_logps = compute_log_probs(
            self.policy_model,
            batch['chosen_input_ids'],
            batch['chosen_attention_mask'],
            batch['chosen_labels']
        )
        policy_rejected_logps = compute_log_probs(
            self.policy_model,
            batch['rejected_input_ids'],
            batch['rejected_attention_mask'],
            batch['rejected_labels']
        )

        # DPO 손실 계산
        loss = compute_dpo_loss(
            policy_chosen_logps,
            policy_rejected_logps,
            ref_chosen_logps,
            ref_rejected_logps,
            beta=self.beta
        )

        # 역전파
        self.optimizer.zero_grad()
        loss.backward()
        self.optimizer.step()

        return loss.item()

Hugging Face TRL 라이브러리 활용

from trl import DPOTrainer, DPOConfig
from datasets import load_dataset
from transformers import AutoModelForCausalLM, AutoTokenizer


# 모델 및 토크나이저 로드
model = AutoModelForCausalLM.from_pretrained("meta-llama/Llama-2-7b-hf")
tokenizer = AutoTokenizer.from_pretrained("meta-llama/Llama-2-7b-hf")
tokenizer.pad_token = tokenizer.eos_token

# 선호도 데이터셋 로드
# 형식: {"prompt": str, "chosen": str, "rejected": str}
dataset = load_dataset("Anthropic/hh-rlhf", split="train")

# DPO 설정
training_args = DPOConfig(
    output_dir="./dpo_output",
    beta=0.1,
    learning_rate=1e-6,
    per_device_train_batch_size=4,
    gradient_accumulation_steps=4,
    num_train_epochs=1,
    warmup_ratio=0.1,
    logging_steps=10,
    save_steps=500,
    bf16=True,
)

# 트레이너 생성 및 학습
trainer = DPOTrainer(
    model=model,
    ref_model=None,  # None이면 자동으로 복사본 생성
    args=training_args,
    train_dataset=dataset,
    tokenizer=tokenizer,
)

trainer.train()

DPO 변형 및 후속 연구

IPO (Identity Preference Optimization)

  • 논문: A General Theoretical Paradigm to Understand Learning from Human Feedback (2024)
  • 특징: DPO의 과적합 문제 해결, 정규화 개선

KTO (Kahneman-Tversky Optimization)

  • 논문: KTO: Model Alignment as Prospect Theoretic Optimization (2024)
  • 특징: 쌍대 선호도 대신 이진 피드백 사용 가능

ORPO (Odds Ratio Preference Optimization)

  • 논문: ORPO: Monolithic Preference Optimization without Reference Model (2024)
  • 특징: 참조 모델 불필요, SFT와 정렬 동시 수행

Cal-DPO (Calibrated DPO)

  • 발표: NeurIPS 2024
  • 특징: 선호도 강도에 따른 보정 적용

Beta-DPO

  • 발표: NeurIPS 2024
  • 특징: 동적 beta 파라미터 조정

실험 결과 요약

벤치마크 성능 (원 논문)

태스크 메트릭 RLHF (PPO) DPO
감정 제어 정확도 0.78 0.82
TL;DR 요약 GPT-4 승률 0.51 0.55
대화 GPT-4 승률 0.48 0.52

계산 효율성

항목 RLHF DPO
학습 시간 1x 0.3-0.5x
GPU 메모리 높음 (4개 모델) 낮음 (2개 모델)
하이퍼파라미터 많음 적음 (주로 beta)
안정성 낮음 높음

한계점 및 고려사항

알려진 한계

  1. 분포 이동(Distribution Shift): 참조 모델에서 멀어지면 성능 저하 가능
  2. 데이터 품질 의존성: 선호도 데이터의 일관성이 중요
  3. 길이 바이어스: 긴 응답이 선호될 수 있음
  4. 과적합 위험: 적은 데이터로 학습 시 발생

모범 사례

  • 고품질 SFT 모델에서 시작
  • 다양하고 일관된 선호도 데이터 확보
  • beta 값 실험으로 최적값 탐색
  • 검증 세트로 과적합 모니터링

참고 문헌

  1. Rafailov, R., et al. (2023). Direct Preference Optimization: Your Language Model is Secretly a Reward Model. NeurIPS 2023.
  2. Azar, M. G., et al. (2024). A General Theoretical Paradigm to Understand Learning from Human Feedback. ICML 2024.
  3. Ethayarajh, K., et al. (2024). KTO: Model Alignment as Prospect Theoretic Optimization.
  4. Hong, J., et al. (2024). ORPO: Monolithic Preference Optimization without Reference Model.
  5. Wu, J., et al. (2024). Beta-DPO: Direct Preference Optimization with Dynamic Beta. NeurIPS 2024.