콘텐츠로 이동
Data Prep
상세

Meta-Learning (메타 학습)

메타 정보

항목 내용
분류 Learning Paradigm / Few-Shot Learning / Transfer Learning
핵심 논문 "Model-Agnostic Meta-Learning for Fast Adaptation" (ICML 2017), "Matching Networks for One Shot Learning" (NeurIPS 2016), "Prototypical Networks for Few-shot Learning" (NeurIPS 2017)
주요 저자 Chelsea Finn, Pieter Abbeel, Sergey Levine (MAML); Oriol Vinyals et al. (Matching Networks); Jake Snell, Kevin Swersky, Richard Zemel (Prototypical Networks)
핵심 개념 학습하는 방법을 학습 (learning to learn) -- 소수의 샘플로 새로운 태스크에 빠르게 적응
관련 분야 Few-Shot Learning, Transfer Learning, Neural Architecture Search, Hyperparameter Optimization

정의

Meta-Learning은 여러 태스크의 학습 경험을 축적하여, 새로운 태스크를 소수의 샘플만으로 빠르게 학습할 수 있는 능력을 갖추는 학습 패러다임이다. 전통적 머신러닝이 단일 태스크의 데이터 분포를 학습하는 것과 달리, 메타 학습은 태스크 분포를 학습한다.

전통적 학습:
  단일 태스크 T의 데이터 D = {(x_i, y_i)} --> 모델 f_theta

메타 학습:
  태스크 분포 p(T)에서 샘플링된 {T_1, T_2, ..., T_N} --> 메타 학습기 M
  새로운 태스크 T_new + 소수 샘플 --> M이 빠르게 적응

문제 설정 (Episode Training)

메타 학습은 에피소드(episode) 기반 훈련을 사용한다.

N-way K-shot 분류 문제:

태스크 T_i:
  Support Set S = {(x_j, y_j)}_{j=1}^{N*K}   -- K개씩 N개 클래스
  Query Set   Q = {(x_j, y_j)}_{j=1}^{N*Q}   -- 평가용 샘플

에피소드 구성:
  1. p(T)에서 태스크 T_i 샘플링
  2. T_i에서 N개 클래스 무작위 선택
  3. 각 클래스에서 K개 support + Q개 query 샘플링
  4. Support set으로 학습, Query set으로 평가

  메타 학습 단계:
  +--------------------------------------------------+
  | Episode 1: 5-way 1-shot                          |
  |   Support: [dog(1), cat(1), bird(1), fish(1),    |
  |             horse(1)]                             |
  |   Query:   [dog(?), cat(?), bird(?), ...]         |
  |   --> Loss L_1                                    |
  +--------------------------------------------------+
  | Episode 2: 5-way 1-shot                          |
  |   Support: [car(1), bus(1), train(1), plane(1),  |
  |             ship(1)]                              |
  |   Query:   [car(?), bus(?), ...]                  |
  |   --> Loss L_2                                    |
  +--------------------------------------------------+
  ...
  메타 업데이트: theta <- theta - alpha * grad(sum L_i)

주요 접근법

메타 학습은 크게 세 가지 패러다임으로 분류된다.

메타 학습 분류 체계:

Meta-Learning
  |
  +-- Metric-Based (거리 기반)
  |     Siamese Networks, Matching Networks,
  |     Prototypical Networks, Relation Networks
  |
  +-- Optimization-Based (최적화 기반)
  |     MAML, Reptile, Meta-SGD, iMAML, ANIL
  |
  +-- Model-Based (모델 기반)
        MANN, SNAIL, MetaNet, CNP/ANP

1. Metric-Based Meta-Learning (거리 기반)

핵심 아이디어: 임베딩 공간에서 유사한 샘플은 가깝고, 다른 샘플은 멀도록 학습한다. 새 태스크의 query 샘플을 support 샘플과의 거리로 분류한다.

Siamese Networks

Koch et al. (2015). 두 입력의 유사도를 학습하는 쌍둥이 네트워크.

Siamese Network:

  x_1 --[f_theta]--> z_1 --|
                            |--> |z_1 - z_2| --> sigma(W * |z_1 - z_2|)
  x_2 --[f_theta]--> z_2 --|                    --> P(same class)

  동일 가중치 f_theta를 공유하여 일관된 임베딩 생성

Prototypical Networks

Snell et al. (NeurIPS 2017). 각 클래스의 support 샘플 임베딩 평균을 프로토타입으로 사용한다.

