Kolmogorov-Arnold Networks (KAN)¶
메타 정보¶
| 항목 | 내용 |
|---|---|
| 논문 | KAN: Kolmogorov-Arnold Networks |
| 저자 | Ziming Liu, Yixuan Wang, Sachin Vaidya, Fabian Ruehle, James Halverson, Marin Soljacic, Thomas Y. Hou, Max Tegmark (MIT, Caltech) |
| 학회 | ICLR 2025 |
| arXiv | 2404.19756 (2024.04) |
| 코드 | https://github.com/KindXiaoming/pykan |
| 분야 | Neural Architecture, Function Approximation |
1. 핵심 아이디어¶
Kolmogorov-Arnold Networks (KAN)는 MLP의 대안으로 제안된 신경망 아키텍처다. MLP가 Universal Approximation Theorem에 기반한다면, KAN은 Kolmogorov-Arnold Representation Theorem에 기반한다.
Kolmogorov-Arnold 표현 정리¶
임의의 연속 다변수 함수 f: [0,1]^n -> R는 다음과 같이 표현 가능:
- phi_{q,p}: [0,1] -> R (내부 univariate 함수)
- Phi_q: R -> R (외부 univariate 함수)
핵심: 다변수 함수를 일변수 함수들의 합성으로 분해 가능
2. MLP vs KAN 구조 비교¶
┌─────────────────────────────────────────────────────────────────┐
│ MLP (Multi-Layer Perceptron) │
├─────────────────────────────────────────────────────────────────┤
│ │
│ Input Hidden Layer Output │
│ │
│ x1 ─────┬─────> [w] ──> (sigma) ──┬ │
│ │ │ │
│ x2 ─────┼─────> [w] ──> (sigma) ──┼────> y │
│ │ │ │
│ x3 ─────┴─────> [w] ──> (sigma) ──┴ │
│ │
│ * 활성화 함수: 노드(뉴런)에 위치 │
│ * 가중치: 고정된 선형 변환 │
│ * 활성화: ReLU, sigmoid 등 고정 함수 │
│ │
└─────────────────────────────────────────────────────────────────┘
┌─────────────────────────────────────────────────────────────────┐
│ KAN (Kolmogorov-Arnold Network) │
├─────────────────────────────────────────────────────────────────┤
│ │
│ Input Hidden Layer Output │
│ │
│ x1 ─────┬───[phi]───> (+) ──┬ │
│ │ │ │
│ x2 ─────┼───[phi]───> (+) ──┼──[Phi]──> y │
│ │ │ │
│ x3 ─────┴───[phi]───> (+) ──┴ │
│ │
│ * 활성화 함수: 엣지(연결)에 위치 │
│ * 가중치 없음: 선형 가중치를 학습 가능한 함수로 대체 │
│ * 활성화: B-spline으로 파라미터화된 학습 가능 함수 │
│ │
└─────────────────────────────────────────────────────────────────┘
| 특성 | MLP | KAN |
|---|---|---|
| 이론적 기반 | Universal Approximation Theorem | Kolmogorov-Arnold Theorem |
| 활성화 함수 위치 | 노드 (뉴런) | 엣지 (연결) |
| 활성화 함수 | 고정 (ReLU, sigmoid 등) | 학습 가능 (B-spline) |
| 선형 가중치 | 있음 | 없음 |
| 파라미터 효율성 | 낮음 | 높음 |
| 해석 가능성 | 낮음 | 높음 |
| Neural Scaling | O(N^4) | O(N^2.5) |
3. KAN 레이어 수학적 정의¶
KAN 레이어는 1D 함수들의 행렬로 정의:
입력 x = (x1, ..., x_n_in)에 대해:
B-Spline 파라미터화¶
각 phi는 B-spline + SiLU 기저 함수의 합으로 표현:
phi(x) = w_b * silu(x) + w_s * spline(x)
where:
silu(x) = x / (1 + e^(-x))
spline(x) = sum_i c_i * B_i(x)
- B_i(x): B-spline 기저 함수
- c_i: 학습 가능한 계수
- grid: spline의 그리드 포인트 수 (해상도 결정)
4. 주요 장점¶
4.1 정확도 (Accuracy)¶
- 더 작은 네트워크로 동등/우수한 성능: [2,5,5,1] KAN이 [2,128,128,1] MLP보다 우수
- 빠른 Neural Scaling Laws: 파라미터 증가 대비 오차 감소율이 MLP보다 가파름
- PDE 솔빙: 물리/수학 문제에서 높은 정밀도 달성
4.2 해석 가능성 (Interpretability)¶
- 시각화 용이:
model.plot()으로 각 활성화 함수 확인 가능 - Symbolic Regression: 학습된 spline을 수학적 수식으로 변환 가능
- 과학적 발견: 물리/수학 법칙 재발견에 활용
4.3 파라미터 효율성¶
┌──────────────────────────────────────────────────────────┐
│ Neural Scaling Laws 비교 │
├──────────────────────────────────────────────────────────┤
│ │
│ Test RMSE │
│ | │
│ 10^0 ┤ o MLP │
│ │ o │
│ 10^-2├ o │
│ │ o x KAN │
│ 10^-4├ x │
│ │ x │
│ 10^-6├ x │
│ │ x │
│ └────────────────────────> Parameters │
│ 10^1 10^2 10^3 10^4 │
│ │
│ MLP: Test Loss ~ N^(-4) │
│ KAN: Test Loss ~ N^(-2.5) (더 빠른 수렴) │
│ │
└──────────────────────────────────────────────────────────┘
5. 한계 및 고려사항¶
| 한계 | 설명 |
|---|---|
| 계산 비용 | Spline 연산으로 인해 MLP보다 느림 |
| 대규모 적용 | LLM 등 대규모 모델에는 아직 검증 부족 |
| 하이퍼파라미터 튜닝 | MLP 직관이 그대로 적용되지 않음 |
| 병렬화 | Symbolic branch 사용 시 병렬화 어려움 |
적합한 사용 사례¶
- 과학/수학 문제 (PDE, 물리 법칙 발견)
- 고정밀도 함수 근사
- 해석 가능한 모델이 필요한 경우
- 작은 규모의 tabular 데이터
덜 적합한 사용 사례¶
- 대규모 딥러닝 (LLM, Vision Transformer)
- 실시간 추론이 필요한 경우
- 데이터가 매우 많은 경우 (스케일링 이점 감소)
6. Python 구현 예시¶
6.1 기본 사용법 (pykan)¶
# 설치
# pip install pykan
from kan import KAN
import torch
# 모델 정의
# width: [입력차원, 은닉층1, ..., 출력차원]
# grid: B-spline 그리드 포인트 수
# k: B-spline 차수
model = KAN(width=[2, 5, 1], grid=5, k=3)
# 학습 데이터 생성 (예: f(x,y) = sin(pi*x) + y^2)
def target_fn(x):
return torch.sin(torch.pi * x[:, [0]]) + x[:, [1]]**2
x_train = torch.rand(1000, 2) * 2 - 1 # [-1, 1] 범위
y_train = target_fn(x_train)
x_test = torch.rand(200, 2) * 2 - 1
y_test = target_fn(x_test)
dataset = {
'train_input': x_train,
'train_label': y_train,
'test_input': x_test,
'test_label': y_test
}
# 학습
# lamb: sparsity regularization
model.train(dataset, opt="LBFGS", steps=50, lamb=0.01)
# 시각화
model.plot()
6.2 효율성 모드¶
# 학습 루프를 직접 작성하고 symbolic branch를 사용하지 않는 경우
# 반드시 speed() 호출 필요 (성능 최적화)
model.speed()
optimizer = torch.optim.Adam(model.parameters(), lr=1e-3)
for epoch in range(100):
pred = model(x_train)
loss = ((pred - y_train)**2).mean()
optimizer.zero_grad()
loss.backward()
optimizer.step()
6.3 Grid Extension (정확도 향상)¶
# 초기 학습: 낮은 grid
model = KAN(width=[2, 5, 1], grid=3, k=3)
model.train(dataset, opt="LBFGS", steps=20)
# Grid 확장: 더 높은 해상도로
model = model.refine(grid=10)
model.train(dataset, opt="LBFGS", steps=20)
# 추가 확장
model = model.refine(grid=20)
model.train(dataset, opt="LBFGS", steps=20)
6.4 Symbolic Regression¶
# 학습된 KAN에서 수학적 수식 추출
model.auto_symbolic()
# 또는 특정 활성화 함수 지정
model.fix_symbolic(0, 0, 0, 'sin') # layer 0, input 0, output 0을 sin으로
model.fix_symbolic(0, 1, 0, 'x^2') # layer 0, input 1, output 0을 x^2으로
# 수식 출력
formula = model.symbolic_formula()
print(formula)
6.5 Pruning (해석 가능성 향상)¶
# Sparsity regularization으로 학습
model.train(dataset, opt="LBFGS", steps=50, lamb=0.01)
# 불필요한 뉴런 제거
pruned_model = model.prune()
# 시각화
pruned_model.plot()
6.6 PyTorch 직접 구현 (간소화 버전)¶
import torch
import torch.nn as nn
import torch.nn.functional as F
class KANLayer(nn.Module):
"""단순화된 KAN 레이어 구현"""
def __init__(self, in_features, out_features, grid_size=5, spline_order=3):
super().__init__()
self.in_features = in_features
self.out_features = out_features
self.grid_size = grid_size
self.spline_order = spline_order
# B-spline 계수
self.spline_weight = nn.Parameter(
torch.randn(out_features, in_features, grid_size + spline_order)
)
# Base weight (SiLU 부분)
self.base_weight = nn.Parameter(
torch.randn(out_features, in_features)
)
# Grid 정의
h = 2.0 / grid_size
grid = torch.linspace(-1 - h * spline_order, 1 + h * spline_order,
grid_size + 2 * spline_order + 1)
self.register_buffer('grid', grid)
def b_spline_basis(self, x, k=0):
"""B-spline 기저 함수 계산"""
grid = self.grid
if k == 0:
return ((x >= grid[:-1]) & (x < grid[1:])).float()
else:
b1 = self.b_spline_basis(x, k-1)
b2 = self.b_spline_basis(x, k-1)
denom1 = grid[k:-1] - grid[:-k-1]
denom2 = grid[k+1:] - grid[1:-k]
term1 = (x - grid[:-k-1]) / (denom1 + 1e-8) * b1[..., :-1]
term2 = (grid[k+1:] - x) / (denom2 + 1e-8) * b2[..., 1:]
return term1 + term2
def forward(self, x):
# x: (batch, in_features)
batch_size = x.shape[0]
# Spline 부분
x_expanded = x.unsqueeze(-1) # (batch, in_features, 1)
bases = self.b_spline_basis(x_expanded, self.spline_order)
# bases: (batch, in_features, grid_size + spline_order)
spline_out = torch.einsum('oin,bin->bo', self.spline_weight, bases)
# Base 부분 (SiLU)
base_out = F.silu(x) @ self.base_weight.T
return base_out + spline_out
class SimpleKAN(nn.Module):
"""단순화된 KAN 모델"""
def __init__(self, width, grid_size=5, spline_order=3):
super().__init__()
self.layers = nn.ModuleList([
KANLayer(width[i], width[i+1], grid_size, spline_order)
for i in range(len(width) - 1)
])
def forward(self, x):
for layer in self.layers:
x = layer(x)
return x
# 사용 예시
if __name__ == "__main__":
# 모델 생성: 2 -> 5 -> 1
model = SimpleKAN(width=[2, 5, 1], grid_size=5)
# 테스트
x = torch.randn(32, 2)
y = model(x)
print(f"Input: {x.shape}, Output: {y.shape}")
7. 하이퍼파라미터 가이드¶
7.1 기본 원칙¶
7.2 권장 설정¶
| 파라미터 | 시작값 | 조정 방향 |
|---|---|---|
| width | [n_in, 1, n_out] | 성능 부족 시 증가 |
| grid | 3 | Grid extension으로 점진적 증가 |
| k (spline order) | 3 | 일반적으로 고정 |
| lamb (sparsity) | 0 | 해석 가능성 필요 시 0.01부터 |
7.3 Overfitting 대응¶
8. 관련 연구 및 확장¶
| 프로젝트 | 설명 |
|---|---|
| efficient-kan | 효율성 개선 구현 |
| FourierKAN | Fourier 기저 함수 사용 |
| GraphKAN | 그래프 신경망에 KAN 적용 |
| TKAN | Temporal KAN (시계열) |
| KAN-GNN | 분자 특성 예측용 KAN+GNN |
9. 참고 자료¶
- 논문: https://arxiv.org/abs/2404.19756
- KAN 2.0: https://arxiv.org/abs/2408.10205
- 공식 코드: https://github.com/KindXiaoming/pykan
- 문서: https://kindxiaoming.github.io/pykan/
- OpenReview: https://openreview.net/forum?id=Ozo7qJ5vZi
최종 업데이트: 2026-02-15