콘텐츠로 이동
Data Prep
상세

Airflow DAG 설계

Apache Airflow는 워크플로우를 프로그래밍 방식으로 작성, 스케줄링, 모니터링하기 위한 플랫폼. ML 파이프라인의 오케스트레이션에 널리 사용됨.


1. 핵심 개념

1.1 DAG (Directed Acyclic Graph)

DAG는 작업의 의존성과 실행 순서를 정의함.

from airflow import DAG
from airflow.operators.python import PythonOperator
from datetime import datetime, timedelta

default_args = {
    'owner': 'ml-team',
    'depends_on_past': False,
    'start_date': datetime(2024, 1, 1),
    'email_on_failure': True,
    'email': ['ml-team@company.com'],
    'retries': 3,
    'retry_delay': timedelta(minutes=5),
}

with DAG(
    dag_id='ml_training_pipeline',
    default_args=default_args,
    schedule_interval='@daily',  # 또는 '0 2 * * *' (매일 02:00)
    catchup=False,
    tags=['ml', 'training'],
) as dag:
    # Task 정의
    pass

1.2 Operator 유형

Operator 용도 예시
PythonOperator Python 함수 실행 데이터 처리, 모델 학습
BashOperator Shell 명령 실행 스크립트 실행
KubernetesPodOperator K8s Pod 실행 GPU 학습
DockerOperator Docker 컨테이너 실행 격리된 환경
S3Operator S3 작업 데이터 업로드/다운로드

1.3 Task 의존성

# 방법 1: >> 연산자
task_a >> task_b >> task_c

# 방법 2: set_downstream/set_upstream
task_a.set_downstream(task_b)

# 방법 3: 복잡한 의존성
from airflow.utils.task_group import TaskGroup

task_a >> [task_b, task_c] >> task_d  # 병렬 후 합류

2. ML 파이프라인 DAG 패턴

2.1 기본 학습 파이프라인

from airflow import DAG
from airflow.operators.python import PythonOperator
from airflow.providers.amazon.aws.operators.s3 import S3CopyObjectOperator
from datetime import datetime

def extract_data(**context):
    """데이터 추출"""
    import pandas as pd
    from sqlalchemy import create_engine

    engine = create_engine('postgresql://...')
    df = pd.read_sql('SELECT * FROM training_data', engine)

    # XCom으로 경로 전달
    data_path = f"/tmp/data_{context['ds']}.parquet"
    df.to_parquet(data_path)
    return data_path

def preprocess_data(**context):
    """데이터 전처리"""
    import pandas as pd
    from sklearn.preprocessing import StandardScaler

    # 이전 태스크 결과 가져오기
    ti = context['ti']
    data_path = ti.xcom_pull(task_ids='extract_data')

    df = pd.read_parquet(data_path)

    # 전처리 로직
    scaler = StandardScaler()
    df[['feature1', 'feature2']] = scaler.fit_transform(df[['feature1', 'feature2']])

    processed_path = f"/tmp/processed_{context['ds']}.parquet"
    df.to_parquet(processed_path)
    return processed_path

def train_model(**context):
    """모델 학습"""
    import mlflow
    import pandas as pd
    from sklearn.ensemble import RandomForestClassifier
    from sklearn.model_selection import train_test_split

    ti = context['ti']
    data_path = ti.xcom_pull(task_ids='preprocess_data')

    df = pd.read_parquet(data_path)
    X = df.drop('target', axis=1)
    y = df['target']

    X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.2)

    with mlflow.start_run():
        model = RandomForestClassifier(n_estimators=100)
        model.fit(X_train, y_train)

        # 메트릭 로깅
        accuracy = model.score(X_test, y_test)
        mlflow.log_metric('accuracy', accuracy)

        # 모델 저장
        mlflow.sklearn.log_model(model, 'model')

        return mlflow.active_run().info.run_id

def evaluate_model(**context):
    """모델 평가 및 배포 결정"""
    import mlflow

    ti = context['ti']
    run_id = ti.xcom_pull(task_ids='train_model')

    client = mlflow.tracking.MlflowClient()
    run = client.get_run(run_id)
    accuracy = run.data.metrics['accuracy']

    # 배포 기준
    ACCURACY_THRESHOLD = 0.85

    if accuracy >= ACCURACY_THRESHOLD:
        # 모델 레지스트리에 등록
        model_uri = f"runs:/{run_id}/model"
        mlflow.register_model(model_uri, "production_model")
        return True
    return False

