콘텐츠로 이동
Data Prep
상세

Federated Learning

메타 정보

항목 내용
분류 Distributed Learning / Privacy-Preserving ML
핵심 논문 "Communication-Efficient Learning of Deep Networks from Decentralized Data" (McMahan et al., AISTATS 2017), "Federated Optimization in Heterogeneous Networks" (Li et al., MLSys 2020), "SCAFFOLD: Stochastic Controlled Averaging" (Karimireddy et al., ICML 2020)
주요 저자 Brendan McMahan (FedAvg 창시), Tian Li (FedProx), Sai Praneeth Karimireddy (SCAFFOLD), Peter Kairouz (서베이)
핵심 개념 데이터를 중앙 서버로 모으지 않고, 각 클라이언트에서 로컬 학습한 후 모델 파라미터만 집계하여 글로벌 모델을 학습하는 분산 학습 패러다임
관련 분야 Distributed Systems, Differential Privacy, Secure Computation, Communication Theory

정의

Federated Learning(FL)은 다수의 클라이언트(device, organization)가 자신의 로컬 데이터를 공유하지 않으면서, 협력적으로 글로벌 모델을 학습하는 기계학습 패러다임이다. Google이 2016년 처음 제안하였으며, 핵심 원칙은 "데이터가 이동하지 않고 모델이 이동한다"는 것이다.

수학적 정의

K개의 클라이언트가 있고, 각 클라이언트 k는 로컬 데이터셋 D_k를 가진다. FL의 목적 함수:

min_w  F(w) = sum_{k=1}^{K} (n_k / n) * F_k(w)

여기서: - w: 글로벌 모델 파라미터 - n_k = |D_k|: 클라이언트 k의 데이터 수 - n = sum n_k: 전체 데이터 수 - F_k(w) = (1/n_k) sum_{(x,y) in D_k} l(w; x, y): 클라이언트 k의 로컬 목적 함수

FedAvg 알고리즘

Algorithm: Federated Averaging (FedAvg)
==========================================

Server:
  initialize w_0
  for each round t = 0, 1, 2, ...:
      S_t <- random subset of K clients (fraction C)
      for each client k in S_t (in parallel):
          w_k^{t+1} <- ClientUpdate(k, w_t)
      w_{t+1} <- sum_{k in S_t} (n_k / n_S) * w_k^{t+1}

ClientUpdate(k, w):
  w_local <- w
  for each local epoch e = 1, ..., E:
      for each batch b in D_k:
          w_local <- w_local - eta * nabla l(w_local; b)
  return w_local

FL 전체 흐름

Round t:

  Server                           Clients
  ======                           =======

  [Global Model w_t]
        |
        |--- broadcast w_t ------> Client 1: train on D_1
        |                          Client 2: train on D_2
        |                          Client 3: train on D_3
        |                          ...
        |                          Client K: train on D_K
        |
        |<-- send w_1^{t+1} ------ Client 1
        |<-- send w_2^{t+1} ------ Client 2
        |<-- send w_3^{t+1} ------ Client 3
        |                          ...
        |<-- send w_K^{t+1} ------ Client K
        |
  [Aggregate: weighted average]
        |
  [Global Model w_{t+1}]

시나리오 비교: Cross-silo vs Cross-device

항목 Cross-silo Cross-device
클라이언트 조직 (병원, 은행, 기업) 개별 디바이스 (스마트폰, IoT)
클라이언트 수 소수 (2~100) 대규모 (10^6 ~ 10^10)
참여율 거의 100% (항상 가용) 매우 낮음 (일부만 참여)
데이터 크기/클라이언트 대용량 소용량
통신 안정적 (데이터센터 간) 불안정 (모바일 네트워크)
연산 자원 풍부 (GPU 서버) 제한적 (모바일 프로세서)
신뢰도 상대적으로 신뢰 가능 낮음 (악의적 클라이언트 가능)
ID 추적 가능 (조직별 식별) 불가 (stateless)
대표 예시 의료 데이터 연합, 금융 컨소시엄 Google Keyboard, Apple Siri
주요 과제 데이터 이질성, 공정성 통신 효율, 디바이스 이탈

핵심 문제

1. Non-IID 데이터

