콘텐츠로 이동
Data Prep
상세

모델 모니터링 (Drift Detection)

프로덕션 ML 모델의 성능을 지속적으로 모니터링하고, 데이터 드리프트와 모델 성능 저하를 감지하여 적시에 대응함.


1. 모니터링이 필요한 이유

1.1 ML 시스템 실패 모드

실패 유형 원인 증상
데이터 드리프트 입력 분포 변화 특성 통계 변화
컨셉 드리프트 입력-출력 관계 변화 동일 입력, 다른 정답
모델 성능 저하 위 두 가지 결과 정확도/F1 감소
데이터 품질 문제 ETL 오류, 스키마 변경 결측값, 이상치 급증
시스템 장애 인프라 문제 지연 시간 증가, 오류율 증가

1.2 모니터링 레이어

+------------------+
|    Business KPIs  |   (매출, 전환율, 클릭률)
+------------------+
         |
+------------------+
|   Model Metrics   |   (Accuracy, F1, AUC, RMSE)
+------------------+
         |
+------------------+
|   Data Quality    |   (결측값, 이상치, 스키마)
+------------------+
         |
+------------------+
|   Data Drift      |   (분포 변화, 통계 변화)
+------------------+
         |
+------------------+
|   System Metrics  |   (Latency, Throughput, Errors)
+------------------+

2. 데이터 드리프트 (Data Drift)

2.1 드리프트 유형

1. Covariate Shift (입력 분포 변화)
   - P(X)가 변화, P(Y|X)는 동일
   - 예: 새로운 사용자층 유입

2. Prior Probability Shift (레이블 분포 변화)
   - P(Y)가 변화
   - 예: 계절적 수요 변화

3. Concept Drift (관계 변화)
   - P(Y|X)가 변화
   - 예: 사용자 선호도 변화

4. Upstream Data Change
   - 데이터 파이프라인 변경
   - 예: 로깅 형식 변경

2.2 드리프트 감지 방법

통계적 검정

방법 대상 특징
KS Test 연속형 분포 형태 비교
Chi-Square Test 범주형 빈도 분포 비교
PSI (Population Stability Index) 둘 다 점수 기반, 해석 용이
Wasserstein Distance 연속형 Earth Mover's Distance
Jensen-Shannon Divergence 둘 다 분포 간 거리

PSI 계산

import numpy as np

def calculate_psi(reference, current, bins=10):
    """
    Population Stability Index 계산
    PSI < 0.1: 변화 없음
    0.1 <= PSI < 0.2: 약간의 변화
    PSI >= 0.2: 유의미한 변화
    """
    # 동일한 구간으로 나눔
    breakpoints = np.percentile(reference, np.linspace(0, 100, bins + 1))
    breakpoints[0] = -np.inf
    breakpoints[-1] = np.inf

    # 각 구간의 비율 계산
    ref_counts = np.histogram(reference, breakpoints)[0] / len(reference)
    cur_counts = np.histogram(current, breakpoints)[0] / len(current)

    # 0 방지
    ref_counts = np.clip(ref_counts, 0.0001, None)
    cur_counts = np.clip(cur_counts, 0.0001, None)

    # PSI 계산
    psi = np.sum((cur_counts - ref_counts) * np.log(cur_counts / ref_counts))

    return psi

KS Test

from scipy import stats

def ks_test_drift(reference, current, threshold=0.05):
    """
    Kolmogorov-Smirnov 검정
    p-value < threshold이면 드리프트 감지
    """
    statistic, p_value = stats.ks_2samp(reference, current)

    return {
        'statistic': statistic,
        'p_value': p_value,
        'drift_detected': p_value < threshold
    }

3. Evidently AI

3.1 설치 및 기본 사용

# pip install evidently

from evidently import ColumnMapping
from evidently.report import Report
from evidently.metric_preset import DataDriftPreset, DataQualityPreset
from evidently.metrics import *

import pandas as pd

# 데이터 준비
reference_data = pd.read_csv("data/reference.csv")  # 학습 데이터 또는 초기 데이터
current_data = pd.read_csv("data/current.csv")      # 프로덕션 데이터

# 컬럼 매핑
column_mapping = ColumnMapping(
    target="target",
    prediction="prediction",
    numerical_features=["feature1", "feature2", "feature3"],
    categorical_features=["category1", "category2"],
)

# 데이터 드리프트 리포트
report = Report(metrics=[
    DataDriftPreset(),
])