with DAG(
    dag_id='ml_training_pipeline',
    schedule_interval='@daily',
    start_date=datetime(2024, 1, 1),
    catchup=False,
) as dag:

    extract = PythonOperator(
        task_id='extract_data',
        python_callable=extract_data,
    )

    preprocess = PythonOperator(
        task_id='preprocess_data',
        python_callable=preprocess_data,
    )

    train = PythonOperator(
        task_id='train_model',
        python_callable=train_model,
    )

    evaluate = PythonOperator(
        task_id='evaluate_model',
        python_callable=evaluate_model,
    )

    extract >> preprocess >> train >> evaluate

2.2 GPU 학습 (KubernetesPodOperator)

from airflow.providers.cncf.kubernetes.operators.pod import KubernetesPodOperator
from kubernetes.client import models as k8s

gpu_training = KubernetesPodOperator(
    task_id='gpu_training',
    name='llm-finetuning',
    namespace='ml-workloads',
    image='nvidia/cuda:12.0-pytorch:latest',
    cmds=['python'],
    arguments=['/scripts/train.py', '--config', '/configs/finetune.yaml'],

    # GPU 리소스 요청
    container_resources=k8s.V1ResourceRequirements(
        requests={'nvidia.com/gpu': '2'},
        limits={'nvidia.com/gpu': '2', 'memory': '64Gi'}
    ),

    # 볼륨 마운트
    volumes=[
        k8s.V1Volume(
            name='model-storage',
            persistent_volume_claim=k8s.V1PersistentVolumeClaimVolumeSource(
                claim_name='model-pvc'
            )
        )
    ],
    volume_mounts=[
        k8s.V1VolumeMount(name='model-storage', mount_path='/models')
    ],

    # 노드 선택
    node_selector={'node-type': 'gpu'},

    # 타임아웃
    startup_timeout_seconds=600,

    # 환경 변수
    env_vars={
        'WANDB_API_KEY': '{{ var.value.wandb_api_key }}',
        'HF_TOKEN': '{{ var.value.hf_token }}',
    },

    # 완료 후 Pod 삭제
    is_delete_operator_pod=True,
)

2.3 조건부 실행 (BranchPythonOperator)

from airflow.operators.python import BranchPythonOperator
from airflow.operators.empty import EmptyOperator

def check_data_quality(**context):
    """데이터 품질 검사 후 분기"""
    ti = context['ti']
    data_stats = ti.xcom_pull(task_ids='calculate_stats')

    # 품질 기준
    if data_stats['null_ratio'] > 0.1:
        return 'handle_data_quality_issue'
    elif data_stats['row_count'] < 1000:
        return 'skip_training'
    else:
        return 'proceed_training'

branch = BranchPythonOperator(
    task_id='branch_on_quality',
    python_callable=check_data_quality,
)

handle_issue = PythonOperator(task_id='handle_data_quality_issue', ...)
skip = EmptyOperator(task_id='skip_training')
proceed = PythonOperator(task_id='proceed_training', ...)

branch >> [handle_issue, skip, proceed]

3. 고급 패턴

3.1 동적 DAG 생성

def create_model_training_dag(model_config):
    """모델별 DAG 동적 생성"""
    dag_id = f"train_{model_config['name']}"

    with DAG(
        dag_id=dag_id,
        schedule_interval=model_config.get('schedule', '@weekly'),
        start_date=datetime(2024, 1, 1),
        tags=['ml', model_config['name']],
    ) as dag:

        train = PythonOperator(
            task_id='train',
            python_callable=train_model,
            op_kwargs={'model_config': model_config},
        )

        evaluate = PythonOperator(
            task_id='evaluate',
            python_callable=evaluate_model,
            op_kwargs={'model_config': model_config},
        )

        train >> evaluate

    return dag

# 설정 파일에서 모델 목록 로드
MODEL_CONFIGS = [
    {'name': 'classifier_v1', 'algorithm': 'xgboost', 'schedule': '@daily'},
    {'name': 'classifier_v2', 'algorithm': 'lightgbm', 'schedule': '@weekly'},
    {'name': 'regressor_v1', 'algorithm': 'catboost', 'schedule': '@monthly'},
]

# 동적 DAG 생성
for config in MODEL_CONFIGS:
    dag = create_model_training_dag(config)
    globals()[dag.dag_id] = dag