FL에서 가장 핵심적인 문제는 각 클라이언트의 데이터 분포가 서로 다르다는 것이다 (Non-IID: Non-Identically and Independently Distributed).

Non-IID 유형:

유형 설명 예시
Label Distribution Skew 클라이언트별 레이블 분포 차이 병원 A: 폐암 데이터 많음, 병원 B: 유방암 많음
Feature Distribution Skew 동일 레이블이지만 특성 분포 차이 지역별 얼굴 인식: 인종 분포 차이
Quantity Skew 클라이언트별 데이터 양 차이 활발한 사용자 10만건, 비활성 사용자 10건
Concept Shift 동일 특성-레이블 관계의 차이 지역별 "좋은 날씨"의 기준 차이

Non-IID가 FedAvg에 미치는 영향: - Client drift: 로컬 학습이 진행될수록 글로벌 최적점에서 멀어짐 - 수렴 속도 저하 및 최종 성능 하락 - 극단적 Non-IID에서는 발산 가능

2. 통신 비용

통신 비용 분석:

매 라운드마다:
  - 서버 -> 클라이언트: 모델 전체 전송 (W bytes)
  - 클라이언트 -> 서버: 업데이트 전송 (W bytes)
  - 총: 2W * K_active bytes per round

예시 (ResNet-50, 100 clients):
  - 모델 크기: ~100MB
  - 라운드당: 100MB * 2 * 100 = 20GB
  - 500 라운드: 10TB 통신량

3. Privacy

원시 데이터를 공유하지 않지만, 모델 업데이트(gradient)로부터 원본 데이터를 복원할 수 있다:

  • Gradient Inversion Attack (Zhu et al., NeurIPS 2019): gradient로부터 학습 이미지 복원 가능
  • Membership Inference: 특정 샘플이 학습에 사용되었는지 추론 가능

4. System Heterogeneity

클라이언트별 연산 능력, 네트워크 속도, 가용 시간이 다르다:

System Heterogeneity:

  Client 1: GPU server     --> 10초에 학습 완료
  Client 2: Laptop         --> 60초에 학습 완료
  Client 3: Smartphone     --> 300초에 학습 완료 (또는 중도 이탈)
  Client 4: IoT device     --> 학습 불가능

  --> Synchronous aggregation은 가장 느린 client에 의해 병목
  --> Straggler problem

알고리즘

FedAvg (McMahan et al., AISTATS 2017)

기본 알고리즘으로, 각 클라이언트에서 여러 epoch의 SGD를 수행한 후 가중 평균으로 집계한다.

수식:

글로벌 업데이트:
  w_{t+1} = sum_{k in S_t} (n_k / n_S) * w_k^{t+1}

로컬 업데이트 (E epochs, learning rate eta):
  w_k^{t+1} = w_t - eta * sum_{e=1}^{E} sum_{b in D_k} nabla l(w; b)

하이퍼파라미터: - C: 라운드당 참여 클라이언트 비율 (0 < C <= 1) - E: 로컬 학습 에포크 수 - B: 로컬 미니배치 크기 - eta: 로컬 학습률

한계: Non-IID 데이터에서 client drift 문제, 수렴 보장 없음 (non-convex case)

FedProx (Li et al., MLSys 2020)

FedAvg에 proximal term을 추가하여 로컬 모델이 글로벌 모델에서 과도하게 벗어나는 것을 방지한다.

로컬 목적 함수:

h_k(w; w_t) = F_k(w) + (mu / 2) * ||w - w_t||^2

여기서 mu는 proximal term의 강도를 조절하는 하이퍼파라미터이다. mu = 0이면 FedAvg와 동일하다.

특징: - Client drift 완화 - Partial work 허용: 클라이언트가 E epochs를 다 못하고 중간에 반환해도 됨 - mu 튜닝 필요 (너무 크면 로컬 학습 효과 감소)

SCAFFOLD (Karimireddy et al., ICML 2020)

Control variate를 도입하여 client drift를 직접적으로 보정한다.

핵심 아이디어:

FedAvg의 로컬 업데이트:
  w <- w - eta * g_k(w)              (로컬 gradient만 사용)

SCAFFOLD의 로컬 업데이트:
  w <- w - eta * (g_k(w) - c_k + c)  (drift 보정)