report.run(
    reference_data=reference_data,
    current_data=current_data,
    column_mapping=column_mapping,
)

# HTML 저장
report.save_html("data_drift_report.html")

# Dictionary 변환
result = report.as_dict()

3.2 데이터 품질 리포트

from evidently.metric_preset import DataQualityPreset

quality_report = Report(metrics=[
    DataQualityPreset(),
])

quality_report.run(
    reference_data=reference_data,
    current_data=current_data,
    column_mapping=column_mapping,
)

# 결과 확인
quality_dict = quality_report.as_dict()

3.3 모델 성능 리포트

from evidently.metric_preset import ClassificationPreset, RegressionPreset

# 분류 모델
classification_report = Report(metrics=[
    ClassificationPreset(),
])

classification_report.run(
    reference_data=reference_data,
    current_data=current_data,
    column_mapping=column_mapping,
)

# 회귀 모델
regression_report = Report(metrics=[
    RegressionPreset(),
])

3.4 커스텀 메트릭

from evidently.metrics import (
    DatasetDriftMetric,
    DataDriftTable,
    ColumnDriftMetric,
    DatasetMissingValuesMetric,
    ColumnQuantileMetric,
)

custom_report = Report(metrics=[
    # 데이터셋 전체 드리프트
    DatasetDriftMetric(),

    # 개별 컬럼 드리프트 테이블
    DataDriftTable(),

    # 특정 컬럼 드리프트
    ColumnDriftMetric(column_name="feature1"),
    ColumnDriftMetric(column_name="category1"),

    # 결측값 비율
    DatasetMissingValuesMetric(),

    # 분위수 모니터링
    ColumnQuantileMetric(column_name="feature1", quantile=0.5),
    ColumnQuantileMetric(column_name="feature1", quantile=0.95),
])

3.5 Test Suite (자동화된 검증)

from evidently.test_suite import TestSuite
from evidently.test_preset import DataDriftTestPreset, DataQualityTestPreset
from evidently.tests import *

# 테스트 스위트 정의
test_suite = TestSuite(tests=[
    # 프리셋
    DataDriftTestPreset(),
    DataQualityTestPreset(),

    # 개별 테스트
    TestNumberOfRows(gte=1000),  # 최소 1000행
    TestNumberOfColumns(eq=10),  # 정확히 10개 컬럼
    TestShareOfMissingValues(lte=0.05),  # 결측값 5% 이하
    TestColumnDrift(column_name="feature1"),
])

test_suite.run(
    reference_data=reference_data,
    current_data=current_data,
    column_mapping=column_mapping,
)

# 결과 확인
result = test_suite.as_dict()
all_passed = result["summary"]["all_passed"]

if not all_passed:
    failed_tests = [t for t in result["tests"] if t["status"] == "FAIL"]
    print(f"Failed tests: {len(failed_tests)}")
    for test in failed_tests:
        print(f"  - {test['name']}: {test['description']}")

4. 프로덕션 모니터링 파이프라인

4.1 Airflow DAG

from airflow import DAG
from airflow.operators.python import PythonOperator
from datetime import datetime, timedelta
import pandas as pd
from evidently.report import Report
from evidently.metric_preset import DataDriftPreset
from evidently.test_suite import TestSuite
from evidently.test_preset import DataDriftTestPreset

def load_reference_data():
    """학습 데이터 또는 기준 데이터 로드"""
    return pd.read_parquet("s3://bucket/reference/data.parquet")

def load_production_data(**context):
    """최근 프로덕션 데이터 로드"""
    execution_date = context["execution_date"]
    date_str = execution_date.strftime("%Y-%m-%d")
    return pd.read_parquet(f"s3://bucket/production/{date_str}/data.parquet")

def run_drift_tests(**context):
    """드리프트 테스트 실행"""
    reference = load_reference_data()
    current = load_production_data(**context)

    test_suite = TestSuite(tests=[DataDriftTestPreset()])
    test_suite.run(reference_data=reference, current_data=current)

    result = test_suite.as_dict()

    if not result["summary"]["all_passed"]:
        # 알림 전송
        send_alert("Data drift detected!", result)

    return result["summary"]["all_passed"]