3.2 TaskGroup으로 구조화

from airflow.utils.task_group import TaskGroup

with DAG('structured_ml_pipeline', ...) as dag:

    with TaskGroup('data_preparation') as data_prep:
        extract = PythonOperator(task_id='extract', ...)
        validate = PythonOperator(task_id='validate', ...)
        transform = PythonOperator(task_id='transform', ...)
        extract >> validate >> transform

    with TaskGroup('model_training') as training:
        split = PythonOperator(task_id='split_data', ...)
        train = PythonOperator(task_id='train', ...)
        tune = PythonOperator(task_id='hyperparameter_tune', ...)
        split >> train >> tune

    with TaskGroup('evaluation') as evaluation:
        metrics = PythonOperator(task_id='calculate_metrics', ...)
        report = PythonOperator(task_id='generate_report', ...)
        metrics >> report

    with TaskGroup('deployment') as deployment:
        register = PythonOperator(task_id='register_model', ...)
        deploy = PythonOperator(task_id='deploy', ...)
        verify = PythonOperator(task_id='verify_deployment', ...)
        register >> deploy >> verify

    data_prep >> training >> evaluation >> deployment

3.3 센서를 이용한 외부 이벤트 대기

from airflow.sensors.filesystem import FileSensor
from airflow.sensors.external_task import ExternalTaskSensor
from airflow.providers.amazon.aws.sensors.s3 import S3KeySensor

# 파일 존재 확인
wait_for_data = FileSensor(
    task_id='wait_for_data',
    filepath='/data/daily_export_{{ ds }}.csv',
    poke_interval=300,  # 5분마다 확인
    timeout=3600,  # 1시간 타임아웃
    mode='poke',  # 또는 'reschedule'
)

# S3 파일 대기
wait_for_s3 = S3KeySensor(
    task_id='wait_for_s3_data',
    bucket_name='data-lake',
    bucket_key='raw/{{ ds }}/data.parquet',
    aws_conn_id='aws_default',
)

# 다른 DAG 완료 대기
wait_for_upstream = ExternalTaskSensor(
    task_id='wait_for_etl',
    external_dag_id='etl_pipeline',
    external_task_id='final_task',
    execution_date_fn=lambda dt: dt,
    timeout=7200,
)

4. 실패 처리 및 알림

4.1 재시도 전략

default_args = {
    'retries': 3,
    'retry_delay': timedelta(minutes=5),
    'retry_exponential_backoff': True,
    'max_retry_delay': timedelta(hours=1),
}

# Task별 재시도 설정
critical_task = PythonOperator(
    task_id='critical_task',
    python_callable=critical_function,
    retries=5,
    retry_delay=timedelta(minutes=10),
)

4.2 콜백 함수

def on_failure_callback(context):
    """실패 시 Slack 알림"""
    import requests

    dag_id = context['dag'].dag_id
    task_id = context['task'].task_id
    execution_date = context['execution_date']
    exception = context['exception']

    message = f"""
    :red_circle: Task Failed!
    DAG: {dag_id}
    Task: {task_id}
    Execution Date: {execution_date}
    Error: {str(exception)}
    """

    requests.post(
        'https://hooks.slack.com/services/...',
        json={'text': message}
    )

def on_success_callback(context):
    """성공 시 처리"""
    pass

with DAG(
    dag_id='ml_pipeline_with_alerts',
    on_failure_callback=on_failure_callback,
    on_success_callback=on_success_callback,
    ...
) as dag:
    pass

4.3 SLA (Service Level Agreement)

from datetime import timedelta

with DAG(
    dag_id='sla_monitored_pipeline',
    sla_miss_callback=sla_miss_alert,
    ...
) as dag:

    critical_task = PythonOperator(
        task_id='critical_task',
        python_callable=critical_function,
        sla=timedelta(hours=2),  # 2시간 내 완료 필요
    )

5. 모범 사례

5.1 코드 구조화

airflow diagram 1

5.2 테스트

import pytest
from airflow.models import DagBag

def test_dag_loaded():
    """DAG 로드 테스트"""
    dagbag = DagBag()
    dag = dagbag.get_dag('ml_training_pipeline')

    assert dag is not None
    assert len(dagbag.import_errors) == 0