여기서:
  c   = 글로벌 control variate (서버의 전체 gradient 추정)
  c_k = 클라이언트 k의 로컬 control variate
  (c_k - c)가 client drift를 보정

Control variate 업데이트:

c_k^{new} = c_k - c + (1 / (K * eta)) * (w_t - w_k^{t+1})
c^{new} = c + (1 / K) * sum_k (c_k^{new} - c_k)

장점: IID에서도 Non-IID에서도 빠른 수렴, client drift 근본적 해결 단점: 추가 통신 비용 (control variate 전송), 메모리 오버헤드

FedNova (Wang et al., NeurIPS 2020)

클라이언트별 로컬 학습 횟수가 다를 때 발생하는 objective inconsistency를 해결한다.

문제점: FedAvg에서 클라이언트별 로컬 스텝 수가 다르면, 더 많이 학습한 클라이언트의 영향이 과대 반영된다.

수식:

FedAvg (문제):
  w_{t+1} = sum (n_k/n) * w_k    (tau_k 스텝 학습한 결과를 그대로 평균)

FedNova (보정):
  d_k = (w_t - w_k) / tau_k       (정규화된 gradient)
  d = sum (n_k/n) * d_k           (정규화된 평균)
  tau_eff = sum (n_k/n) * tau_k   (유효 스텝 수)
  w_{t+1} = w_t - tau_eff * d     (보정된 업데이트)

FedBN (Li et al., ICLR 2021)

Batch Normalization 레이어를 공유하지 않고 로컬에 유지하여 feature shift 문제를 해결한다.

FedBN의 파라미터 분할:

  Global (공유):
    - Conv layers
    - FC layers
    - 기타 가중치

  Local (비공유):
    - BN running mean
    - BN running var
    - BN gamma, beta

이유: 각 클라이언트의 데이터 분포가 다르므로,
      BN 통계치도 달라야 자연스러움

알고리즘 비교

알고리즘 핵심 메커니즘 Non-IID 대응 추가 통신 추가 메모리
FedAvg 가중 평균 약함 없음 없음
FedProx Proximal term 중간 없음 없음
SCAFFOLD Control variate 강함 2x c, c_k 저장
FedNova 정규화된 집계 중간 없음 없음
FedBN BN 로컬 유지 중간 (feature shift) 약간 감소 BN 파라미터

Privacy 기법

Secure Aggregation

서버가 개별 클라이언트의 모델 업데이트를 볼 수 없도록, 암호학적 프로토콜로 집계를 수행한다.

Secure Aggregation 흐름:

  Client 1: w_1 + mask_1  ----+
  Client 2: w_2 + mask_2  ----+--> Server: sum(w_k + mask_k)
  Client 3: w_3 + mask_3  ----+            = sum(w_k) + sum(mask_k)
                                            = sum(w_k) + 0  (masks cancel out)
                                            = sum(w_k)

  * 서버는 sum(w_k)만 알 수 있고, 개별 w_k는 알 수 없음
  * Pairwise masking: mask_{ij} = -mask_{ji}

Differential Privacy (DP)

각 클라이언트의 업데이트에 노이즈를 추가하여 개인 정보를 보호한다.

DP-FedAvg:

  1. Gradient clipping:
     g_k <- g_k * min(1, S / ||g_k||)       (sensitivity 제한)

  2. Noise addition:
     g_k <- g_k + N(0, sigma^2 * S^2 * I)   (Gaussian noise)

  3. Aggregation:
     w_{t+1} = w_t - eta * (1/K) * sum g_k

  Privacy guarantee: (epsilon, delta)-DP
  * sigma가 클수록 프라이버시 강함, 성능 저하
  * epsilon이 작을수록 프라이버시 강함

Privacy-Utility 트레이드오프:

epsilon Privacy 수준 일반적 성능 영향
0.1 ~ 1 매우 강함 유의미한 성능 저하
1 ~ 10 강함 중간 수준 저하
10 ~ 100 약함 미미한 저하
> 100 거의 없음 거의 영향 없음

Homomorphic Encryption (HE)

암호화된 상태에서 연산을 수행할 수 있는 암호 체계를 이용한다.

