콘텐츠로 이동
Data Prep
상세

RAFT: Retrieval Augmented Fine-Tuning

개요

RAFT (Retrieval Augmented Fine-Tuning)는 도메인 특화 RAG 성능을 극대화하기 위해 검색 컨텍스트와 함께 모델을 파인튜닝하는 기법이다. 기존 RAG의 한계를 학습 시점에서 해결한다.

기존 RAG vs RAFT

기존 RAG 문제점

Query → Retriever → [관련 문서 + 노이즈] → LLM → 답변
              검색 품질에 크게 의존
              노이즈 문서에 혼란

한계: - 검색 결과가 항상 완벽하지 않음 - 관련 없는 문서가 섞이면 성능 저하 - 모델이 관련 정보 추출에 미숙

RAFT 해결책

학습 시:
[Query + Oracle Doc + Distractor Docs] → LLM → Answer with CoT

추론 시:
Query → Retriever → [Documents] → Fine-tuned LLM → 정확한 답변

핵심: 노이즈가 포함된 컨텍스트에서 정답을 찾도록 학습

RAFT 학습 데이터 구성

데이터 유형

┌─────────────────────────────────────────────┐
│            Training Data Mix                 │
├─────────────────────────────────────────────┤
│                                              │
│  ┌───────────────┐  ┌───────────────┐       │
│  │  Oracle Doc   │  │   Distractor  │       │
│  │   포함 (P%)   │  │  Only (1-P%)  │       │
│  └───────────────┘  └───────────────┘       │
│                                              │
│  P = 0.4~0.8 (일반적으로 0.6)               │
└─────────────────────────────────────────────┘

Oracle Document 포함 케이스

{
  "question": "RAFT 논문에서 제안한 최적의 P 값은?",
  "documents": [
    "D1 (Oracle): RAFT 논문에서는 P=0.6을 권장한다...",
    "D2 (Distractor): RAG 시스템 구현 가이드...",
    "D3 (Distractor): LLM 파인튜닝 일반 기법..."
  ],
  "answer": "##begin_quote## RAFT 논문에서는 P=0.6을 권장한다 ##end_quote## 따라서 최적의 P 값은 0.6입니다."
}

Distractor Only 케이스

{
  "question": "RAFT의 학습률 권장값은?",
  "documents": [
    "D1 (Distractor): 일반적인 SFT 학습률...",
    "D2 (Distractor): LoRA 하이퍼파라미터...",
    "D3 (Distractor): 배치 크기 최적화..."
  ],
  "answer": "제공된 문서에서 RAFT 학습률 정보를 찾을 수 없습니다."
}

Chain-of-Thought 답변 형식

구조화된 답변

Step 1: 질문 분석
"사용자가 X에 대해 묻고 있다"

Step 2: 문서 검토
##begin_quote## 관련 인용 ##end_quote##

Step 3: 추론
"인용에 따르면..."

Step 4: 최종 답변
"따라서 답은 Y입니다."

왜 CoT인가?

  • 모델이 어떤 문서에서 정보를 가져왔는지 명시
  • 환각(Hallucination) 감소
  • 답변 검증 가능

구현 가이드

데이터 생성 파이프라인

import random
from typing import List, Dict

def create_raft_example(
    question: str,
    oracle_doc: str,
    distractor_docs: List[str],
    include_oracle: bool,
    num_docs: int = 4
) -> Dict:
    """RAFT 학습 예시 생성"""

    if include_oracle:
        # Oracle + Distractors
        docs = [oracle_doc] + random.sample(
            distractor_docs, 
            min(num_docs - 1, len(distractor_docs))
        )
        random.shuffle(docs)
    else:
        # Distractors only
        docs = random.sample(
            distractor_docs,
            min(num_docs, len(distractor_docs))
        )

    return {
        "question": question,
        "documents": docs,
        "has_oracle": include_oracle
    }

def generate_raft_dataset(
    qa_pairs: List[Dict],
    all_documents: List[str],
    oracle_ratio: float = 0.6
) -> List[Dict]:
    """전체 RAFT 데이터셋 생성"""

    dataset = []
    for qa in qa_pairs:
        include_oracle = random.random() < oracle_ratio

        # Oracle이 아닌 문서를 distractor로
        distractors = [d for d in all_documents if d != qa["oracle_doc"]]

        example = create_raft_example(
            question=qa["question"],
            oracle_doc=qa["oracle_doc"],
            distractor_docs=distractors,
            include_oracle=include_oracle
        )

        # 답변 생성 (별도 모델 또는 수동)
        if include_oracle:
            example["answer"] = generate_cot_answer(qa["question"], qa["oracle_doc"])
        else:
            example["answer"] = "제공된 문서에서 관련 정보를 찾을 수 없습니다."

        dataset.append(example)

    return dataset

프롬프트 템플릿

RAFT_PROMPT = """아래 문서들을 참고하여 질문에 답하세요.
문서에서 답을 찾을 수 없다면 "찾을 수 없습니다"라고 답하세요.
답변 시 관련 문서 내용을 ##begin_quote## ##end_quote##로 인용하세요.

### 문서들:
{documents}

### 질문:
{question}

### 답변:
"""