def generate_monitoring_report(**context):
    """모니터링 리포트 생성"""
    reference = load_reference_data()
    current = load_production_data(**context)

    report = Report(metrics=[DataDriftPreset()])
    report.run(reference_data=reference, current_data=current)

    # S3에 저장
    date_str = context["execution_date"].strftime("%Y-%m-%d")
    report.save_html(f"s3://bucket/reports/{date_str}/drift_report.html")

def send_alert(message, details):
    """Slack/Email 알림"""
    import requests

    requests.post(
        "https://hooks.slack.com/services/...",
        json={
            "text": f":warning: {message}",
            "attachments": [{"text": str(details)}]
        }
    )

with DAG(
    dag_id="model_monitoring",
    schedule_interval="@daily",
    start_date=datetime(2024, 1, 1),
    catchup=False,
    default_args={
        "retries": 3,
        "retry_delay": timedelta(minutes=5),
    },
) as dag:

    drift_test = PythonOperator(
        task_id="run_drift_tests",
        python_callable=run_drift_tests,
    )

    report = PythonOperator(
        task_id="generate_report",
        python_callable=generate_monitoring_report,
    )

    drift_test >> report

4.2 실시간 모니터링 (FastAPI + Prometheus)

from fastapi import FastAPI
from prometheus_client import Counter, Histogram, Gauge, generate_latest
from prometheus_client import CONTENT_TYPE_LATEST
from starlette.responses import Response
import numpy as np
from collections import deque

app = FastAPI()

# Prometheus 메트릭 정의
PREDICTION_COUNTER = Counter(
    "model_predictions_total",
    "Total number of predictions",
    ["model_name", "model_version"]
)

PREDICTION_LATENCY = Histogram(
    "model_prediction_latency_seconds",
    "Prediction latency",
    ["model_name"],
    buckets=[0.01, 0.05, 0.1, 0.25, 0.5, 1.0, 2.5, 5.0]
)

FEATURE_VALUE = Gauge(
    "model_feature_value",
    "Feature value distribution",
    ["feature_name", "stat"]  # stat: mean, std, min, max
)

PREDICTION_SCORE = Histogram(
    "model_prediction_score",
    "Prediction score distribution",
    ["model_name"],
    buckets=[0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7, 0.8, 0.9, 1.0]
)

# 최근 예측 저장 (드리프트 감지용)
recent_predictions = deque(maxlen=10000)
recent_features = {f"feature_{i}": deque(maxlen=10000) for i in range(5)}

@app.post("/predict")
async def predict(request: PredictionRequest):
    import time

    start = time.time()

    # 특성 저장
    for i, value in enumerate(request.features):
        recent_features[f"feature_{i}"].append(value)

    # 예측
    prediction = model.predict([request.features])[0]

    # 메트릭 기록
    latency = time.time() - start
    PREDICTION_COUNTER.labels(model_name="main", model_version="v1").inc()
    PREDICTION_LATENCY.labels(model_name="main").observe(latency)
    PREDICTION_SCORE.labels(model_name="main").observe(prediction)

    # 특성 통계 업데이트
    for name, values in recent_features.items():
        if len(values) >= 100:
            arr = np.array(list(values))
            FEATURE_VALUE.labels(feature_name=name, stat="mean").set(np.mean(arr))
            FEATURE_VALUE.labels(feature_name=name, stat="std").set(np.std(arr))

    recent_predictions.append(prediction)

    return {"prediction": prediction, "latency_ms": latency * 1000}

@app.get("/metrics")
async def metrics():
    return Response(generate_latest(), media_type=CONTENT_TYPE_LATEST)

@app.get("/drift/status")
async def drift_status():
    """간단한 드리프트 상태 확인"""
    if len(recent_predictions) < 1000:
        return {"status": "insufficient_data"}

    # 예측값 분포 확인
    recent_mean = np.mean(list(recent_predictions))
    reference_mean = 0.5  # 학습 시 평균

    drift_score = abs(recent_mean - reference_mean)

    return {
        "status": "drift_detected" if drift_score > 0.1 else "normal",
        "drift_score": drift_score,
        "recent_mean": recent_mean,
        "reference_mean": reference_mean,
    }

4.3 Grafana Dashboard