HE 기반 FL:

  Client k: Enc(w_k) --> Server
  Server:   Enc(w_1) + Enc(w_2) + ... = Enc(w_1 + w_2 + ...)
  Server:   Enc(sum) / K = Enc(avg)
  Server:   Enc(avg) --> Clients
  Client k: Dec(Enc(avg)) = avg

  * 서버는 암호화된 값만 처리
  * 복호화 키를 모르므로 개별 값 확인 불가

한계: 연산 오버헤드가 매우 크다 (plaintext 대비 1000x~10000x). 실용적으로는 제한적이며, Secure Aggregation + DP 조합이 더 일반적이다.


Communication 최적화

Gradient Compression

전체 gradient 대신 중요한 성분만 전송하여 통신량을 줄인다.

Top-k Sparsification:

원래 gradient: g = [0.1, -0.5, 0.02, 0.8, -0.01, 0.3, ...]
                    (d = 10^6 개 원소)

Top-k (k = 0.1%):  [0, -0.5, 0, 0.8, 0, 0.3, ...]
                    (1000개 비영 원소만 전송)

통신량: d * 32bit --> k * (32bit + index) = 99.9% 감소

Error Feedback: 전송하지 않은 gradient를 누적하여 다음 라운드에 반영

e_k^{t+1} = g_k^t - Compress(g_k^t + e_k^t)

Quantization

gradient 값의 비트 수를 줄인다.

32-bit float --> 8-bit quantization:

  원래: 0.2345678 (32 bits)
  양자화: 0.234 (8 bits)

  통신량 4x 감소

SignSGD (극단적 1-bit):
  원래: [0.2, -0.5, 0.1, -0.3]
  SignSGD: [+1, -1, +1, -1]  (각 1 bit)

  통신량 32x 감소, 성능 저하 있음

Federated Distillation

모델 파라미터 대신 soft label(logit)만 주고받는 방식이다.

일반 FL: 모델 전체 전송 (수백 MB)
FedDistill: logit만 전송 (수 KB)

Client k: input x --> model_k(x) = logit_k
Server:   avg_logit = mean(logit_k)
Client k: KD loss = KL(model_k(x) || avg_logit) 로 학습

통신 최적화 비교

기법 압축률 성능 영향 적용 난이도
Top-k Sparsification 100~1000x 낮음 쉬움
Random Sparsification 10~100x 중간 쉬움
8-bit Quantization 4x 매우 낮음 쉬움
1-bit (SignSGD) 32x 중간~높음 쉬움
Federated Distillation 1000x+ 중간 중간
Low-rank Decomposition 10~100x 낮음 중간

적용 분야

의료 (Healthcare)

적용 사례 설명
의료 영상 분석 여러 병원의 CT/MRI 데이터로 질병 진단 모델 학습 (HIPAA 준수)
전자건강기록 (EHR) 환자 기록을 공유하지 않고 질병 예측 모델 학습
약물 발견 제약사 간 분자 데이터 연합 학습
웨어러블 개인 건강 데이터로 이상 징후 탐지

사례: HealthChain 프로젝트 - 유럽 4개 병원 연합으로 유방암 조직 분류 모델 학습. 중앙 집중 학습 대비 97% 성능 달성.

금융 (Finance)

적용 사례 설명
사기 탐지 은행 간 거래 데이터를 공유하지 않고 사기 패턴 학습
신용 평가 여러 금융 기관의 데이터로 신용 모델 개선
AML 자금세탁 패턴을 은행 연합으로 탐지
리스크 관리 시장 리스크 모델의 분산 학습

모바일 (Mobile/Edge)

적용 사례 설명
키보드 예측 Google Gboard: 사용자 타이핑 패턴으로 다음 단어 예측 개선
음성 인식 Apple Siri: 사용자 음성 데이터를 디바이스에서 학습
추천 시스템 사용자 행동 데이터를 로컬에서 학습하여 개인화
자율주행 각 차량의 주행 데이터로 인식 모델 개선

프레임워크 비교