def test_dag_structure():
    """DAG 구조 테스트"""
    dagbag = DagBag()
    dag = dagbag.get_dag('ml_training_pipeline')

    # 태스크 수 확인
    assert len(dag.tasks) == 4

    # 의존성 확인
    extract_task = dag.get_task('extract_data')
    assert 'preprocess_data' in [t.task_id for t in extract_task.downstream_list]

def test_task_function():
    """태스크 함수 단위 테스트"""
    from dags.ml.training_dag import preprocess_data

    # Mock context
    context = {'ds': '2024-01-01', 'ti': MockTaskInstance()}

    result = preprocess_data(**context)
    assert result is not None

5.3 성능 최적화

# XCom 대신 외부 저장소 사용 (대용량 데이터)
def train_with_external_storage(**context):
    # S3에 직접 저장
    model_path = f"s3://models/{context['ds']}/model.pkl"

    # XCom에는 경로만 저장
    return model_path

# Pool로 동시 실행 제한
train_task = PythonOperator(
    task_id='train',
    python_callable=train_model,
    pool='gpu_pool',  # 사전 정의된 Pool
    pool_slots=2,     # 슬롯 수
)

# 병렬 처리
with DAG(..., max_active_tasks=10) as dag:
    # 동시에 최대 10개 태스크 실행
    pass

6. Airflow 2.x 신기능

6.1 TaskFlow API

from airflow.decorators import dag, task
from datetime import datetime

@dag(
    dag_id='taskflow_ml_pipeline',
    schedule_interval='@daily',
    start_date=datetime(2024, 1, 1),
    catchup=False,
)
def ml_pipeline():

    @task()
    def extract_data():
        import pandas as pd
        df = pd.read_csv('/data/raw.csv')
        return df.to_dict()

    @task()
    def preprocess(data: dict):
        import pandas as pd
        df = pd.DataFrame(data)
        # 전처리
        return df.to_dict()

    @task()
    def train(data: dict):
        import pandas as pd
        from sklearn.ensemble import RandomForestClassifier

        df = pd.DataFrame(data)
        model = RandomForestClassifier()
        # 학습
        return {'accuracy': 0.95}

    @task()
    def evaluate(metrics: dict):
        if metrics['accuracy'] > 0.9:
            print("Model passed evaluation")

    # TaskFlow 의존성 (자동 XCom 처리)
    data = extract_data()
    processed = preprocess(data)
    metrics = train(processed)
    evaluate(metrics)

# DAG 인스턴스 생성
ml_dag = ml_pipeline()

6.2 Dynamic Task Mapping

@dag(...)
def dynamic_training():

    @task()
    def get_model_configs():
        return [
            {'name': 'model_a', 'params': {'n_estimators': 100}},
            {'name': 'model_b', 'params': {'n_estimators': 200}},
            {'name': 'model_c', 'params': {'n_estimators': 300}},
        ]

    @task()
    def train_model(config):
        # 각 설정으로 모델 학습
        print(f"Training {config['name']}")
        return config['name']

    @task()
    def aggregate_results(results):
        print(f"Trained models: {results}")

    configs = get_model_configs()
    # 동적으로 병렬 태스크 생성
    trained = train_model.expand(config=configs)
    aggregate_results(trained)

7. 트러블슈팅 가이드

7.1 일반적인 문제

문제 증상 원인 해결책
DAG 미표시 UI에 DAG 없음 문법 오류, import 실패 airflow dags list 로 오류 확인
Task 멈춤 Running 상태 유지 Deadlock, 리소스 부족 Worker 로그 확인, 리소스 증가
XCom 실패 데이터 전달 안됨 직렬화 불가, 크기 초과 외부 저장소 사용
스케줄 미실행 트리거 안됨 Scheduler 중단, 시간대 Scheduler 상태, timezone 확인
메모리 부족 Worker OOM 대용량 데이터 처리 청크 처리, Worker 스펙 업

7.2 DAG 로드 오류 디버깅

# DAG 로드 테스트
airflow dags list-import-errors

# 특정 DAG 검증
python /path/to/dags/my_dag.py

# 구문 검사
airflow dags test my_dag_id 2024-01-01

# 전체 파싱 시간 확인
airflow dags report
# 디버그 모드로 DAG 로드
import logging
logging.basicConfig(level=logging.DEBUG)

from airflow.models import DagBag

dag_bag = DagBag(dag_folder="/opt/airflow/dags", include_examples=False)

if dag_bag.import_errors:
    for dag_id, error in dag_bag.import_errors.items():
        print(f"DAG {dag_id} failed to load:")
        print(error)