Prototypical Networks:

1. 클래스별 프로토타입 계산:
   c_k = (1/|S_k|) * sum_{(x_i, y_i) in S_k} f_theta(x_i)

2. Query 샘플의 클래스 확률:
   p(y = k | x) = softmax(-d(f_theta(x), c_k))

   여기서 d는 유클리드 거리 또는 코사인 거리

임베딩 공간 시각화:

    c_1 (dog)          c_2 (cat)
      *                  *
     / \                / \
    .   .              .   .     <- support 샘플
                 ?               <- query: c_1에 더 가까움 --> dog
특성 설명
거리 함수 유클리드 거리 (원논문), 코사인 유사도, Mahalanobis 거리
장점 단순하고 효과적, 계산 효율적
한계 클래스 내 분산이 큰 경우 프로토타입이 대표성 부족
import torch
import torch.nn as nn
import torch.nn.functional as F

class PrototypicalNetwork(nn.Module):
    def __init__(self, encoder):
        super().__init__()
        self.encoder = encoder  # CNN 또는 ResNet backbone

    def compute_prototypes(self, support, labels, n_way):
        """클래스별 프로토타입 계산"""
        z_support = self.encoder(support)
        prototypes = []
        for k in range(n_way):
            mask = (labels == k)
            prototypes.append(z_support[mask].mean(dim=0))
        return torch.stack(prototypes)  # (n_way, embedding_dim)

    def forward(self, support, support_labels, query, n_way):
        prototypes = self.compute_prototypes(support, support_labels, n_way)
        z_query = self.encoder(query)

        # 유클리드 거리 계산
        dists = torch.cdist(z_query, prototypes)  # (n_query, n_way)
        log_probs = F.log_softmax(-dists, dim=1)
        return log_probs

# 학습 루프 (에피소드 기반)
def train_episode(model, support, support_labels, query, query_labels, n_way):
    log_probs = model(support, support_labels, query, n_way)
    loss = F.nll_loss(log_probs, query_labels)
    return loss

Matching Networks

Vinyals et al. (NeurIPS 2016). Attention 메커니즘으로 query와 support 간 유사도를 가중 합산한다.

Matching Networks:

  p(y | x, S) = sum_{i=1}^{|S|} a(x, x_i) * y_i

  a(x, x_i) = softmax(cosine(f(x), g(x_i)))

  f: query 인코더 (LSTM 기반 Full Context Embedding)
  g: support 인코더

  핵심 차이: f와 g가 전체 support set을 조건으로 인코딩
  -> support set의 맥락 정보를 반영

Relation Networks

Sung et al. (CVPR 2018). 거리 함수 자체를 학습 가능한 신경망으로 대체한다.

Relation Network:

  z_query   = f_theta(x_query)
  z_support = f_theta(x_support)

  relation_score = g_phi(concat(z_query, z_support))

  g_phi: 학습 가능한 관계 네트워크 (2-layer CNN + FC)

  장점: 유클리드/코사인보다 유연한 유사도 학습

2. Optimization-Based Meta-Learning (최적화 기반)

핵심 아이디어: 소수의 경사 하강 스텝으로 새 태스크에 빠르게 적응할 수 있는 초기 파라미터를 학습한다.

MAML (Model-Agnostic Meta-Learning)

Finn et al. (ICML 2017). 메타 학습 분야의 대표적 알고리즘으로, 모델 구조에 무관하게 적용 가능하다.

MAML 알고리즘:

외부 루프 (Meta-update):
  for each meta-batch:

    내부 루프 (Task adaptation):
      for each task T_i:
        1. 현재 파라미터 theta로 시작
        2. Support set으로 k-step gradient descent:
           theta_i' = theta - alpha * grad_theta L_{T_i}(f_theta)
        3. 적응된 theta_i'로 Query set 손실 계산

    메타 업데이트:
      theta <- theta - beta * grad_theta sum_i L_{T_i}(f_{theta_i'})

핵심: "빠르게 적응할 수 있는 초기점"을 학습

  파라미터 공간 시각화:

  theta (초기점)
    |\
    | \  1-step adaptation
    |  \
    v   v
  theta_1' theta_2'   <- 각 태스크에 적응된 파라미터
  (T_1)    (T_2)

  메타 업데이트는 theta를 모든 태스크에 잘 적응하는 위치로 이동

이중 경사(Bi-level Optimization) 구조

MAML의 수학적 구조:

외부 목적함수 (meta-objective):
  min_theta  sum_{T_i ~ p(T)} L_{T_i}(U_k(theta, D^train_{T_i}), D^test_{T_i})

내부 최적화 (task adaptation):
  U_k(theta, D) = theta - alpha * grad L(theta, D)  (1-step)

  또는 k-step:
  theta_0 = theta
  theta_{j+1} = theta_j - alpha * grad L(theta_j, D)  for j = 0,...,k-1

메타 그래디언트 (2차 미분 포함):
  grad_theta L(theta') = grad_theta L(theta - alpha * grad_theta L(theta))
                       = (I - alpha * H) * grad_{theta'} L(theta')

  H: 헤시안 행렬 -- 계산 비용이 큼
import torch
import torch.nn as nn
import torch.nn.functional as F
from copy import deepcopy

class MAML:
    def __init__(self, model, inner_lr=0.01, meta_lr=0.001, inner_steps=5):
        self.model = model
        self.inner_lr = inner_lr
        self.meta_lr = meta_lr
        self.inner_steps = inner_steps
        self.meta_optimizer = torch.optim.Adam(model.parameters(), lr=meta_lr)

    def inner_loop(self, support_x, support_y):
        """태스크별 내부 적응"""
        fast_weights = {name: p.clone() for name, p in self.model.named_parameters()}

        for step in range(self.inner_steps):
            logits = self.model.functional_forward(support_x, fast_weights)
            loss = F.cross_entropy(logits, support_y)
            grads = torch.autograd.grad(loss, fast_weights.values(), create_graph=True)

            fast_weights = {
                name: w - self.inner_lr * g
                for (name, w), g in zip(fast_weights.items(), grads)
            }

        return fast_weights

    def meta_step(self, tasks):
        """메타 업데이트"""
        meta_loss = 0.0

        for support_x, support_y, query_x, query_y in tasks:
            # 내부 루프: 태스크 적응
            fast_weights = self.inner_loop(support_x, support_y)

            # 외부 루프: Query set으로 메타 손실 계산
            query_logits = self.model.functional_forward(query_x, fast_weights)
            task_loss = F.cross_entropy(query_logits, query_y)
            meta_loss += task_loss

        meta_loss /= len(tasks)

        self.meta_optimizer.zero_grad()
        meta_loss.backward()
        self.meta_optimizer.step()

        return meta_loss.item()

MAML 변형들

변형 핵심 차이 논문
First-Order MAML (FOMAML) 헤시안 무시, 1차 미분만 사용 Finn et al. 2017
Reptile 내부 루프 후 파라미터 차이 방향으로 업데이트 Nichol et al. 2018
Meta-SGD 학습률도 메타 학습 Li et al. ICML 2017
ANIL 마지막 레이어만 내부 루프에서 적응 Raghu et al. ICLR 2020
iMAML 암묵적 미분으로 메모리 효율화 Rajeswaran et al. NeurIPS 2019

Reptile

Nichol et al. (2018). MAML보다 구현이 단순한 1차 근사 방법.

Reptile 알고리즘:

  for each iteration:
    1. 태스크 T_i 샘플링
    2. theta에서 시작, k-step SGD 수행:
       theta_i' = SGD(theta, T_i, k steps)
    3. 메타 업데이트:
       theta <- theta + epsilon * (theta_i' - theta)

  직관: "각 태스크에 적응한 파라미터 방향의 평균"으로 이동

  MAML과의 차이:
  - 2차 미분 불필요
  - 구현이 매우 단순
  - 성능은 MAML에 근접

3. Model-Based Meta-Learning (모델 기반)

핵심 아이디어: 외부 메모리 또는 특수 아키텍처를 사용하여 새 태스크의 정보를 빠르게 인코딩하고 활용한다.

Memory-Augmented Neural Network (MANN)

Santoro et al. (ICML 2016). Neural Turing Machine에 메타 학습을 결합.

MANN 구조:

  x_t, y_{t-1}  --[Controller]--> read/write --> [External Memory M]
       |                                              |
       |              <-- read --                     |
       v                                              
    output y_t

  핵심: 1-step offset으로 학습
  - 시점 t에서 x_t와 함께 이전 정답 y_{t-1}을 입력
  - 네트워크는 x_t를 메모리에 저장하고, 유사한 과거 패턴을 검색
  - 다음 시점에 같은 클래스가 나타나면 메모리에서 검색하여 분류

SNAIL (Simple Neural Attentive Learner)

Mishra et al. (ICLR 2018). Temporal Convolution + Soft Attention으로 에피소드 내 패턴을 학습.

SNAIL 아키텍처:

  입력 시퀀스: [(x_1, y_1), (x_2, y_2), ..., (x_t, ?)]
       |
  [Temporal Convolution] -- 지역 패턴 추출
       |
  [Causal Attention]     -- 전역 의존성 캡처
       |
  [Temporal Convolution]
       |
  [Causal Attention]
       |
    output: y_t 예측

Conditional Neural Processes (CNP/ANP)

Garnelo et al. (ICML 2018). 함수 분포를 학습하는 메타 학습 프레임워크.

Neural Process:

  Context set: {(x_c, y_c)}  --> Encoder --> r_c (context representation)
                                                |
                                                v
  Target x_t  --------------------------------> Decoder --> p(y_t | x_t, C)

  ANP (Attentive Neural Process):
  - Cross-attention으로 target이 관련 context를 선택적으로 참조
  - 불확실성 추정 가능 (predictive distribution 출력)

Metric vs Optimization vs Model 비교

측면 Metric-Based Optimization-Based Model-Based
적응 방식 거리 함수 비교 Gradient descent 메모리/아키텍처
적응 속도 Feed-forward (즉시) K-step gradient Feed-forward
유연성 분류에 특화 모든 태스크 적용 가능 아키텍처 의존적
계산 비용 낮음 높음 (2차 미분) 중간
대표 방법 ProtoNet, MatchNet MAML, Reptile MANN, SNAIL, CNP
스케일링 대규모 가능 모델 크기에 제한 메모리 크기에 의존

Few-Shot 벤치마크

데이터셋 도메인 클래스 수 설명
Omniglot 문자 인식 1,623 50개 알파벳의 필기 문자
mini-ImageNet 이미지 분류 100 ImageNet 서브셋 (600장/클래스)
tiered-ImageNet 이미지 분류 608 계층적 ImageNet 서브셋
CUB-200 세밀 분류 200 새 종류 분류
Meta-Dataset 다중 도메인 다양 10개 데이터셋 통합 벤치마크

mini-ImageNet 5-way 5-shot 성능 비교

방법 정확도 (%) 년도
Matching Networks 55.3 2016
Prototypical Networks 68.2 2017
MAML 63.1 2017
Relation Networks 65.3 2018
TADAM 76.7 2018
MetaOptNet 78.6 2019
FEAT 82.7 2020
P>M>F (Pre-train, Meta-train, Fine-tune) 85.4+ 2022

메타 학습의 응용 분야

1. Few-Shot 이미지 분류

가장 전통적인 응용. 소수 샘플로 새 클래스를 인식한다.

2. Few-Shot Object Detection

새로운 물체 클래스를 소수 바운딩 박스 어노테이션만으로 탐지.

3. Drug Discovery

새로운 화합물-타겟 상호작용을 소수 실험 결과로 예측.

4. Robotics

새로운 환경/태스크에 몇 번의 시행착오로 적응하는 로봇 정책 학습.

5. Neural Architecture Search (NAS)

아키텍처 성능을 소수 학습 에피소드로 빠르게 추정.

6. Hyperparameter Optimization

학습률, 정규화 등 하이퍼파라미터를 메타 학습으로 자동 조정.

실전 파이프라인

import torch
import torch.nn as nn
from torchvision import models

# 1. Backbone 정의 (feature extractor)
class ConvBlock(nn.Module):
    def __init__(self, in_ch, out_ch):
        super().__init__()
        self.block = nn.Sequential(
            nn.Conv2d(in_ch, out_ch, 3, padding=1),
            nn.BatchNorm2d(out_ch),
            nn.ReLU(),
            nn.MaxPool2d(2)
        )

    def forward(self, x):
        return self.block(x)

class FewShotEncoder(nn.Module):
    """4-layer Conv backbone (mini-ImageNet 표준)"""
    def __init__(self, in_channels=3, hidden_dim=64, out_dim=1600):
        super().__init__()
        self.encoder = nn.Sequential(
            ConvBlock(in_channels, hidden_dim),
            ConvBlock(hidden_dim, hidden_dim),
            ConvBlock(hidden_dim, hidden_dim),
            ConvBlock(hidden_dim, hidden_dim),
        )

    def forward(self, x):
        return self.encoder(x).view(x.size(0), -1)

# 2. Prototypical Networks 학습
def prototypical_loss(support, support_labels, query, query_labels, n_way):
    """
    Args:
        support: (n_way * k_shot, embedding_dim)
        query: (n_way * n_query, embedding_dim)
    """
    prototypes = []
    for k in range(n_way):
        mask = (support_labels == k)
        prototypes.append(support[mask].mean(dim=0))
    prototypes = torch.stack(prototypes)  # (n_way, dim)

    dists = torch.cdist(query, prototypes)  # (n_query_total, n_way)
    log_probs = (-dists).log_softmax(dim=1)

    loss = nn.functional.nll_loss(log_probs, query_labels)
    acc = (log_probs.argmax(dim=1) == query_labels).float().mean()

    return loss, acc

# 3. 에피소드 샘플러
class EpisodeSampler:
    def __init__(self, labels, n_way, k_shot, n_query, n_episodes):
        self.labels = labels
        self.n_way = n_way
        self.k_shot = k_shot
        self.n_query = n_query
        self.n_episodes = n_episodes
        self.classes = list(set(labels.tolist()))

    def __iter__(self):
        for _ in range(self.n_episodes):
            selected = torch.randperm(len(self.classes))[:self.n_way]
            support_idx, query_idx = [], []

            for i, c in enumerate(selected):
                c = self.classes[c]
                class_idx = (self.labels == c).nonzero(as_tuple=True)[0]
                perm = class_idx[torch.randperm(len(class_idx))]
                support_idx.extend(perm[:self.k_shot].tolist())
                query_idx.extend(perm[self.k_shot:self.k_shot + self.n_query].tolist())

            yield support_idx, query_idx

    def __len__(self):
        return self.n_episodes

최근 동향 (2024-2025)

트렌드 설명
Pre-train + Meta-train + Fine-tune 대규모 사전학습 후 메타 학습으로 few-shot 성능 극대화 (P>M>F)
Foundation Model + Few-Shot CLIP, DINOv2 등 foundation model을 few-shot backbone으로 활용
In-Context Learning as Meta-Learning LLM의 ICL을 메타 학습 관점에서 해석하는 연구
Task-Agnostic Meta-Learning 태스크 경계 없이 연속적 학습 환경에서의 메타 학습
Meta-Learning for LLM Alignment 사용자별 선호도에 빠르게 적응하는 개인화 기법
Unsupervised Meta-Learning 레이블 없이 pseudo-task를 구성하여 메타 학습
Open-Task Meta-Learning N-way K-shot 고정 설정을 넘어, 가변적 태스크 구조에 대응

한계와 과제

한계 설명
태스크 분포 의존성 메타 학습 시 태스크 분포와 테스트 태스크 분포가 달라지면 성능 저하
Cross-Domain 일반화 도메인이 크게 다른 태스크로의 전이가 어려움
스케일링 MAML 계열은 대규모 모델에 적용 시 메모리/계산 비용 문제
벤치마크 한계 mini-ImageNet 등 기존 벤치마크가 실제 few-shot 시나리오를 충분히 반영하지 못함
Foundation Model과의 경쟁 CLIP 등 대규모 모델이 zero-shot으로도 few-shot 메타 학습을 능가하는 경우 증가

참고 문헌

  1. Finn, C., Abbeel, P., & Levine, S. (2017). Model-Agnostic Meta-Learning for Fast Adaptation of Deep Networks. ICML 2017.
  2. Vinyals, O., et al. (2016). Matching Networks for One Shot Learning. NeurIPS 2016.
  3. Snell, J., Swersky, K., & Zemel, R. (2017). Prototypical Networks for Few-shot Learning. NeurIPS 2017.
  4. Sung, F., et al. (2018). Learning to Compare: Relation Network for Few-Shot Learning. CVPR 2018.
  5. Nichol, A., Achiam, J., & Schulman, J. (2018). On First-Order Meta-Learning Algorithms. arXiv:1803.02999.
  6. Hospedales, T., et al. (2022). Meta-Learning in Neural Networks: A Survey. IEEE TPAMI.
  7. Hu, S. X., et al. (2022). Pushing the Limits of Simple Pipelines for Few-Shot Learning (P>M>F). CVPR 2022.
  8. Garnelo, M., et al. (2018). Conditional Neural Processes. ICML 2018.