# Prometheus 알림 규칙
groups:
  - name: ml_model_alerts
    rules:
      - alert: ModelLatencyHigh
        expr: histogram_quantile(0.95, rate(model_prediction_latency_seconds_bucket[5m])) > 1
        for: 5m
        labels:
          severity: warning
        annotations:
          summary: "Model prediction latency is high"
          description: "P95 latency is {{ $value }}s"

      - alert: PredictionDrift
        expr: |
          abs(
            avg_over_time(model_prediction_score_sum[1h] / model_prediction_score_count[1h])
            - avg_over_time(model_prediction_score_sum[24h] / model_prediction_score_count[24h])
          ) > 0.1
        for: 15m
        labels:
          severity: warning
        annotations:
          summary: "Prediction score distribution drift detected"

      - alert: FeatureDrift
        expr: |
          abs(
            model_feature_value{stat="mean"} - model_feature_reference_mean
          ) / model_feature_reference_std > 2
        for: 30m
        labels:
          severity: warning
        annotations:
          summary: "Feature {{ $labels.feature_name }} drift detected"

5. 모니터링 전략

5.1 단계별 접근

Phase 1: 기본 시스템 메트릭
- Latency, Throughput, Error Rate
- 구현: Prometheus + Grafana

Phase 2: 모델 성능 메트릭
- Ground Truth 수집 후 정확도 계산
- 구현: 배치 평가 파이프라인

Phase 3: 데이터 품질 모니터링
- 결측값, 이상치, 스키마 검증
- 구현: Great Expectations, Evidently

Phase 4: 드리프트 감지
- 통계적 테스트 자동화
- 구현: Evidently Test Suite

Phase 5: 자동 대응
- 드리프트 감지 시 자동 재학습
- 구현: Airflow + MLflow

5.2 Ground Truth 수집 전략

monitoring diagram 1

5.3 임계값 설정 가이드

메트릭 경고 위험
PSI 0.1 0.2
KS Statistic 0.1 0.2
결측값 비율 5% 증가 10% 증가
정확도 저하 2% 5%
지연 시간 증가 50% 100%

6. 실무 케이스 스터디

6.1 케이스 1: 추천 시스템 드리프트 감지

상황: 이커머스 추천 시스템의 클릭률(CTR)이 점진적으로 하락

# 실제 모니터링 대시보드 설정
from evidently.report import Report
from evidently.metrics import (
    ColumnDriftMetric,
    DatasetDriftMetric,
    ColumnQuantileMetric,
)
import pandas as pd
from datetime import datetime, timedelta

def daily_drift_check():
    """일일 드리프트 체크"""

    # 기준 데이터 (학습 시점 또는 지난 주)
    reference = pd.read_parquet("s3://data/reference/user_features.parquet")

    # 오늘 데이터
    today = datetime.now().strftime("%Y-%m-%d")
    current = pd.read_parquet(f"s3://data/production/{today}/user_features.parquet")

    # 중요 특성에 대한 드리프트 체크
    critical_features = [
        "user_purchase_count_30d",
        "avg_session_duration",
        "category_preference_score",
    ]

    report = Report(metrics=[
        DatasetDriftMetric(),
        *[ColumnDriftMetric(column_name=f) for f in critical_features],
        *[ColumnQuantileMetric(column_name=f, quantile=0.95) for f in critical_features],
    ])

    report.run(reference_data=reference, current_data=current)
    result = report.as_dict()

    # 드리프트 감지 시 알림
    drift_detected = result["metrics"][0]["result"]["dataset_drift"]

    if drift_detected:
        drifted_features = [
            f["column_name"] 
            for f in result["metrics"][1:len(critical_features)+1]
            if f["result"]["drift_detected"]
        ]

        send_slack_alert(
            channel="#ml-alerts",
            message=f"Drift detected in features: {drifted_features}",
            severity="warning"
        )

        # 자동 재학습 트리거 (선택적)
        if len(drifted_features) > 2:
            trigger_retraining_pipeline()

    return result

# 조사 결과: 신규 마케팅 캠페인으로 새로운 사용자층 유입
# 해결: 새로운 사용자 세그먼트 포함하여 모델 재학습

6.2 케이스 2: 사기 탐지 모델 성능 저하

상황: 사기 탐지 모델의 precision 급격히 하락, false positive 증가

# 레이블 드리프트 감지
from evidently.metrics import ClassificationQualityMetric
from evidently.test_suite import TestSuite
from evidently.tests import TestPrecisionScore, TestRecallScore