else:
    print("All DAGs loaded successfully")
    for dag_id, dag in dag_bag.dags.items():
        print(f"  - {dag_id}: {len(dag.tasks)} tasks")

7.3 Task 실패 디버깅

# 실패한 Task 재실행
airflow tasks test my_dag task_id 2024-01-01

# 특정 Task 강제 성공 처리
airflow tasks set_state my_dag task_id 2024-01-01 --state success

# Task 인스턴스 상태 확인
airflow tasks states-for-dag-run my_dag 2024-01-01
# DAG 내 디버깅 로직
from airflow.operators.python import PythonOperator
import traceback

def robust_task(**context):
    """실패 시 상세 정보 로깅"""
    try:
        # 실제 작업
        result = do_work()
        return result
    except Exception as e:
        # 상세 오류 정보
        error_info = {
            "task_id": context["task"].task_id,
            "dag_id": context["dag"].dag_id,
            "execution_date": str(context["execution_date"]),
            "error_type": type(e).__name__,
            "error_message": str(e),
            "traceback": traceback.format_exc(),
        }

        # 외부 로깅 (Slack, DB 등)
        send_error_to_slack(error_info)

        raise  # 원래 예외 다시 발생

task = PythonOperator(
    task_id="robust_task",
    python_callable=robust_task,
    provide_context=True,
)

7.4 성능 문제 해결

# Scheduler 성능 튜닝
# airflow.cfg 또는 환경 변수

# DAG 파싱 간격 (초)
AIRFLOW__SCHEDULER__MIN_FILE_PROCESS_INTERVAL=30

# 동시 DAG 파싱 프로세스 수
AIRFLOW__SCHEDULER__PARSING_PROCESSES=4

# 스케줄러 하트비트 간격
AIRFLOW__SCHEDULER__SCHEDULER_HEARTBEAT_SEC=5

# 최대 동시 실행 Task
AIRFLOW__CORE__PARALLELISM=32
AIRFLOW__CORE__MAX_ACTIVE_TASKS_PER_DAG=16
# 무거운 연산 최적화
from airflow.decorators import task
from airflow.providers.cncf.kubernetes.operators.pod import KubernetesPodOperator

# 나쁜 예: Airflow Worker에서 무거운 연산
@task
def heavy_computation():
    # 이 코드는 Worker 리소스를 소모
    result = train_large_model()
    return result

# 좋은 예: 외부 컴퓨팅 리소스 사용
heavy_task = KubernetesPodOperator(
    task_id="heavy_computation",
    name="training-pod",
    image="training-image:latest",
    resources={"limit_memory": "32Gi", "limit_cpu": "8"},
    # 무거운 연산은 별도 Pod에서
)

7.5 XCom 문제 해결

# XCom 크기 제한 문제
# 기본 XCom은 메타데이터 DB에 저장 (크기 제한)

# 해결책 1: 외부 저장소 사용
@task
def process_large_data():
    result = compute_large_result()

    # S3에 저장
    s3_path = f"s3://bucket/xcom/{context['run_id']}/result.parquet"
    result.to_parquet(s3_path)

    # 경로만 XCom으로 전달
    return s3_path

@task
def use_large_data(s3_path: str):
    import pandas as pd
    result = pd.read_parquet(s3_path)
    # 처리

# 해결책 2: Custom XCom Backend
# airflow.cfg
# xcom_backend = my_package.s3_xcom_backend.S3XComBackend

7.6 연결 및 인증 문제

# Connection 테스트
from airflow.hooks.base import BaseHook

def test_connections():
    connections_to_test = ["aws_default", "postgres_default", "slack_webhook"]

    for conn_id in connections_to_test:
        try:
            conn = BaseHook.get_connection(conn_id)
            print(f"✓ {conn_id}: {conn.host}")
        except Exception as e:
            print(f"✗ {conn_id}: {str(e)}")

# DAG에서 연결 검증
from airflow.operators.python import ShortCircuitOperator

def check_db_connection():
    from airflow.providers.postgres.hooks.postgres import PostgresHook
    hook = PostgresHook(postgres_conn_id="my_postgres")

    try:
        hook.get_conn()
        return True
    except Exception:
        return False

validate_connection = ShortCircuitOperator(
    task_id="validate_connection",
    python_callable=check_db_connection,
)

8. 실무 사례

8.1 ML 재학습 파이프라인 자동화