프레임워크 개발 언어 특징 적합 용도
Flower Adap (독립) Python 프레임워크 무관, 유연함, 쉬운 시작 연구, 프로토타이핑
PySyft OpenMined Python Privacy 중심, DP/HE 통합 Privacy 연구
FATE WeBank Python/Java 산업용, 보안 프로토콜 완비 금융/엔터프라이즈
TFF Google Python TensorFlow 생태계, 시뮬레이션 TF 기반 연구
FedML FedML Inc. Python MLOps 통합, 다양한 토폴로지 산업 배포
NVIDIA FLARE NVIDIA Python 의료 특화, Clara 통합 의료 영상

Flower 코드 예시

"""
Flower를 이용한 Federated Learning 예시
pip install flwr torch torchvision
"""

import flwr as fl
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader, Subset
from torchvision import datasets, transforms


# --- 모델 정의 ---

class SimpleCNN(nn.Module):
    def __init__(self):
        super().__init__()
        self.conv = nn.Sequential(
            nn.Conv2d(1, 32, 3, padding=1),
            nn.ReLU(),
            nn.MaxPool2d(2),
            nn.Conv2d(32, 64, 3, padding=1),
            nn.ReLU(),
            nn.MaxPool2d(2),
        )
        self.fc = nn.Sequential(
            nn.Linear(64 * 7 * 7, 128),
            nn.ReLU(),
            nn.Linear(128, 10),
        )

    def forward(self, x):
        x = self.conv(x)
        x = x.view(x.size(0), -1)
        return self.fc(x)


# --- Flower Client ---

class MNISTClient(fl.client.NumPyClient):
    def __init__(self, model, trainloader, testloader, device):
        self.model = model
        self.trainloader = trainloader
        self.testloader = testloader
        self.device = device

    def get_parameters(self, config):
        """모델 파라미터를 numpy array로 반환"""
        return [val.cpu().numpy() for val in self.model.state_dict().values()]

    def set_parameters(self, parameters):
        """서버에서 받은 파라미터를 모델에 설정"""
        params_dict = zip(self.model.state_dict().keys(), parameters)
        state_dict = {k: torch.tensor(v) for k, v in params_dict}
        self.model.load_state_dict(state_dict, strict=True)

    def fit(self, parameters, config):
        """로컬 학습 수행"""
        self.set_parameters(parameters)

        optimizer = optim.SGD(self.model.parameters(), lr=0.01, momentum=0.9)
        criterion = nn.CrossEntropyLoss()

        self.model.train()
        for epoch in range(config.get("local_epochs", 1)):
            for images, labels in self.trainloader:
                images, labels = images.to(self.device), labels.to(self.device)
                optimizer.zero_grad()
                loss = criterion(self.model(images), labels)
                loss.backward()
                optimizer.step()

        return self.get_parameters(config), len(self.trainloader.dataset), {}

    def evaluate(self, parameters, config):
        """글로벌 모델 평가"""
        self.set_parameters(parameters)

        criterion = nn.CrossEntropyLoss()
        self.model.eval()
        loss, correct, total = 0.0, 0, 0

        with torch.no_grad():
            for images, labels in self.testloader:
                images, labels = images.to(self.device), labels.to(self.device)
                outputs = self.model(images)
                loss += criterion(outputs, labels).item()
                correct += (outputs.argmax(1) == labels).sum().item()
                total += labels.size(0)

        return loss / len(self.testloader), total, {"accuracy": correct / total}


# --- 데이터 분할 (Non-IID 시뮬레이션) ---

def partition_data_noniid(dataset, num_clients, alpha=0.5):
    """
    Dirichlet 분포를 이용한 Non-IID 분할
    alpha가 작을수록 더 Non-IID (0.1: 극심한 Non-IID, 100: 거의 IID)
    """
    import numpy as np

    labels = np.array([dataset[i][1] for i in range(len(dataset))])
    num_classes = len(set(labels))

    client_indices = [[] for _ in range(num_clients)]

    for c in range(num_classes):
        class_indices = np.where(labels == c)[0]
        np.random.shuffle(class_indices)

        # Dirichlet 분포로 비율 생성
        proportions = np.random.dirichlet([alpha] * num_clients)
        proportions = (proportions * len(class_indices)).astype(int)

        # 나머지 할당
        proportions[-1] = len(class_indices) - proportions[:-1].sum()

        start = 0
        for k in range(num_clients):
            end = start + proportions[k]
            client_indices[k].extend(class_indices[start:end].tolist())
            start = end

    return client_indices


