콘텐츠로 이동
Data Prep
상세

Kolmogorov-Arnold Networks (KAN)

메타 정보

항목 내용
논문 KAN: Kolmogorov-Arnold Networks
저자 Ziming Liu, Yixuan Wang, Sachin Vaidya, Fabian Ruehle, James Halverson, Marin Soljacic, Thomas Y. Hou, Max Tegmark (MIT, Caltech)
학회 ICLR 2025
arXiv 2404.19756 (2024.04)
코드 https://github.com/KindXiaoming/pykan
분야 Neural Architecture, Function Approximation

1. 핵심 아이디어

Kolmogorov-Arnold Networks (KAN)는 MLP의 대안으로 제안된 신경망 아키텍처다. MLP가 Universal Approximation Theorem에 기반한다면, KAN은 Kolmogorov-Arnold Representation Theorem에 기반한다.

Kolmogorov-Arnold 표현 정리

임의의 연속 다변수 함수 f: [0,1]^n -> R는 다음과 같이 표현 가능:

f(x1, ..., xn) = sum_{q=0}^{2n} Phi_q( sum_{p=1}^{n} phi_{q,p}(x_p) )
  • phi_{q,p}: [0,1] -> R (내부 univariate 함수)
  • Phi_q: R -> R (외부 univariate 함수)

핵심: 다변수 함수를 일변수 함수들의 합성으로 분해 가능


2. MLP vs KAN 구조 비교

┌─────────────────────────────────────────────────────────────────┐
│                    MLP (Multi-Layer Perceptron)                 │
├─────────────────────────────────────────────────────────────────┤
│                                                                 │
│    Input         Hidden Layer         Output                    │
│                                                                 │
│     x1 ─────┬─────> [w] ──> (sigma) ──┬                        │
│             │                          │                        │
│     x2 ─────┼─────> [w] ──> (sigma) ──┼────> y                 │
│             │                          │                        │
│     x3 ─────┴─────> [w] ──> (sigma) ──┴                        │
│                                                                 │
│    * 활성화 함수: 노드(뉴런)에 위치                              │
│    * 가중치: 고정된 선형 변환                                    │
│    * 활성화: ReLU, sigmoid 등 고정 함수                         │
│                                                                 │
└─────────────────────────────────────────────────────────────────┘

┌─────────────────────────────────────────────────────────────────┐
│                 KAN (Kolmogorov-Arnold Network)                 │
├─────────────────────────────────────────────────────────────────┤
│                                                                 │
│    Input         Hidden Layer         Output                    │
│                                                                 │
│     x1 ─────┬───[phi]───> (+) ──┬                              │
│             │                    │                              │
│     x2 ─────┼───[phi]───> (+) ──┼──[Phi]──> y                  │
│             │                    │                              │
│     x3 ─────┴───[phi]───> (+) ──┴                              │
│                                                                 │
│    * 활성화 함수: 엣지(연결)에 위치                              │
│    * 가중치 없음: 선형 가중치를 학습 가능한 함수로 대체           │
│    * 활성화: B-spline으로 파라미터화된 학습 가능 함수             │
│                                                                 │
└─────────────────────────────────────────────────────────────────┘
특성 MLP KAN
이론적 기반 Universal Approximation Theorem Kolmogorov-Arnold Theorem
활성화 함수 위치 노드 (뉴런) 엣지 (연결)
활성화 함수 고정 (ReLU, sigmoid 등) 학습 가능 (B-spline)
선형 가중치 있음 없음
파라미터 효율성 낮음 높음
해석 가능성 낮음 높음
Neural Scaling O(N^4) O(N^2.5)

3. KAN 레이어 수학적 정의

KAN 레이어는 1D 함수들의 행렬로 정의:

Phi = { phi_{q,p} }, p = 1, ..., n_in, q = 1, ..., n_out

입력 x = (x1, ..., x_n_in)에 대해:

x_out,q = sum_{p=1}^{n_in} phi_{q,p}(x_p)

B-Spline 파라미터화

각 phi는 B-spline + SiLU 기저 함수의 합으로 표현:

phi(x) = w_b * silu(x) + w_s * spline(x)

where:
  silu(x) = x / (1 + e^(-x))
  spline(x) = sum_i c_i * B_i(x)
  • B_i(x): B-spline 기저 함수
  • c_i: 학습 가능한 계수
  • grid: spline의 그리드 포인트 수 (해상도 결정)

4. 주요 장점