from airflow import DAG
from airflow.decorators import task, task_group
from airflow.operators.python import BranchPythonOperator
from airflow.providers.cncf.kubernetes.operators.pod import KubernetesPodOperator
from datetime import datetime, timedelta

default_args = {
    "retries": 2,
    "retry_delay": timedelta(minutes=10),
    "execution_timeout": timedelta(hours=6),
}

with DAG(
    dag_id="ml_retraining_pipeline",
    schedule_interval="0 2 * * 0",  # 매주 일요일 02:00
    start_date=datetime(2024, 1, 1),
    catchup=False,
    default_args=default_args,
    tags=["ml", "retraining"],
) as dag:

    @task
    def check_drift():
        """드리프트 체크하여 재학습 필요 여부 판단"""
        from my_package.monitoring import check_model_drift

        drift_result = check_model_drift()
        return drift_result["needs_retraining"]

    def decide_retraining(**context):
        """재학습 여부에 따른 분기"""
        ti = context["ti"]
        needs_retraining = ti.xcom_pull(task_ids="check_drift")

        if needs_retraining:
            return "retraining_group.prepare_data"
        else:
            return "skip_retraining"

    branching = BranchPythonOperator(
        task_id="decide_retraining",
        python_callable=decide_retraining,
    )

    @task
    def skip_retraining():
        print("No retraining needed, skipping...")

    @task_group
    def retraining_group():
        @task
        def prepare_data():
            """데이터 준비"""
            import subprocess
            subprocess.run(["dvc", "pull"], check=True)
            return "/data/train"

        train = KubernetesPodOperator(
            task_id="train_model",
            name="model-training",
            image="ml-training:latest",
            arguments=["python", "train.py", "--config", "prod.yaml"],
            resources={"limit_gpu": "2", "limit_memory": "64Gi"},
            env_vars={
                "MLFLOW_TRACKING_URI": "{{ var.value.mlflow_uri }}",
                "WANDB_API_KEY": "{{ var.value.wandb_key }}",
            },
            get_logs=True,
            is_delete_operator_pod=True,
        )

        @task
        def evaluate_model():
            """모델 평가"""
            from my_package.evaluation import evaluate_model

            metrics = evaluate_model()
            return metrics

        @task
        def promote_if_better(metrics: dict):
            """기존 모델보다 좋으면 프로모션"""
            import mlflow

            current_accuracy = get_current_model_accuracy()

            if metrics["accuracy"] > current_accuracy * 1.01:  # 1% 이상 개선
                promote_model_to_production()
                return True
            return False

        data = prepare_data()
        data >> train
        train >> evaluate_model() >> promote_if_better()

    drift_check = check_drift()
    drift_check >> branching
    branching >> [retraining_group(), skip_retraining()]

8.2 Feature Store 업데이트 파이프라인

from airflow import DAG
from airflow.decorators import task
from airflow.providers.apache.spark.operators.spark_submit import SparkSubmitOperator
from datetime import datetime, timedelta

with DAG(
    dag_id="feature_store_update",
    schedule_interval="@hourly",
    start_date=datetime(2024, 1, 1),
    catchup=False,
    max_active_runs=1,
) as dag:

    # Spark로 특성 계산
    compute_features = SparkSubmitOperator(
        task_id="compute_features",
        application="/opt/spark/jobs/compute_features.py",
        conf={
            "spark.executor.memory": "8g",
            "spark.executor.cores": "4",
        },
        application_args=[
            "--date", "{{ ds }}",
            "--output-path", "s3://features/computed/{{ ds }}/",
        ],
    )

    @task
    def validate_features():
        """계산된 특성 검증"""
        import great_expectations as gx

        context = gx.get_context()
        result = context.run_checkpoint(
            checkpoint_name="feature_validation"
        )

        if not result.success:
            raise ValueError("Feature validation failed")

        return True

    @task
    def materialize_to_online():
        """Online Store로 Materialize"""
        from feast import FeatureStore

        store = FeatureStore(repo_path="/app/feature_repo")
        store.materialize_incremental(end_date=datetime.now())

    @task
    def notify_completion():
        """완료 알림"""
        import requests

        requests.post(
            "https://hooks.slack.com/services/...",
            json={"text": f"Feature store updated successfully at {datetime.now()}"}
        )

    compute_features >> validate_features() >> materialize_to_online() >> notify_completion()

참고 자료