Airflow DAG 설계¶
Apache Airflow는 워크플로우를 프로그래밍 방식으로 작성, 스케줄링, 모니터링하기 위한 플랫폼. ML 파이프라인의 오케스트레이션에 널리 사용됨.
1. 핵심 개념¶
1.1 DAG (Directed Acyclic Graph)¶
DAG는 작업의 의존성과 실행 순서를 정의함.
from airflow import DAG
from airflow.operators.python import PythonOperator
from datetime import datetime, timedelta
default_args = {
'owner': 'ml-team',
'depends_on_past': False,
'start_date': datetime(2024, 1, 1),
'email_on_failure': True,
'email': ['ml-team@company.com'],
'retries': 3,
'retry_delay': timedelta(minutes=5),
}
with DAG(
dag_id='ml_training_pipeline',
default_args=default_args,
schedule_interval='@daily', # 또는 '0 2 * * *' (매일 02:00)
catchup=False,
tags=['ml', 'training'],
) as dag:
# Task 정의
pass
1.2 Operator 유형¶
| Operator | 용도 | 예시 |
|---|---|---|
| PythonOperator | Python 함수 실행 | 데이터 처리, 모델 학습 |
| BashOperator | Shell 명령 실행 | 스크립트 실행 |
| KubernetesPodOperator | K8s Pod 실행 | GPU 학습 |
| DockerOperator | Docker 컨테이너 실행 | 격리된 환경 |
| S3Operator | S3 작업 | 데이터 업로드/다운로드 |
1.3 Task 의존성¶
# 방법 1: >> 연산자
task_a >> task_b >> task_c
# 방법 2: set_downstream/set_upstream
task_a.set_downstream(task_b)
# 방법 3: 복잡한 의존성
from airflow.utils.task_group import TaskGroup
task_a >> [task_b, task_c] >> task_d # 병렬 후 합류
2. ML 파이프라인 DAG 패턴¶
2.1 기본 학습 파이프라인¶
from airflow import DAG
from airflow.operators.python import PythonOperator
from airflow.providers.amazon.aws.operators.s3 import S3CopyObjectOperator
from datetime import datetime
def extract_data(**context):
"""데이터 추출"""
import pandas as pd
from sqlalchemy import create_engine
engine = create_engine('postgresql://...')
df = pd.read_sql('SELECT * FROM training_data', engine)
# XCom으로 경로 전달
data_path = f"/tmp/data_{context['ds']}.parquet"
df.to_parquet(data_path)
return data_path
def preprocess_data(**context):
"""데이터 전처리"""
import pandas as pd
from sklearn.preprocessing import StandardScaler
# 이전 태스크 결과 가져오기
ti = context['ti']
data_path = ti.xcom_pull(task_ids='extract_data')
df = pd.read_parquet(data_path)
# 전처리 로직
scaler = StandardScaler()
df[['feature1', 'feature2']] = scaler.fit_transform(df[['feature1', 'feature2']])
processed_path = f"/tmp/processed_{context['ds']}.parquet"
df.to_parquet(processed_path)
return processed_path
def train_model(**context):
"""모델 학습"""
import mlflow
import pandas as pd
from sklearn.ensemble import RandomForestClassifier
from sklearn.model_selection import train_test_split
ti = context['ti']
data_path = ti.xcom_pull(task_ids='preprocess_data')
df = pd.read_parquet(data_path)
X = df.drop('target', axis=1)
y = df['target']
X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.2)
with mlflow.start_run():
model = RandomForestClassifier(n_estimators=100)
model.fit(X_train, y_train)
# 메트릭 로깅
accuracy = model.score(X_test, y_test)
mlflow.log_metric('accuracy', accuracy)
# 모델 저장
mlflow.sklearn.log_model(model, 'model')
return mlflow.active_run().info.run_id
def evaluate_model(**context):
"""모델 평가 및 배포 결정"""
import mlflow
ti = context['ti']
run_id = ti.xcom_pull(task_ids='train_model')
client = mlflow.tracking.MlflowClient()
run = client.get_run(run_id)
accuracy = run.data.metrics['accuracy']
# 배포 기준
ACCURACY_THRESHOLD = 0.85
if accuracy >= ACCURACY_THRESHOLD:
# 모델 레지스트리에 등록
model_uri = f"runs:/{run_id}/model"
mlflow.register_model(model_uri, "production_model")
return True
return False
with DAG(
dag_id='ml_training_pipeline',
schedule_interval='@daily',
start_date=datetime(2024, 1, 1),
catchup=False,
) as dag:
extract = PythonOperator(
task_id='extract_data',
python_callable=extract_data,
)
preprocess = PythonOperator(
task_id='preprocess_data',
python_callable=preprocess_data,
)
train = PythonOperator(
task_id='train_model',
python_callable=train_model,
)
evaluate = PythonOperator(
task_id='evaluate_model',
python_callable=evaluate_model,
)
extract >> preprocess >> train >> evaluate
2.2 GPU 학습 (KubernetesPodOperator)¶
from airflow.providers.cncf.kubernetes.operators.pod import KubernetesPodOperator
from kubernetes.client import models as k8s
gpu_training = KubernetesPodOperator(
task_id='gpu_training',
name='llm-finetuning',
namespace='ml-workloads',
image='nvidia/cuda:12.0-pytorch:latest',
cmds=['python'],
arguments=['/scripts/train.py', '--config', '/configs/finetune.yaml'],
# GPU 리소스 요청
container_resources=k8s.V1ResourceRequirements(
requests={'nvidia.com/gpu': '2'},
limits={'nvidia.com/gpu': '2', 'memory': '64Gi'}
),
# 볼륨 마운트
volumes=[
k8s.V1Volume(
name='model-storage',
persistent_volume_claim=k8s.V1PersistentVolumeClaimVolumeSource(
claim_name='model-pvc'
)
)
],
volume_mounts=[
k8s.V1VolumeMount(name='model-storage', mount_path='/models')
],
# 노드 선택
node_selector={'node-type': 'gpu'},
# 타임아웃
startup_timeout_seconds=600,
# 환경 변수
env_vars={
'WANDB_API_KEY': '{{ var.value.wandb_api_key }}',
'HF_TOKEN': '{{ var.value.hf_token }}',
},
# 완료 후 Pod 삭제
is_delete_operator_pod=True,
)
2.3 조건부 실행 (BranchPythonOperator)¶
from airflow.operators.python import BranchPythonOperator
from airflow.operators.empty import EmptyOperator
def check_data_quality(**context):
"""데이터 품질 검사 후 분기"""
ti = context['ti']
data_stats = ti.xcom_pull(task_ids='calculate_stats')
# 품질 기준
if data_stats['null_ratio'] > 0.1:
return 'handle_data_quality_issue'
elif data_stats['row_count'] < 1000:
return 'skip_training'
else:
return 'proceed_training'
branch = BranchPythonOperator(
task_id='branch_on_quality',
python_callable=check_data_quality,
)
handle_issue = PythonOperator(task_id='handle_data_quality_issue', ...)
skip = EmptyOperator(task_id='skip_training')
proceed = PythonOperator(task_id='proceed_training', ...)
branch >> [handle_issue, skip, proceed]
3. 고급 패턴¶
3.1 동적 DAG 생성¶
def create_model_training_dag(model_config):
"""모델별 DAG 동적 생성"""
dag_id = f"train_{model_config['name']}"
with DAG(
dag_id=dag_id,
schedule_interval=model_config.get('schedule', '@weekly'),
start_date=datetime(2024, 1, 1),
tags=['ml', model_config['name']],
) as dag:
train = PythonOperator(
task_id='train',
python_callable=train_model,
op_kwargs={'model_config': model_config},
)
evaluate = PythonOperator(
task_id='evaluate',
python_callable=evaluate_model,
op_kwargs={'model_config': model_config},
)
train >> evaluate
return dag
# 설정 파일에서 모델 목록 로드
MODEL_CONFIGS = [
{'name': 'classifier_v1', 'algorithm': 'xgboost', 'schedule': '@daily'},
{'name': 'classifier_v2', 'algorithm': 'lightgbm', 'schedule': '@weekly'},
{'name': 'regressor_v1', 'algorithm': 'catboost', 'schedule': '@monthly'},
]
# 동적 DAG 생성
for config in MODEL_CONFIGS:
dag = create_model_training_dag(config)
globals()[dag.dag_id] = dag
3.2 TaskGroup으로 구조화¶
from airflow.utils.task_group import TaskGroup
with DAG('structured_ml_pipeline', ...) as dag:
with TaskGroup('data_preparation') as data_prep:
extract = PythonOperator(task_id='extract', ...)
validate = PythonOperator(task_id='validate', ...)
transform = PythonOperator(task_id='transform', ...)
extract >> validate >> transform
with TaskGroup('model_training') as training:
split = PythonOperator(task_id='split_data', ...)
train = PythonOperator(task_id='train', ...)
tune = PythonOperator(task_id='hyperparameter_tune', ...)
split >> train >> tune
with TaskGroup('evaluation') as evaluation:
metrics = PythonOperator(task_id='calculate_metrics', ...)
report = PythonOperator(task_id='generate_report', ...)
metrics >> report
with TaskGroup('deployment') as deployment:
register = PythonOperator(task_id='register_model', ...)
deploy = PythonOperator(task_id='deploy', ...)
verify = PythonOperator(task_id='verify_deployment', ...)
register >> deploy >> verify
data_prep >> training >> evaluation >> deployment
3.3 센서를 이용한 외부 이벤트 대기¶
from airflow.sensors.filesystem import FileSensor
from airflow.sensors.external_task import ExternalTaskSensor
from airflow.providers.amazon.aws.sensors.s3 import S3KeySensor
# 파일 존재 확인
wait_for_data = FileSensor(
task_id='wait_for_data',
filepath='/data/daily_export_{{ ds }}.csv',
poke_interval=300, # 5분마다 확인
timeout=3600, # 1시간 타임아웃
mode='poke', # 또는 'reschedule'
)
# S3 파일 대기
wait_for_s3 = S3KeySensor(
task_id='wait_for_s3_data',
bucket_name='data-lake',
bucket_key='raw/{{ ds }}/data.parquet',
aws_conn_id='aws_default',
)
# 다른 DAG 완료 대기
wait_for_upstream = ExternalTaskSensor(
task_id='wait_for_etl',
external_dag_id='etl_pipeline',
external_task_id='final_task',
execution_date_fn=lambda dt: dt,
timeout=7200,
)
4. 실패 처리 및 알림¶
4.1 재시도 전략¶
default_args = {
'retries': 3,
'retry_delay': timedelta(minutes=5),
'retry_exponential_backoff': True,
'max_retry_delay': timedelta(hours=1),
}
# Task별 재시도 설정
critical_task = PythonOperator(
task_id='critical_task',
python_callable=critical_function,
retries=5,
retry_delay=timedelta(minutes=10),
)
4.2 콜백 함수¶
def on_failure_callback(context):
"""실패 시 Slack 알림"""
import requests
dag_id = context['dag'].dag_id
task_id = context['task'].task_id
execution_date = context['execution_date']
exception = context['exception']
message = f"""
:red_circle: Task Failed!
DAG: {dag_id}
Task: {task_id}
Execution Date: {execution_date}
Error: {str(exception)}
"""
requests.post(
'https://hooks.slack.com/services/...',
json={'text': message}
)
def on_success_callback(context):
"""성공 시 처리"""
pass
with DAG(
dag_id='ml_pipeline_with_alerts',
on_failure_callback=on_failure_callback,
on_success_callback=on_success_callback,
...
) as dag:
pass
4.3 SLA (Service Level Agreement)¶
from datetime import timedelta
with DAG(
dag_id='sla_monitored_pipeline',
sla_miss_callback=sla_miss_alert,
...
) as dag:
critical_task = PythonOperator(
task_id='critical_task',
python_callable=critical_function,
sla=timedelta(hours=2), # 2시간 내 완료 필요
)
5. 모범 사례¶
5.1 코드 구조화¶
5.2 테스트¶
import pytest
from airflow.models import DagBag
def test_dag_loaded():
"""DAG 로드 테스트"""
dagbag = DagBag()
dag = dagbag.get_dag('ml_training_pipeline')
assert dag is not None
assert len(dagbag.import_errors) == 0
def test_dag_structure():
"""DAG 구조 테스트"""
dagbag = DagBag()
dag = dagbag.get_dag('ml_training_pipeline')
# 태스크 수 확인
assert len(dag.tasks) == 4
# 의존성 확인
extract_task = dag.get_task('extract_data')
assert 'preprocess_data' in [t.task_id for t in extract_task.downstream_list]
def test_task_function():
"""태스크 함수 단위 테스트"""
from dags.ml.training_dag import preprocess_data
# Mock context
context = {'ds': '2024-01-01', 'ti': MockTaskInstance()}
result = preprocess_data(**context)
assert result is not None
5.3 성능 최적화¶
# XCom 대신 외부 저장소 사용 (대용량 데이터)
def train_with_external_storage(**context):
# S3에 직접 저장
model_path = f"s3://models/{context['ds']}/model.pkl"
# XCom에는 경로만 저장
return model_path
# Pool로 동시 실행 제한
train_task = PythonOperator(
task_id='train',
python_callable=train_model,
pool='gpu_pool', # 사전 정의된 Pool
pool_slots=2, # 슬롯 수
)
# 병렬 처리
with DAG(..., max_active_tasks=10) as dag:
# 동시에 최대 10개 태스크 실행
pass
6. Airflow 2.x 신기능¶
6.1 TaskFlow API¶
from airflow.decorators import dag, task
from datetime import datetime
@dag(
dag_id='taskflow_ml_pipeline',
schedule_interval='@daily',
start_date=datetime(2024, 1, 1),
catchup=False,
)
def ml_pipeline():
@task()
def extract_data():
import pandas as pd
df = pd.read_csv('/data/raw.csv')
return df.to_dict()
@task()
def preprocess(data: dict):
import pandas as pd
df = pd.DataFrame(data)
# 전처리
return df.to_dict()
@task()
def train(data: dict):
import pandas as pd
from sklearn.ensemble import RandomForestClassifier
df = pd.DataFrame(data)
model = RandomForestClassifier()
# 학습
return {'accuracy': 0.95}
@task()
def evaluate(metrics: dict):
if metrics['accuracy'] > 0.9:
print("Model passed evaluation")
# TaskFlow 의존성 (자동 XCom 처리)
data = extract_data()
processed = preprocess(data)
metrics = train(processed)
evaluate(metrics)
# DAG 인스턴스 생성
ml_dag = ml_pipeline()
6.2 Dynamic Task Mapping¶
@dag(...)
def dynamic_training():
@task()
def get_model_configs():
return [
{'name': 'model_a', 'params': {'n_estimators': 100}},
{'name': 'model_b', 'params': {'n_estimators': 200}},
{'name': 'model_c', 'params': {'n_estimators': 300}},
]
@task()
def train_model(config):
# 각 설정으로 모델 학습
print(f"Training {config['name']}")
return config['name']
@task()
def aggregate_results(results):
print(f"Trained models: {results}")
configs = get_model_configs()
# 동적으로 병렬 태스크 생성
trained = train_model.expand(config=configs)
aggregate_results(trained)
7. 트러블슈팅 가이드¶
7.1 일반적인 문제¶
| 문제 | 증상 | 원인 | 해결책 |
|---|---|---|---|
| DAG 미표시 | UI에 DAG 없음 | 문법 오류, import 실패 | airflow dags list 로 오류 확인 |
| Task 멈춤 | Running 상태 유지 | Deadlock, 리소스 부족 | Worker 로그 확인, 리소스 증가 |
| XCom 실패 | 데이터 전달 안됨 | 직렬화 불가, 크기 초과 | 외부 저장소 사용 |
| 스케줄 미실행 | 트리거 안됨 | Scheduler 중단, 시간대 | Scheduler 상태, timezone 확인 |
| 메모리 부족 | Worker OOM | 대용량 데이터 처리 | 청크 처리, Worker 스펙 업 |
7.2 DAG 로드 오류 디버깅¶
# DAG 로드 테스트
airflow dags list-import-errors
# 특정 DAG 검증
python /path/to/dags/my_dag.py
# 구문 검사
airflow dags test my_dag_id 2024-01-01
# 전체 파싱 시간 확인
airflow dags report
# 디버그 모드로 DAG 로드
import logging
logging.basicConfig(level=logging.DEBUG)
from airflow.models import DagBag
dag_bag = DagBag(dag_folder="/opt/airflow/dags", include_examples=False)
if dag_bag.import_errors:
for dag_id, error in dag_bag.import_errors.items():
print(f"DAG {dag_id} failed to load:")
print(error)
else:
print("All DAGs loaded successfully")
for dag_id, dag in dag_bag.dags.items():
print(f" - {dag_id}: {len(dag.tasks)} tasks")
7.3 Task 실패 디버깅¶
# 실패한 Task 재실행
airflow tasks test my_dag task_id 2024-01-01
# 특정 Task 강제 성공 처리
airflow tasks set_state my_dag task_id 2024-01-01 --state success
# Task 인스턴스 상태 확인
airflow tasks states-for-dag-run my_dag 2024-01-01
# DAG 내 디버깅 로직
from airflow.operators.python import PythonOperator
import traceback
def robust_task(**context):
"""실패 시 상세 정보 로깅"""
try:
# 실제 작업
result = do_work()
return result
except Exception as e:
# 상세 오류 정보
error_info = {
"task_id": context["task"].task_id,
"dag_id": context["dag"].dag_id,
"execution_date": str(context["execution_date"]),
"error_type": type(e).__name__,
"error_message": str(e),
"traceback": traceback.format_exc(),
}
# 외부 로깅 (Slack, DB 등)
send_error_to_slack(error_info)
raise # 원래 예외 다시 발생
task = PythonOperator(
task_id="robust_task",
python_callable=robust_task,
provide_context=True,
)
7.4 성능 문제 해결¶
# Scheduler 성능 튜닝
# airflow.cfg 또는 환경 변수
# DAG 파싱 간격 (초)
AIRFLOW__SCHEDULER__MIN_FILE_PROCESS_INTERVAL=30
# 동시 DAG 파싱 프로세스 수
AIRFLOW__SCHEDULER__PARSING_PROCESSES=4
# 스케줄러 하트비트 간격
AIRFLOW__SCHEDULER__SCHEDULER_HEARTBEAT_SEC=5
# 최대 동시 실행 Task
AIRFLOW__CORE__PARALLELISM=32
AIRFLOW__CORE__MAX_ACTIVE_TASKS_PER_DAG=16
# 무거운 연산 최적화
from airflow.decorators import task
from airflow.providers.cncf.kubernetes.operators.pod import KubernetesPodOperator
# 나쁜 예: Airflow Worker에서 무거운 연산
@task
def heavy_computation():
# 이 코드는 Worker 리소스를 소모
result = train_large_model()
return result
# 좋은 예: 외부 컴퓨팅 리소스 사용
heavy_task = KubernetesPodOperator(
task_id="heavy_computation",
name="training-pod",
image="training-image:latest",
resources={"limit_memory": "32Gi", "limit_cpu": "8"},
# 무거운 연산은 별도 Pod에서
)
7.5 XCom 문제 해결¶
# XCom 크기 제한 문제
# 기본 XCom은 메타데이터 DB에 저장 (크기 제한)
# 해결책 1: 외부 저장소 사용
@task
def process_large_data():
result = compute_large_result()
# S3에 저장
s3_path = f"s3://bucket/xcom/{context['run_id']}/result.parquet"
result.to_parquet(s3_path)
# 경로만 XCom으로 전달
return s3_path
@task
def use_large_data(s3_path: str):
import pandas as pd
result = pd.read_parquet(s3_path)
# 처리
# 해결책 2: Custom XCom Backend
# airflow.cfg
# xcom_backend = my_package.s3_xcom_backend.S3XComBackend
7.6 연결 및 인증 문제¶
# Connection 테스트
from airflow.hooks.base import BaseHook
def test_connections():
connections_to_test = ["aws_default", "postgres_default", "slack_webhook"]
for conn_id in connections_to_test:
try:
conn = BaseHook.get_connection(conn_id)
print(f"✓ {conn_id}: {conn.host}")
except Exception as e:
print(f"✗ {conn_id}: {str(e)}")
# DAG에서 연결 검증
from airflow.operators.python import ShortCircuitOperator
def check_db_connection():
from airflow.providers.postgres.hooks.postgres import PostgresHook
hook = PostgresHook(postgres_conn_id="my_postgres")
try:
hook.get_conn()
return True
except Exception:
return False
validate_connection = ShortCircuitOperator(
task_id="validate_connection",
python_callable=check_db_connection,
)
8. 실무 사례¶
8.1 ML 재학습 파이프라인 자동화¶
from airflow import DAG
from airflow.decorators import task, task_group
from airflow.operators.python import BranchPythonOperator
from airflow.providers.cncf.kubernetes.operators.pod import KubernetesPodOperator
from datetime import datetime, timedelta
default_args = {
"retries": 2,
"retry_delay": timedelta(minutes=10),
"execution_timeout": timedelta(hours=6),
}
with DAG(
dag_id="ml_retraining_pipeline",
schedule_interval="0 2 * * 0", # 매주 일요일 02:00
start_date=datetime(2024, 1, 1),
catchup=False,
default_args=default_args,
tags=["ml", "retraining"],
) as dag:
@task
def check_drift():
"""드리프트 체크하여 재학습 필요 여부 판단"""
from my_package.monitoring import check_model_drift
drift_result = check_model_drift()
return drift_result["needs_retraining"]
def decide_retraining(**context):
"""재학습 여부에 따른 분기"""
ti = context["ti"]
needs_retraining = ti.xcom_pull(task_ids="check_drift")
if needs_retraining:
return "retraining_group.prepare_data"
else:
return "skip_retraining"
branching = BranchPythonOperator(
task_id="decide_retraining",
python_callable=decide_retraining,
)
@task
def skip_retraining():
print("No retraining needed, skipping...")
@task_group
def retraining_group():
@task
def prepare_data():
"""데이터 준비"""
import subprocess
subprocess.run(["dvc", "pull"], check=True)
return "/data/train"
train = KubernetesPodOperator(
task_id="train_model",
name="model-training",
image="ml-training:latest",
arguments=["python", "train.py", "--config", "prod.yaml"],
resources={"limit_gpu": "2", "limit_memory": "64Gi"},
env_vars={
"MLFLOW_TRACKING_URI": "{{ var.value.mlflow_uri }}",
"WANDB_API_KEY": "{{ var.value.wandb_key }}",
},
get_logs=True,
is_delete_operator_pod=True,
)
@task
def evaluate_model():
"""모델 평가"""
from my_package.evaluation import evaluate_model
metrics = evaluate_model()
return metrics
@task
def promote_if_better(metrics: dict):
"""기존 모델보다 좋으면 프로모션"""
import mlflow
current_accuracy = get_current_model_accuracy()
if metrics["accuracy"] > current_accuracy * 1.01: # 1% 이상 개선
promote_model_to_production()
return True
return False
data = prepare_data()
data >> train
train >> evaluate_model() >> promote_if_better()
drift_check = check_drift()
drift_check >> branching
branching >> [retraining_group(), skip_retraining()]
8.2 Feature Store 업데이트 파이프라인¶
from airflow import DAG
from airflow.decorators import task
from airflow.providers.apache.spark.operators.spark_submit import SparkSubmitOperator
from datetime import datetime, timedelta
with DAG(
dag_id="feature_store_update",
schedule_interval="@hourly",
start_date=datetime(2024, 1, 1),
catchup=False,
max_active_runs=1,
) as dag:
# Spark로 특성 계산
compute_features = SparkSubmitOperator(
task_id="compute_features",
application="/opt/spark/jobs/compute_features.py",
conf={
"spark.executor.memory": "8g",
"spark.executor.cores": "4",
},
application_args=[
"--date", "{{ ds }}",
"--output-path", "s3://features/computed/{{ ds }}/",
],
)
@task
def validate_features():
"""계산된 특성 검증"""
import great_expectations as gx
context = gx.get_context()
result = context.run_checkpoint(
checkpoint_name="feature_validation"
)
if not result.success:
raise ValueError("Feature validation failed")
return True
@task
def materialize_to_online():
"""Online Store로 Materialize"""
from feast import FeatureStore
store = FeatureStore(repo_path="/app/feature_repo")
store.materialize_incremental(end_date=datetime.now())
@task
def notify_completion():
"""완료 알림"""
import requests
requests.post(
"https://hooks.slack.com/services/...",
json={"text": f"Feature store updated successfully at {datetime.now()}"}
)
compute_features >> validate_features() >> materialize_to_online() >> notify_completion()