def fraud_model_monitoring():
    """사기 탐지 모델 모니터링"""

    # 최근 레이블 확보된 데이터
    labeled_data = get_recent_labeled_predictions(days=7)

    # 기준 기간 (안정 시점)
    baseline_data = get_baseline_labeled_data()

    # 성능 테스트
    test_suite = TestSuite(tests=[
        TestPrecisionScore(gte=0.85),   # Precision >= 85%
        TestRecallScore(gte=0.90),      # Recall >= 90%
    ])

    test_suite.run(reference_data=baseline_data, current_data=labeled_data)

    result = test_suite.as_dict()

    if not result["summary"]["all_passed"]:
        # 실패한 테스트 분석
        for test in result["tests"]:
            if test["status"] == "FAIL":
                analyze_failure(test)

    return result

def analyze_failure(test_result):
    """실패 원인 분석"""

    # 특성별 중요도 vs 드리프트 상관관계 분석
    feature_importance = get_model_feature_importance()
    feature_drift = get_feature_drift_scores()

    # 높은 중요도 + 높은 드리프트 = 주요 원인
    analysis = pd.DataFrame({
        "feature": feature_importance.keys(),
        "importance": feature_importance.values(),
        "drift_score": [feature_drift.get(f, 0) for f in feature_importance.keys()],
    })

    analysis["impact_score"] = analysis["importance"] * analysis["drift_score"]
    top_causes = analysis.nlargest(5, "impact_score")

    print("Top causes of model degradation:")
    print(top_causes)

    return top_causes

# 조사 결과: 새로운 사기 패턴 등장 (Concept Drift)
# 해결: 
# 1. 새로운 사기 패턴 샘플 레이블링
# 2. 규칙 기반 필터 추가 (임시)
# 3. 모델 재학습 및 배포

6.3 케이스 3: 배포 후 지연시간 급증

상황: 모델 업데이트 후 P99 지연시간 3배 증가

# 시스템 메트릭과 모델 메트릭 통합 모니터링
import prometheus_client
from prometheus_client import Histogram, Counter, Gauge

# 메트릭 정의
PREDICTION_LATENCY = Histogram(
    "model_prediction_latency_seconds",
    "Prediction latency in seconds",
    ["model_version", "batch_size_bucket"],
    buckets=[0.01, 0.025, 0.05, 0.1, 0.25, 0.5, 1.0, 2.5]
)

INPUT_SIZE = Histogram(
    "model_input_size",
    "Size of input features",
    ["feature_type"],
    buckets=[10, 50, 100, 500, 1000, 5000]
)

def monitor_prediction(func):
    """예측 함수 데코레이터"""
    def wrapper(input_data, model_version):
        # 입력 크기 기록
        INPUT_SIZE.labels(feature_type="numeric").observe(len(input_data))

        # 배치 크기 버킷
        batch_bucket = "small" if len(input_data) < 10 else "medium" if len(input_data) < 100 else "large"

        # 예측 시간 측정
        with PREDICTION_LATENCY.labels(
            model_version=model_version, 
            batch_size_bucket=batch_bucket
        ).time():
            result = func(input_data, model_version)

        return result
    return wrapper

# 분석 쿼리 (Prometheus)
"""
# P99 지연시간
histogram_quantile(0.99, rate(model_prediction_latency_seconds_bucket[5m]))

# 배치 크기별 지연시간
histogram_quantile(0.95, 
    rate(model_prediction_latency_seconds_bucket{batch_size_bucket="large"}[5m])
)
"""

# 조사 결과: 새 모델의 특성 전처리 복잡도 증가
# 해결:
# 1. 특성 전처리 병렬화
# 2. 무거운 변환은 오프라인으로 이동
# 3. 모델 양자화 적용

6.4 알림 전략

# alerting_rules.yaml
groups:
  - name: ml_model_alerts
    rules:
      # 급격한 성능 저하
      - alert: ModelPerformanceDegradation
        expr: |
          (
            avg_over_time(model_accuracy[1h]) 
            - avg_over_time(model_accuracy[24h])
          ) / avg_over_time(model_accuracy[24h]) < -0.05
        for: 30m
        labels:
          severity: critical
          team: ml-platform
        annotations:
          summary: "Model accuracy dropped more than 5%"
          runbook: "https://wiki/ml-runbooks/performance-degradation"

      # 점진적 드리프트
      - alert: GradualDataDrift
        expr: |
          avg_over_time(feature_psi_score[7d]) > 0.15
        for: 1d
        labels:
          severity: warning
          team: data-science
        annotations:
          summary: "Gradual data drift detected over past week"

      # 예측 분포 이상
      - alert: PredictionDistributionAnomaly
        expr: |
          abs(
            avg(model_prediction_score) - 0.5
          ) > 0.2
        for: 2h
        labels:
          severity: warning
        annotations:
          summary: "Prediction distribution significantly shifted"