4.1 정확도 (Accuracy)

  • 더 작은 네트워크로 동등/우수한 성능: [2,5,5,1] KAN이 [2,128,128,1] MLP보다 우수
  • 빠른 Neural Scaling Laws: 파라미터 증가 대비 오차 감소율이 MLP보다 가파름
  • PDE 솔빙: 물리/수학 문제에서 높은 정밀도 달성

4.2 해석 가능성 (Interpretability)

  • 시각화 용이: model.plot()으로 각 활성화 함수 확인 가능
  • Symbolic Regression: 학습된 spline을 수학적 수식으로 변환 가능
  • 과학적 발견: 물리/수학 법칙 재발견에 활용

4.3 파라미터 효율성

┌──────────────────────────────────────────────────────────┐
│             Neural Scaling Laws 비교                     │
├──────────────────────────────────────────────────────────┤
│                                                          │
│  Test RMSE                                               │
│     |                                                    │
│  10^0 ┤  o MLP                                           │
│       │   o                                              │
│  10^-2├     o                                            │
│       │       o  x KAN                                   │
│  10^-4├         x                                        │
│       │           x                                      │
│  10^-6├             x                                    │
│       │               x                                  │
│       └────────────────────────> Parameters              │
│        10^1   10^2   10^3   10^4                         │
│                                                          │
│  MLP: Test Loss ~ N^(-4)                                 │
│  KAN: Test Loss ~ N^(-2.5) (더 빠른 수렴)                │
│                                                          │
└──────────────────────────────────────────────────────────┘

5. 한계 및 고려사항

한계 설명
계산 비용 Spline 연산으로 인해 MLP보다 느림
대규모 적용 LLM 등 대규모 모델에는 아직 검증 부족
하이퍼파라미터 튜닝 MLP 직관이 그대로 적용되지 않음
병렬화 Symbolic branch 사용 시 병렬화 어려움

적합한 사용 사례

  • 과학/수학 문제 (PDE, 물리 법칙 발견)
  • 고정밀도 함수 근사
  • 해석 가능한 모델이 필요한 경우
  • 작은 규모의 tabular 데이터

덜 적합한 사용 사례

  • 대규모 딥러닝 (LLM, Vision Transformer)
  • 실시간 추론이 필요한 경우
  • 데이터가 매우 많은 경우 (스케일링 이점 감소)

6. Python 구현 예시

6.1 기본 사용법 (pykan)

# 설치
# pip install pykan

from kan import KAN
import torch

# 모델 정의
# width: [입력차원, 은닉층1, ..., 출력차원]
# grid: B-spline 그리드 포인트 수
# k: B-spline 차수
model = KAN(width=[2, 5, 1], grid=5, k=3)

# 학습 데이터 생성 (예: f(x,y) = sin(pi*x) + y^2)
def target_fn(x):
    return torch.sin(torch.pi * x[:, [0]]) + x[:, [1]]**2

x_train = torch.rand(1000, 2) * 2 - 1  # [-1, 1] 범위
y_train = target_fn(x_train)

x_test = torch.rand(200, 2) * 2 - 1
y_test = target_fn(x_test)

dataset = {
    'train_input': x_train,
    'train_label': y_train,
    'test_input': x_test,
    'test_label': y_test
}

# 학습
# lamb: sparsity regularization
model.train(dataset, opt="LBFGS", steps=50, lamb=0.01)

# 시각화
model.plot()

6.2 효율성 모드

# 학습 루프를 직접 작성하고 symbolic branch를 사용하지 않는 경우
# 반드시 speed() 호출 필요 (성능 최적화)
model.speed()

optimizer = torch.optim.Adam(model.parameters(), lr=1e-3)
for epoch in range(100):
    pred = model(x_train)
    loss = ((pred - y_train)**2).mean()

    optimizer.zero_grad()
    loss.backward()
    optimizer.step()

6.3 Grid Extension (정확도 향상)

# 초기 학습: 낮은 grid
model = KAN(width=[2, 5, 1], grid=3, k=3)
model.train(dataset, opt="LBFGS", steps=20)

# Grid 확장: 더 높은 해상도로
model = model.refine(grid=10)
model.train(dataset, opt="LBFGS", steps=20)

# 추가 확장
model = model.refine(grid=20)
model.train(dataset, opt="LBFGS", steps=20)

6.4 Symbolic Regression

# 학습된 KAN에서 수학적 수식 추출
model.auto_symbolic()

# 또는 특정 활성화 함수 지정
model.fix_symbolic(0, 0, 0, 'sin')  # layer 0, input 0, output 0을 sin으로
model.fix_symbolic(0, 1, 0, 'x^2')  # layer 0, input 1, output 0을 x^2으로