# --- 서버 Strategy ---

strategy = fl.server.strategy.FedAvg(
    fraction_fit=0.5,           # 라운드당 50% 클라이언트 참여
    fraction_evaluate=0.3,      # 평가에 30% 참여
    min_fit_clients=2,          # 최소 2개 클라이언트
    min_evaluate_clients=2,
    min_available_clients=5,    # 최소 5개 클라이언트 대기
    on_fit_config_fn=lambda rnd: {"local_epochs": 2},
)


# --- 서버 실행 ---

def start_server():
    fl.server.start_server(
        server_address="0.0.0.0:8080",
        config=fl.server.ServerConfig(num_rounds=20),
        strategy=strategy,
    )


# --- 클라이언트 실행 ---

def start_client(client_id, num_clients=5):
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

    transform = transforms.Compose([
        transforms.ToTensor(),
        transforms.Normalize((0.1307,), (0.3081,)),
    ])

    full_train = datasets.MNIST("./data", train=True, download=True, transform=transform)
    full_test = datasets.MNIST("./data", train=False, transform=transform)

    # Non-IID 분할
    client_indices = partition_data_noniid(full_train, num_clients, alpha=0.5)

    train_subset = Subset(full_train, client_indices[client_id])
    trainloader = DataLoader(train_subset, batch_size=32, shuffle=True)
    testloader = DataLoader(full_test, batch_size=64)

    model = SimpleCNN().to(device)

    fl.client.start_numpy_client(
        server_address="127.0.0.1:8080",
        client=MNISTClient(model, trainloader, testloader, device),
    )

최신 동향

Personalized Federated Learning

글로벌 모델 하나로 모든 클라이언트를 만족시키기 어려우므로, 클라이언트별 맞춤 모델을 제공한다.

방법 설명
Fine-tuning 글로벌 모델을 로컬 데이터로 추가 학습
Per-FedAvg MAML 기반: 글로벌 모델이 좋은 초기화가 되도록 학습
pFedMe Moreau envelope로 개인화와 글로벌 간 균형
FedPer 하위 레이어 공유 + 상위 레이어 개인화
FedRep Representation 공유 + Head 개인화
Ditto 글로벌 모델 + 개인화 모델 동시 학습, proximal 제약

FL + LLM Fine-tuning

대규모 언어 모델의 fine-tuning을 federated 방식으로 수행하는 연구가 활발하다.

FL + LLM Fine-tuning 시나리오:

  기업 A (법률 문서)  --+
  기업 B (의료 기록)  --+--> Federated LoRA Fine-tuning
  기업 C (금융 보고서) --+         |
                                   v
                          [공유: LoRA adapter만]
                          [비공유: 원본 데이터, full weight]
                          [통신량: 수 MB (full model 수십 GB 대비)]

핵심 기법: - FedLoRA: LoRA adapter만 federated 학습, base model은 고정 - FFA-LoRA: Freeze-A, federate-B (LoRA의 A 행렬 고정, B만 연합 학습) - 통신량이 수 MB 수준으로 매우 효율적

기타 동향

분야 설명
Vertical FL 동일 사용자, 다른 feature를 가진 기관 간 학습 (예: 은행+통신사)
Federated Analytics 모델 학습이 아닌 통계 분석을 분산으로 수행
Asynchronous FL 동기화 없이 비동기적으로 업데이트 집계
Incentive Mechanism 클라이언트의 참여를 유도하는 보상 메커니즘 설계
Fairness 클라이언트 간 성능 공정성 보장

핵심 논문 목록

논문 연도 기여
McMahan et al., "Communication-Efficient Learning" 2017 FedAvg 제안
Li et al., "Federated Optimization in Heterogeneous Networks" 2020 FedProx
Karimireddy et al., "SCAFFOLD" 2020 Control variate 기반 drift 보정
Wang et al., "Tackling the Objective Inconsistency" 2020 FedNova
Li et al., "FedBN" 2021 BN 로컬 유지
Kairouz et al., "Advances and Open Problems in FL" 2021 종합 서베이
Zhang et al., "Federated Learning for LLMs" 2024 FL + LLM 서베이

마지막 업데이트: 2026-03-25