연합학습 (Federated Learning)
1. 개요
연합학습은 데이터를 중앙 서버로 모으지 않고, 분산된 클라이언트(디바이스)에서 로컬 데이터로 모델을 학습하고 파라미터만 공유하는 머신러닝 패러다임. 데이터 프라이버시를 보존하면서 협력적 학습을 가능하게 함.
정의
연합학습: 분산 데이터에서 중앙화 없이 학습
기존 ML:
데이터 수집 → 중앙 서버 → 모델 학습
연합학습:
클라이언트 로컬 학습 → 파라미터 집계 → 글로벌 모델 업데이트
(데이터는 클라이언트에 유지)
동기
| 동기 |
설명 |
| 프라이버시 |
민감한 데이터 중앙화 불가 |
| 규제 |
GDPR, 데이터 지역화 법규 |
| 통신 비용 |
대용량 데이터 전송 비효율 |
| 실시간성 |
엣지에서 즉시 학습 |
2. 핵심 개념
2.1 시스템 구성
┌──────────────┐
│ 중앙 서버 │
│ (Aggregator) │
└──────┬───────┘
│ 글로벌 모델
┌─────────┼─────────┐
↓ ↓ ↓
┌───────┐ ┌───────┐ ┌───────┐
│Client1│ │Client2│ │Client3│
│ Data1 │ │ Data2 │ │ Data3 │
└───────┘ └───────┘ └───────┘
2.2 연합학습 유형
| 유형 |
특징 |
예시 |
| Cross-device |
수백만 모바일 디바이스 |
키보드 예측 |
| Cross-silo |
소수 조직 간 협력 |
병원 간 학습 |
2.3 학습 과정 (FedAvg)
글로벌 라운드 t = 1, 2, ...:
1. 서버: 현재 글로벌 모델 wₜ 전송
2. 클라이언트 k (샘플링):
- 로컬 데이터로 E 에폭 학습
- wₖᵗ⁺¹ = wₜ - η∇L(wₜ; Dₖ)
3. 서버: 클라이언트 업데이트 집계
- wₜ₊₁ = Σₖ (nₖ/n) wₖᵗ⁺¹
2.4 주요 과제
| 과제 |
설명 |
| Non-IID 데이터 |
클라이언트 간 데이터 분포 차이 |
| 통신 효율성 |
대역폭 제한 |
| 시스템 이질성 |
디바이스 성능 차이 |
| 프라이버시 |
파라미터에서 정보 누출 |
| 악의적 클라이언트 |
공격 방어 |
3. 주요 알고리즘/기법
3.1 집계 알고리즘
FedAvg (Federated Averaging)
가장 기본적인 알고리즘:
1. 서버가 글로벌 모델 배포
2. 각 클라이언트 로컬 SGD 수행
3. 가중 평균으로 집계
wₜ₊₁ = Σₖ (nₖ/n) wₖᵗ⁺¹
nₖ: 클라이언트 k의 데이터 수
n: 전체 데이터 수
FedProx
Non-IID 문제 완화를 위한 proximal term:
min_w Lₖ(w) + (μ/2)||w - wₜ||²
글로벌 모델에서 크게 벗어나지 않도록 제약
FedOpt / FedAdam
서버 측 최적화 적용:
- FedAdagrad
- FedYogi
- FedAdam
wₜ₊₁ = ServerOpt(wₜ, Δwₜ)
Δwₜ = Σₖ (nₖ/n)(wₖᵗ⁺¹ - wₜ)
3.2 통신 효율성
| 기법 |
방법 |
| Gradient Compression |
그래디언트 양자화 |
| Sparsification |
상위 k% 전송 |
| Knowledge Distillation |
모델 대신 예측 전송 |
| Partial Model Update |
일부 레이어만 |
Gradient Quantization
32-bit → k-bit 양자화
SignSGD: 부호만 전송 (1-bit)
TernGrad: {-1, 0, +1} 3값
3.3 Non-IID 처리
| 기법 |
설명 |
| Data Sharing |
공유 데이터셋 |
| FedProx |
Proximal term |
| SCAFFOLD |
제어 변수 |
| FedNova |
정규화된 평균 |
| Per-FedAvg |
개인화 |
SCAFFOLD:
클라이언트 드리프트 보정을 위한 제어 변수:
c: 서버 제어 변수
cᵢ: 클라이언트 i 제어 변수
업데이트에 (c - cᵢ) 보정 항 추가
3.4 개인화 연합학습
글로벌 모델 + 로컬 적응
방법:
- 로컬 Fine-tuning
- 혼합 모델: wᵢ = αwglobal + (1-α)wlocal
- 멀티태스크 학습
- 클러스터링
3.5 프라이버시 강화
Differential Privacy
민감도 제한 + 노이즈 추가:
클립: Δwᵢ = Δwᵢ / max(1, ||Δwᵢ||/C)
노이즈: Δw̃ᵢ = Δwᵢ + N(0, σ²C²I)
(ε, δ)-DP 보장
Secure Aggregation
암호화된 집계:
- 서버도 개별 업데이트 볼 수 없음
- 합계만 복호화
프로토콜: Secure Multi-party Computation
4. 실무 적용 사례
4.1 모바일 키보드 예측
Google Gboard:
- 다음 단어 예측
- 오타 교정
연합학습 적용:
- 사용자 타이핑 데이터 로컬 유지
- 언어 모델 협력 학습
- 야간 유휴 시 학습
4.2 의료 데이터 협력
병원 간 협력 학습:
- 환자 데이터 공유 불가 (HIPAA)
- 희귀 질환 데이터 부족
연합학습:
- 각 병원 로컬 학습
- 모델 파라미터만 공유
- 프라이버시 보존 진단 모델
4.3 금융 사기 탐지
은행 간 협력:
- 고객 데이터 공유 불가
- 사기 패턴은 공통
연합학습:
- 각 은행 로컬 모델 학습
- 집계로 범용 탐지 모델
4.4 자율주행
차량 플릿 학습:
- 개별 주행 데이터 방대
- 엣지 디바이스 학습
연합학습:
- 차량별 로컬 모델 개선
- 야간 충전 시 서버 동기화
5. 참고 논문/저널
핵심 논문
| 논문 |
저자 |
출처 |
기여 |
| "Communication-Efficient Learning of Deep Networks from Decentralized Data" |
McMahan et al. |
AISTATS 2017 |
FedAvg |
| "Federated Optimization in Heterogeneous Networks" |
Li et al. |
MLSys 2020 |
FedProx |
| "SCAFFOLD: Stochastic Controlled Averaging for Federated Learning" |
Karimireddy et al. |
ICML 2020 |
SCAFFOLD |
| "Advances and Open Problems in Federated Learning" |
Kairouz et al. |
2021 |
종합 서베이 |
| "Secure Aggregation for Federated Learning" |
Bonawitz et al. |
CCS 2017 |
Secure Aggregation |
주요 컨퍼런스
| 컨퍼런스 |
분야 |
| ICML, NeurIPS, ICLR |
FL 알고리즘 |
| MLSys |
시스템 |
| CCS, S&P |
보안/프라이버시 |
| FL-ICML Workshop |
FL 전문 워크샵 |
6. 구현 프레임워크
| 프레임워크 |
개발 |
특징 |
| Flower |
Adap |
유연, 프레임워크 무관 |
| PySyft |
OpenMined |
프라이버시 중심 |
| TensorFlow Federated |
Google |
TF 통합 |
| NVIDIA FLARE |
NVIDIA |
기업용 |
| FedML |
FedML Inc. |
연구/프로덕션 |
| OpenFL |
Intel |
의료 특화 |
Flower 예시
import flwr as fl
# 클라이언트 정의
class FlowerClient(fl.client.NumPyClient):
def get_parameters(self, config):
return model.get_weights()
def fit(self, parameters, config):
model.set_weights(parameters)
model.fit(x_train, y_train, epochs=1)
return model.get_weights(), len(x_train), {}
def evaluate(self, parameters, config):
model.set_weights(parameters)
loss, accuracy = model.evaluate(x_test, y_test)
return loss, len(x_test), {"accuracy": accuracy}
# 클라이언트 시작
fl.client.start_numpy_client(
server_address="localhost:8080",
client=FlowerClient()
)
# 서버 시작
fl.server.start_server(
server_address="localhost:8080",
config=fl.server.ServerConfig(num_rounds=3)
)
7. 실무 고려사항
| 고려사항 |
설명 |
| 클라이언트 선택 |
샘플링 전략, 가용성 |
| 비동기 업데이트 |
Stale 업데이트 처리 |
| 모델 검증 |
악의적 업데이트 탐지 |
| 공정성 |
클라이언트 간 성능 균형 |
| 디버깅 |
분산 환경 로깅/모니터링 |
| 규제 준수 |
GDPR, HIPAA 검증 |