# 수식 출력
formula = model.symbolic_formula()
print(formula)

6.5 Pruning (해석 가능성 향상)

# Sparsity regularization으로 학습
model.train(dataset, opt="LBFGS", steps=50, lamb=0.01)

# 불필요한 뉴런 제거
pruned_model = model.prune()

# 시각화
pruned_model.plot()

6.6 PyTorch 직접 구현 (간소화 버전)

import torch
import torch.nn as nn
import torch.nn.functional as F

class KANLayer(nn.Module):
    """단순화된 KAN 레이어 구현"""

    def __init__(self, in_features, out_features, grid_size=5, spline_order=3):
        super().__init__()
        self.in_features = in_features
        self.out_features = out_features
        self.grid_size = grid_size
        self.spline_order = spline_order

        # B-spline 계수
        self.spline_weight = nn.Parameter(
            torch.randn(out_features, in_features, grid_size + spline_order)
        )

        # Base weight (SiLU 부분)
        self.base_weight = nn.Parameter(
            torch.randn(out_features, in_features)
        )

        # Grid 정의
        h = 2.0 / grid_size
        grid = torch.linspace(-1 - h * spline_order, 1 + h * spline_order, 
                             grid_size + 2 * spline_order + 1)
        self.register_buffer('grid', grid)

    def b_spline_basis(self, x, k=0):
        """B-spline 기저 함수 계산"""
        grid = self.grid
        if k == 0:
            return ((x >= grid[:-1]) & (x < grid[1:])).float()
        else:
            b1 = self.b_spline_basis(x, k-1)
            b2 = self.b_spline_basis(x, k-1)

            denom1 = grid[k:-1] - grid[:-k-1]
            denom2 = grid[k+1:] - grid[1:-k]

            term1 = (x - grid[:-k-1]) / (denom1 + 1e-8) * b1[..., :-1]
            term2 = (grid[k+1:] - x) / (denom2 + 1e-8) * b2[..., 1:]

            return term1 + term2

    def forward(self, x):
        # x: (batch, in_features)
        batch_size = x.shape[0]

        # Spline 부분
        x_expanded = x.unsqueeze(-1)  # (batch, in_features, 1)
        bases = self.b_spline_basis(x_expanded, self.spline_order)
        # bases: (batch, in_features, grid_size + spline_order)

        spline_out = torch.einsum('oin,bin->bo', self.spline_weight, bases)

        # Base 부분 (SiLU)
        base_out = F.silu(x) @ self.base_weight.T

        return base_out + spline_out


class SimpleKAN(nn.Module):
    """단순화된 KAN 모델"""

    def __init__(self, width, grid_size=5, spline_order=3):
        super().__init__()
        self.layers = nn.ModuleList([
            KANLayer(width[i], width[i+1], grid_size, spline_order)
            for i in range(len(width) - 1)
        ])

    def forward(self, x):
        for layer in self.layers:
            x = layer(x)
        return x


# 사용 예시
if __name__ == "__main__":
    # 모델 생성: 2 -> 5 -> 1
    model = SimpleKAN(width=[2, 5, 1], grid_size=5)

    # 테스트
    x = torch.randn(32, 2)
    y = model(x)
    print(f"Input: {x.shape}, Output: {y.shape}")

7. 하이퍼파라미터 가이드

7.1 기본 원칙

시작: 단순하게 → 점진적 확장
      (작은 width, 작은 grid, 정규화 없음)

7.2 권장 설정

파라미터 시작값 조정 방향
width [n_in, 1, n_out] 성능 부족 시 증가
grid 3 Grid extension으로 점진적 증가
k (spline order) 3 일반적으로 고정
lamb (sparsity) 0 해석 가능성 필요 시 0.01부터

7.3 Overfitting 대응

train/test loss 격차 큼 → 
  1. 데이터 증가 
  2. grid 감소 (가장 중요)
  3. width 감소

8. 관련 연구 및 확장

프로젝트 설명
efficient-kan 효율성 개선 구현
FourierKAN Fourier 기저 함수 사용
GraphKAN 그래프 신경망에 KAN 적용
TKAN Temporal KAN (시계열)
KAN-GNN 분자 특성 예측용 KAN+GNN

9. 참고 자료

  • 논문: https://arxiv.org/abs/2404.19756
  • KAN 2.0: https://arxiv.org/abs/2408.10205
  • 공식 코드: https://github.com/KindXiaoming/pykan
  • 문서: https://kindxiaoming.github.io/pykan/
  • OpenReview: https://openreview.net/forum?id=Ozo7qJ5vZi

최종 업데이트: 2026-02-15