학습 코드

from transformers import AutoModelForCausalLM, AutoTokenizer
from trl import SFTTrainer, SFTConfig
from datasets import Dataset

# 데이터셋 로드
raft_dataset = Dataset.from_list(generate_raft_dataset(...))

def format_example(example):
    docs = "\n\n".join([f"[문서 {i+1}]\n{d}" 
                        for i, d in enumerate(example["documents"])])
    prompt = RAFT_PROMPT.format(
        documents=docs,
        question=example["question"]
    )
    return {"text": f"{prompt}{example['answer']}"}

formatted_dataset = raft_dataset.map(format_example)

# 학습
model = AutoModelForCausalLM.from_pretrained(
    "meta-llama/Llama-3.1-8B-Instruct",
    torch_dtype=torch.bfloat16,
)

trainer = SFTTrainer(
    model=model,
    train_dataset=formatted_dataset,
    args=SFTConfig(
        learning_rate=2e-5,
        per_device_train_batch_size=2,
        gradient_accumulation_steps=8,
        num_train_epochs=3,
        max_seq_length=4096,
    ),
)

trainer.train()

하이퍼파라미터 가이드

Oracle Ratio (P)

P 값 특징 적합한 경우
0.4 Distractor 비중 높음 검색 품질 낮은 환경
0.6 균형 (권장) 일반적인 RAG
0.8 Oracle 비중 높음 검색 품질 높은 환경

문서 수

문서 수 장점 단점
2-3 빠른 학습, 단순 노이즈 내성 약함
4-5 권장, 균형 -
6+ 노이즈 내성 강함 학습 느림, 긴 컨텍스트

기타 권장값

learning_rate: 1e-5 ~ 5e-5
epochs: 2-4
batch_size: 2-4 (gradient accumulation으로 effective batch 16-32)
max_seq_length: 4096-8192
warmup_ratio: 0.03

평가 방법

테스트 데이터 구성

def create_test_scenarios(qa_pair, all_docs):
    """다양한 시나리오로 테스트"""

    scenarios = {
        "clean": {  # 이상적 검색
            "docs": [qa_pair["oracle_doc"]],
            "expected_behavior": "정확한 답변"
        },
        "noisy": {  # 현실적 검색
            "docs": [qa_pair["oracle_doc"]] + random.sample(all_docs, 3),
            "expected_behavior": "노이즈에서 정답 추출"
        },
        "no_answer": {  # 정보 없음
            "docs": random.sample(all_docs, 4),
            "expected_behavior": "찾을 수 없다고 답변"
        }
    }
    return scenarios

메트릭

메트릭 설명
Answer Accuracy 정답 일치율
Citation Accuracy 올바른 문서 인용율
Abstention Rate "찾을 수 없음" 적절 사용율
Hallucination Rate 없는 정보 생성율

실험 결과

도메인 QA 벤치마크

모델 Accuracy 환각률
Base LLM 45% 32%
Base + RAG 62% 18%
RAFT 78% 8%
RAFT + RAG (추론시) 82% 5%

노이즈 내성 테스트

Distractor 수 증가에 따른 정확도:

         0개   2개   4개   6개
Base+RAG: 85%  72%   61%   52%
RAFT:     88%  84%   79%   74%

고급 기법

어려운 Negative 샘플링

def hard_negative_sampling(query, documents, retriever):
    """검색 점수 기반 어려운 negative 선택"""

    scores = retriever.score(query, documents)

    # Oracle 제외하고 점수 높은 문서를 distractor로
    hard_negatives = sorted(
        [(doc, score) for doc, score in zip(documents, scores) 
         if doc != oracle],
        key=lambda x: x[1],
        reverse=True
    )[:3]

    return [doc for doc, _ in hard_negatives]

동적 Oracle Ratio

def dynamic_oracle_ratio(epoch, total_epochs):
    """학습 진행에 따라 P 감소"""
    initial_p = 0.8
    final_p = 0.4
    return initial_p - (initial_p - final_p) * (epoch / total_epochs)

Multi-hop RAFT

# 여러 문서 조합이 필요한 질문
{
    "question": "A의 저자가 쓴 다른 논문 B의 인용수는?",
    "oracle_docs": ["A 논문 정보", "B 논문 정보"],
    "requires_reasoning": True
}

제한사항 및 주의점

제한사항 대응
도메인 문서 필요 합성 데이터 생성 고려
학습 비용 LoRA/QLoRA 활용
컨텍스트 길이 문서 요약/청킹
문서 업데이트 주기적 재학습 또는 증분 학습

참고 자료

  • "RAFT: Adapting Language Model to Domain Specific RAG" (Zhang et al., 2024)
  • Microsoft Research RAFT Implementation
  • LlamaIndex RAFT Tutorial
  • Berkeley RAG Fine-tuning Recipes

최종 업데이트: 2026-02-18