7. 트러블슈팅 가이드

7.1 일반적인 문제

증상 가능한 원인 진단 방법 해결책
갑작스러운 정확도 하락 데이터 파이프라인 오류 입력 데이터 스키마 검증 ETL 로그 확인, 롤백
점진적 성능 저하 데이터 드리프트 PSI, KS 테스트 재학습 또는 모델 교체
특정 시간대 오류 급증 트래픽 패턴 변화 시간별 메트릭 분석 동적 스케일링, 캐싱
False Positive 증가 Concept Drift 레이블 분석 새 패턴 포함 재학습
지연시간 증가 입력 크기 증가 입력 분포 모니터링 전처리 최적화

7.2 드리프트 진단 플로우

성능 저하 감지
      |
      v
+------------------+
| 데이터 품질 확인  |
+------------------+
| - 결측값 증가?   |
| - 스키마 변경?   |
| - 이상치 급증?   |
+------------------+
      |
      v (문제 없음)
+------------------+
| 특성 드리프트    |
+------------------+
| - PSI 계산      |
| - 분포 비교     |
| - 상관관계 변화  |
+------------------+
      |
      v (드리프트 확인)
+------------------+
| 영향도 분석      |
+------------------+
| - 특성 중요도   |
| - 부분 의존성   |
+------------------+
      |
      v
+------------------+
| 대응 결정        |
+------------------+
| - 재학습?       |
| - 특성 재설계?  |
| - 규칙 추가?    |
+------------------+

7.3 디버깅 체크리스트

def debug_model_issue():
    """모델 이슈 디버깅 체크리스트"""

    checks = {
        "data_quality": {
            "null_ratio": check_null_ratio(),
            "schema_match": check_schema(),
            "value_ranges": check_value_ranges(),
        },
        "feature_drift": {
            "psi_scores": calculate_all_psi(),
            "top_drifted": get_top_drifted_features(n=5),
        },
        "prediction_analysis": {
            "distribution": get_prediction_distribution(),
            "confidence_scores": analyze_confidence(),
            "error_patterns": analyze_error_patterns(),
        },
        "system_health": {
            "latency_p95": get_latency_percentile(0.95),
            "error_rate": get_error_rate(),
            "throughput": get_throughput(),
        },
    }

    # 자동 진단
    issues = []

    if checks["data_quality"]["null_ratio"] > 0.1:
        issues.append("High null ratio detected")

    if any(psi > 0.2 for psi in checks["feature_drift"]["psi_scores"].values()):
        issues.append(f"Significant drift in: {checks['feature_drift']['top_drifted']}")

    if checks["system_health"]["latency_p95"] > 1.0:
        issues.append("High latency detected")

    return {
        "checks": checks,
        "issues": issues,
        "recommended_actions": generate_recommendations(issues),
    }

7.4 롤백 전략

# 모델 롤백 자동화
import mlflow
from datetime import datetime

def rollback_model(model_name: str, reason: str):
    """프로덕션 모델 롤백"""
    client = mlflow.tracking.MlflowClient()

    # 현재 프로덕션 버전
    current_prod = client.get_latest_versions(model_name, stages=["Production"])[0]

    # 이전 프로덕션 버전 (Archived에서)
    archived = client.get_latest_versions(model_name, stages=["Archived"])
    if not archived:
        raise ValueError("No previous version to rollback to")

    previous_version = max(archived, key=lambda x: x.version)

    # 롤백 실행
    client.transition_model_version_stage(
        name=model_name,
        version=previous_version.version,
        stage="Production",
    )

    client.transition_model_version_stage(
        name=model_name,
        version=current_prod.version,
        stage="Archived",
    )

    # 롤백 기록
    client.set_model_version_tag(
        name=model_name,
        version=current_prod.version,
        key="rollback_reason",
        value=reason,
    )

    client.set_model_version_tag(
        name=model_name,
        version=current_prod.version,
        key="rollback_time",
        value=datetime.now().isoformat(),
    )

    print(f"Rolled back from v{current_prod.version} to v{previous_version.version}")

    return previous_version.version

8. 참고 자료