콘텐츠로 이동
Data Prep
상세

Multi-Task Learning

메타 정보

항목 내용
분류 Optimization / Learning Paradigm
핵심 논문 "An Overview of Multi-Task Learning in Deep Neural Networks" (Ruder, 2017), "Multi-Task Learning as Multi-Objective Optimization" (Sener & Koltun, NIPS 2018), "GradNorm" (Chen et al., ICML 2018), "Conflict-Averse Gradient Descent for MTL" (Liu et al., NeurIPS 2021)
주요 저자 Sebastian Ruder (서베이), Ozan Sener (MOO-MTL), Zhao Chen (GradNorm), Trevor Darrell (cross-stitch)
핵심 개념 관련된 여러 태스크를 동시에 학습하여 shared representation을 통해 일반화 성능을 향상시키는 학습 패러다임
관련 분야 Transfer Learning, Meta-Learning, Knowledge Distillation, Curriculum Learning

정의

Multi-Task Learning(MTL)은 여러 관련 태스크를 동시에 학습하여, 태스크 간 공유되는 표현(shared representation)을 통해 개별 태스크의 일반화 성능을 향상시키는 학습 패러다임이다.

수학적 정의

T개의 태스크가 주어졌을 때, 각 태스크 t는 데이터셋 D_t = {(x_i^t, y_i^t)}를 가진다. MTL의 목적 함수는 다음과 같다:

min_{theta_sh, theta_1, ..., theta_T}  sum_{t=1}^{T} w_t * L_t(theta_sh, theta_t)

여기서: - theta_sh: shared encoder의 파라미터 - theta_t: 태스크 t의 task-specific head 파라미터 - L_t: 태스크 t의 손실 함수 - w_t: 태스크 t의 가중치

핵심 가정

MTL이 효과적이려면 다음 조건이 필요하다:

조건 설명
Task Relatedness 태스크 간 공유 가능한 구조가 존재해야 함
Sufficient Data 각 태스크에 충분한 학습 데이터 필요
Balanced Difficulty 태스크 간 난이도가 극단적으로 차이나지 않아야 함
Compatible Gradients 태스크 간 gradient가 과도하게 충돌하지 않아야 함

Single-Task Learning과의 비교

Single-Task Learning (STL):

  Task 1:  Input --> [Encoder_1] --> [Head_1] --> Output_1
  Task 2:  Input --> [Encoder_2] --> [Head_2] --> Output_2
  Task 3:  Input --> [Encoder_3] --> [Head_3] --> Output_3

  * 독립적 모델, 파라미터 공유 없음
  * 총 파라미터 = 3 * (Encoder + Head)


Multi-Task Learning (MTL):

                        +--> [Head_1] --> Output_1
                        |
  Input --> [Shared     +--> [Head_2] --> Output_2
             Encoder]   |
                        +--> [Head_3] --> Output_3

  * 공유 인코더 + 태스크별 헤드
  * 총 파라미터 = Encoder + 3 * Head

Parameter Sharing 방식

Hard Parameter Sharing

가장 기본적이고 널리 사용되는 방식이다. 은닉층(hidden layers)을 모든 태스크가 공유하고, 출력층(output layers)만 태스크별로 분리한다.

            Hard Parameter Sharing

  Input
    |
    v
+-------------------+
|   Shared Layer 1  |   <-- 모든 태스크 공유
+-------------------+
    |
    v
+-------------------+
|   Shared Layer 2  |   <-- 모든 태스크 공유
+-------------------+
    |
    v
+-------------------+
|   Shared Layer 3  |   <-- 모든 태스크 공유
+-------------------+
    |         |         |
    v         v         v
+------+  +------+  +------+
|Head 1|  |Head 2|  |Head 3|   <-- 태스크별 분리
+------+  +------+  +------+
    |         |         |
    v         v         v
  Out_1     Out_2     Out_3

특징: - Overfitting 감소 효과 (Baxter, 2000): 공유 파라미터의 overfitting 위험이 O(T*N/T) = O(N)에서 감소 - 구현이 단순하고 파라미터 효율적 - 태스크 간 관련성이 낮으면 성능 저하 가능 (negative transfer)

Soft Parameter Sharing

각 태스크가 자체 파라미터를 가지되, 파라미터 간 유사성을 regularization으로 강제한다.

            Soft Parameter Sharing

  Input             Input             Input
    |                 |                 |
    v                 v                 v
+--------+       +--------+       +--------+
|Layer 1a| <---> |Layer 1b| <---> |Layer 1c|
+--------+  reg  +--------+  reg  +--------+
    |                 |                 |
    v                 v                 v
+--------+       +--------+       +--------+
|Layer 2a| <---> |Layer 2b| <---> |Layer 2c|
+--------+  reg  +--------+  reg  +--------+
    |                 |                 |
    v                 v                 v
+--------+       +--------+       +--------+
| Head 1 |       | Head 2 |       | Head 3 |
+--------+       +--------+       +--------+

  <---> : regularization constraint
          (L2 distance, trace norm 등)

Regularization 수식:

L_total = sum_t L_t + lambda * sum_{i,j} ||W_t^i - W_t^j||_F^2

여기서 ||.||_F는 Frobenius norm이다.

Hard vs Soft 비교

항목 Hard Sharing Soft Sharing
파라미터 수 적음 (공유) 많음 (태스크별 별도)
구현 복잡도 낮음 높음
유연성 낮음 높음
Negative transfer 취약 상대적으로 강건
Overfitting 방지 강함 보통
대표 사례 대부분의 MTL Cross-stitch, Sluice

아키텍처

Cross-Stitch Networks (Misra et al., CVPR 2016)

두 태스크의 activation map을 선형 결합하여 각 태스크에 전달하는 방식이다.

Task A Layer l      Task B Layer l
     |                   |
     v                   v
  [x_A^l]            [x_B^l]
     \       cross      /
      \     stitch     /
       v   unit       v
  +---------------------+
  | alpha_AA  alpha_AB  |
  | alpha_BA  alpha_BB  |
  +---------------------+
       |             |
       v             v
  [x_A^{l+1}]   [x_B^{l+1}]

수식:

[x_A^{l+1}]   [alpha_AA  alpha_AB] [x_A^l]
[x_B^{l+1}] = [alpha_BA  alpha_BB] [x_B^l]

alpha 값은 학습을 통해 결정되며, 어떤 레이어에서 얼마나 공유할지 자동으로 학습한다.

Multi-Task Attention Network (MTAN, Liu et al., CVPR 2019)

Shared encoder에 task-specific attention module을 붙여 각 태스크가 shared feature에서 필요한 부분을 선택적으로 추출한다.

Input --> [Shared Encoder] --> Shared Features (F)
                                    |
                  +-----------------+-----------------+
                  |                 |                 |
                  v                 v                 v
           [Attention_1]    [Attention_2]    [Attention_3]
                  |                 |                 |
                  v                 v                 v
           Task 1 feat      Task 2 feat      Task 3 feat
                  |                 |                 |
                  v                 v                 v
              [Head_1]         [Head_2]         [Head_3]

Attention 수식:

a_t^k = sigmoid(W_t^k * F^k + b_t^k)
F_t^k = a_t^k (element-wise) F^k

여기서 a_t^k는 태스크 t의 k번째 레이어에서의 attention mask이다.

NDDR-CNN (Gao et al., CVPR 2019)

Neural Discriminative Dimensionality Reduction을 이용하여 태스크별 feature를 학습 가능한 방식으로 결합한다.

Task A feature map     Task B feature map
       |                      |
       v                      v
  [concat along channel dim]
       |
       v
  [1x1 Conv + BN + ReLU]   <-- 학습 가능한 결합
       |
       v
  Task-specific feature

아키텍처 비교

아키텍처 공유 방식 학습 가능 여부 태스크 수 제한 오버헤드
Hard Sharing 고정 X 무제한 매우 낮음
Cross-Stitch 선형 결합 O (alpha) 2 (원래) 낮음
MTAN Attention O 무제한 중간
NDDR-CNN 1x1 Conv O 2+ 낮음
Sluice 조합 O 2+ 중간

손실 함수 밸런싱

MTL의 핵심 문제 중 하나는 여러 태스크의 손실 함수를 어떻게 결합할 것인가이다. 단순 합산(equal weighting)은 태스크 간 스케일 차이, 학습 속도 차이로 인해 최적이 아닌 경우가 많다.

Uniform Weighting (Baseline)

L_total = sum_{t=1}^{T} L_t

모든 태스크에 동일한 가중치를 부여한다. 단순하지만 태스크 간 스케일 차이가 크면 특정 태스크가 학습을 지배할 수 있다.

Uncertainty Weighting (Kendall et al., CVPR 2018)

Homoscedastic uncertainty를 이용하여 태스크별 가중치를 자동으로 학습한다.

L_total = sum_{t=1}^{T} (1 / (2 * sigma_t^2)) * L_t + log(sigma_t)

여기서 sigma_t는 태스크 t의 관측 noise를 나타내는 학습 가능한 파라미터다. Noise가 큰 태스크(어려운 태스크)의 가중치가 자동으로 낮아진다. log(sigma_t) 항은 모든 sigma가 무한대로 발산하는 것을 방지한다.

구현 팁: 실제로는 log(sigma_t^2)를 직접 학습시키는 것이 수치적으로 안정적이다.

# log_var = log(sigma^2)
log_var_t = nn.Parameter(torch.zeros(T))

# loss for task t
loss_t = (1 / (2 * torch.exp(log_var_t[t]))) * task_loss_t + 0.5 * log_var_t[t]

GradNorm (Chen et al., ICML 2018)

Gradient의 크기(norm)를 동적으로 조절하여 태스크 간 학습 속도를 균형 있게 맞춘다.

핵심 아이디어: 1. 각 태스크의 gradient norm이 비슷하도록 가중치 조절 2. 학습이 느린 태스크(상대적 loss가 큰 태스크)에 더 큰 가중치 부여

알고리즘:

1. 각 태스크 t에 대해 gradient norm 계산:
   G_t = ||nabla_{W_sh} w_t * L_t||_2

2. 평균 gradient norm:
   G_avg = E[G_t]

3. 상대적 역 학습 속도:
   r_t = L_t(epoch) / L_t(0)      (현재 loss / 초기 loss)
   r_avg = E[r_t]

4. 타겟 gradient norm:
   G_target_t = G_avg * (r_t / r_avg)^alpha

5. 가중치 업데이트 (L1 loss로):
   L_grad = sum_t |G_t - G_target_t|
   w_t <- w_t - lr_w * nabla_{w_t} L_grad

6. 가중치 정규화:
   w_t <- w_t * T / sum(w_t)

alpha는 하이퍼파라미터로, 높을수록 태스크 간 학습 속도 균형을 강하게 맞춘다. 논문에서는 alpha = 1.5를 권장한다.

PCGrad (Yu et al., NeurIPS 2020)

Projecting Conflicting Gradients: 태스크 간 gradient가 충돌(cosine similarity < 0)할 때, 충돌하는 성분을 제거한다.

알고리즘:

for each task i:
    g_i = nabla L_i
    for each other task j (random order):
        if g_i . g_j < 0:   (충돌 감지)
            g_i = g_i - (g_i . g_j / ||g_j||^2) * g_j   (투영)
    use modified g_i for update

기하학적 해석:

       g_j
        ^
        |        g_i (원래)
        |       /
        |      /
        |     /
        |----/-------->  g_i' (투영 후)
        |
        |

g_i에서 g_j와 충돌하는 성분(g_j 방향의 음수 투영)을 제거하여, 다른 태스크에 해를 끼치지 않는 방향으로만 업데이트한다.

CAGrad (Liu et al., NeurIPS 2021)

Conflict-Averse Gradient descent: 평균 gradient 방향을 유지하면서, 모든 태스크의 최소 improvement를 최대화하는 방향을 찾는다.

목적 함수:

max_d  min_t  <g_t, d>
s.t.   ||d - g_avg|| <= c * ||g_avg||

여기서 g_avg = (1/T) sum g_t이고, c는 평균 gradient에서 벗어날 수 있는 범위를 제어하는 하이퍼파라미터이다.

이 문제는 dual form으로 변환하여 효율적으로 풀 수 있다.

Nash-MTL (Navon et al., ICML 2022)

MTL을 Nash bargaining game으로 모델링한다. 각 태스크를 player로, gradient를 action으로 보고 Nash bargaining solution을 구한다.

목적 함수:

max_{alpha}  sum_{t=1}^{T} log(alpha^T * g_t)
s.t.         alpha >= 0, ||alpha||_1 = 1

이는 각 태스크의 improvement의 로그합을 최대화하는 것으로, proportional fairness를 보장한다.

밸런싱 기법 비교

기법 핵심 아이디어 연산 오버헤드 하이퍼파라미터 성능 (NYUv2)
Uniform 동일 가중치 없음 없음 baseline
Uncertainty noise로 가중치 매우 낮음 없음 +
GradNorm gradient norm 균등화 낮음 alpha ++
PCGrad 충돌 gradient 투영 중간 없음 ++
CAGrad min-improvement 최대화 중간 c +++
Nash-MTL Nash bargaining 높음 없음 +++

Negative Transfer

정의

Negative transfer는 MTL에서 태스크를 함께 학습한 결과, 일부 또는 전체 태스크의 성능이 단독 학습(STL) 대비 오히려 저하되는 현상이다.

원인

원인 설명
Task Conflict 태스크 간 최적 representation이 서로 호환 불가
Gradient Conflict 태스크별 gradient가 반대 방향을 가리킴
Capacity Bottleneck 공유 네트워크의 용량이 부족하여 모든 태스크를 수용 불가
Dominance 특정 태스크가 학습을 지배하여 다른 태스크가 희생됨
Label Noise Propagation 한 태스크의 noisy label이 shared representation을 오염

해결책

  1. 태스크 그룹핑: 관련 태스크끼리 묶어서 학습 (task grouping)
  2. Gradient 조작: PCGrad, CAGrad 등으로 gradient 충돌 완화
  3. Soft Sharing: Hard sharing 대신 soft sharing으로 태스크 간 간섭 감소
  4. 모델 용량 증가: shared encoder의 크기를 늘려 여러 태스크를 수용
  5. Auxiliary Task 선택: 보조 태스크를 신중하게 선택하고, 도움이 되지 않는 태스크는 제거
  6. Dynamic Weighting: 학습 과정에서 태스크 가중치를 동적으로 조절

Negative Transfer 감지

if perf_MTL(task_t) < perf_STL(task_t):
    negative_transfer = True
    delta = perf_MTL(task_t) - perf_STL(task_t)
    print(f"Task {t}: negative transfer = {delta}")

도메인별 적용

NLP

T5 (Raffel et al., JMLR 2020): - Text-to-text 프레임워크로 번역, 요약, QA, 분류 등을 동시 학습 - 모든 NLP 태스크를 "text in, text out" 형태로 통일 - 태스크별 prefix (예: "translate English to German:", "summarize:") 사용

MT-DNN (Liu et al., ACL 2019): - BERT 기반 MTL로 GLUE 벤치마크 태스크를 동시 학습 - Shared BERT encoder + task-specific output layers - 학습 시 태스크를 mini-batch 단위로 번갈아 선택

MT-DNN 구조:

  Input Text --> [WordPiece Tokenizer]
                        |
                        v
               [Shared BERT Encoder]
                        |
           +------------+------------+
           |            |            |
           v            v            v
     [NLI Head]   [QA Head]   [Sentiment Head]
     (softmax)    (span)      (softmax)

Computer Vision (자율주행)

자율주행은 대표적인 MTL 응용 분야이다. 단일 카메라 입력에서 여러 태스크를 동시에 수행한다:

태스크 출력 형태 Loss
Semantic Segmentation 픽셀별 클래스 Cross-entropy
Depth Estimation 픽셀별 깊이값 L1 / Berhu
Surface Normal 픽셀별 법선 벡터 Cosine similarity
Object Detection Bounding box Focal + L1

NYUv2 / Cityscapes 데이터셋이 MTL 벤치마크로 널리 사용된다.

LLM: Instruction Tuning as MTL

현대의 instruction tuning은 사실상 MTL이다. 다양한 태스크의 instruction 데이터를 함께 학습한다.

Instruction Tuning = MTL

태스크 1 (번역):     "Translate to Korean: Hello" --> "안녕하세요"
태스크 2 (요약):     "Summarize: ..."              --> "요약 결과"
태스크 3 (코드):     "Write Python code for ..."   --> "def ..."
태스크 4 (수학):     "Solve: 2x + 3 = 7"          --> "x = 2"
태스크 5 (대화):     "User: Hi\nAssistant:"        --> "Hello!"
...

모든 태스크를 하나의 LLM이 next-token prediction으로 학습
--> implicit hard parameter sharing

FLAN (Wei et al., ICLR 2022): - 1,800+ 태스크를 instruction 형태로 변환하여 학습 - Zero-shot / few-shot 성능 향상


Task Affinity

어떤 태스크를 함께 학습할 것인가

태스크 조합에 따라 MTL의 효과가 크게 달라지므로, 태스크 간 affinity(친화도)를 측정하는 것이 중요하다.

측정 방법

방법 설명 비용
Brute Force 모든 조합을 학습하여 비교 O(2^T) - 비현실적
Task2Vec (Achille et al., 2019) Fisher information으로 태스크 임베딩 중간
TAG (Fifty et al., ICML 2021) Inter-task affinity 그래프 중간
Gradient Cosine Similarity 태스크 간 gradient 유사도 낮음

Gradient Cosine Similarity

가장 간단한 방법: 두 태스크의 gradient 방향이 유사하면 함께 학습하기 좋다.

affinity(t_i, t_j) = cos(nabla L_i, nabla L_j)
                   = (nabla L_i . nabla L_j) / (||nabla L_i|| * ||nabla L_j||)

affinity > 0   -->  태스크 호환 (함께 학습 가능)
affinity ~ 0   -->  태스크 독립 (함께 학습해도 크게 영향 없음)
affinity < 0   -->  태스크 충돌 (negative transfer 위험)

Task Grouping

T개의 태스크를 K개의 그룹으로 나누어, 그룹 내에서만 MTL을 수행하는 전략이다.

예시: 5개 태스크, 2개 그룹

Affinity Matrix:
         T1    T2    T3    T4    T5
   T1  [ 1.0   0.8   0.2   0.1  -0.3 ]
   T2  [ 0.8   1.0   0.3   0.2  -0.1 ]
   T3  [ 0.2   0.3   1.0   0.9   0.7 ]
   T4  [ 0.1   0.2   0.9   1.0   0.8 ]
   T5  [-0.3  -0.1   0.7   0.8   1.0 ]

Grouping 결과:
  Group 1: {T1, T2}      --> 높은 affinity
  Group 2: {T3, T4, T5}  --> 높은 affinity

PyTorch 구현 예시

Hard Parameter Sharing

import torch
import torch.nn as nn


class SharedEncoder(nn.Module):
    """태스크 간 공유되는 인코더"""

    def __init__(self, input_dim, hidden_dim):
        super().__init__()
        self.layers = nn.Sequential(
            nn.Linear(input_dim, hidden_dim),
            nn.ReLU(),
            nn.BatchNorm1d(hidden_dim),
            nn.Linear(hidden_dim, hidden_dim),
            nn.ReLU(),
            nn.BatchNorm1d(hidden_dim),
            nn.Linear(hidden_dim, hidden_dim),
            nn.ReLU(),
        )

    def forward(self, x):
        return self.layers(x)


class TaskHead(nn.Module):
    """태스크별 출력 헤드"""

    def __init__(self, hidden_dim, output_dim):
        super().__init__()
        self.head = nn.Sequential(
            nn.Linear(hidden_dim, hidden_dim // 2),
            nn.ReLU(),
            nn.Linear(hidden_dim // 2, output_dim),
        )

    def forward(self, x):
        return self.head(x)


class MultiTaskModel(nn.Module):
    """Multi-Task Learning 모델 (Hard Parameter Sharing)"""

    def __init__(self, input_dim, hidden_dim, task_output_dims):
        super().__init__()
        self.encoder = SharedEncoder(input_dim, hidden_dim)
        self.task_heads = nn.ModuleDict({
            name: TaskHead(hidden_dim, out_dim)
            for name, out_dim in task_output_dims.items()
        })

    def forward(self, x, task_name=None):
        shared_repr = self.encoder(x)

        if task_name is not None:
            return self.task_heads[task_name](shared_repr)

        # 모든 태스크 출력
        return {
            name: head(shared_repr)
            for name, head in self.task_heads.items()
        }


# --- 학습 루프 ---

def train_mtl(model, dataloaders, loss_fns, optimizer, epochs=100):
    """
    dataloaders: dict of {task_name: DataLoader}
    loss_fns: dict of {task_name: loss_function}
    """
    for epoch in range(epochs):
        model.train()
        total_loss = 0.0
        task_losses = {}

        # 모든 태스크의 배치를 가져옴
        iterators = {name: iter(dl) for name, dl in dataloaders.items()}

        while True:
            optimizer.zero_grad()
            batch_loss = 0.0
            active = False

            for task_name, it in iterators.items():
                try:
                    x, y = next(it)
                except StopIteration:
                    continue

                active = True
                pred = model(x, task_name)
                loss = loss_fns[task_name](pred, y)
                batch_loss += loss
                task_losses[task_name] = loss.item()

            if not active:
                break

            batch_loss.backward()
            optimizer.step()
            total_loss += batch_loss.item()

        if (epoch + 1) % 10 == 0:
            print(f"Epoch {epoch+1}: total_loss={total_loss:.4f}")
            for name, l in task_losses.items():
                print(f"  {name}: {l:.4f}")


# --- 사용 예시 ---

model = MultiTaskModel(
    input_dim=128,
    hidden_dim=256,
    task_output_dims={
        "classification": 10,    # 10-class classification
        "regression": 1,         # regression
        "segmentation": 64,      # 64-dim segmentation
    }
)

loss_fns = {
    "classification": nn.CrossEntropyLoss(),
    "regression": nn.MSELoss(),
    "segmentation": nn.CrossEntropyLoss(),
}

optimizer = torch.optim.Adam(model.parameters(), lr=1e-3)

Uncertainty Weighting 구현

class UncertaintyWeightedLoss(nn.Module):
    """Kendall et al. (2018) Uncertainty Weighting"""

    def __init__(self, num_tasks):
        super().__init__()
        # log(sigma^2) 초기화
        self.log_vars = nn.Parameter(torch.zeros(num_tasks))

    def forward(self, losses):
        """
        losses: list of scalar tensors (각 태스크의 loss)
        """
        total = 0.0
        for i, loss in enumerate(losses):
            precision = torch.exp(-self.log_vars[i])
            total += precision * loss + self.log_vars[i]
        return total

장단점 비교표

항목 장점 단점
일반화 Inductive bias로 overfitting 감소 태스크 선택을 잘못하면 역효과
효율성 하나의 모델로 여러 태스크 수행, 파라미터/연산 절약 밸런싱이 어려움, 튜닝 비용
데이터 다른 태스크의 데이터가 regularizer 역할 태스크별 데이터 불균형 문제
성능 관련 태스크 간 시너지 효과 Negative transfer 위험
배포 단일 모델 서빙으로 인프라 단순화 특정 태스크만 업데이트하기 어려움
학습 Auxiliary task로 주 태스크 성능 향상 가능 최적 태스크 조합 탐색이 비쌈
표현 학습 더 풍부하고 일반적인 representation 학습 Capacity 부족 시 모든 태스크 성능 저하

핵심 논문 목록

논문 연도 기여
Caruana, "Multitask Learning" 1997 MTL 개념 정립
Ruder, "An Overview of MTL in Deep NNs" 2017 포괄적 서베이
Kendall et al., "Multi-Task Learning Using Uncertainty" 2018 Uncertainty weighting
Chen et al., "GradNorm" 2018 Gradient norm 밸런싱
Sener & Koltun, "MTL as MOO" 2018 Multi-objective 관점
Yu et al., "PCGrad" 2020 Gradient 충돌 해소
Liu et al., "CAGrad" 2021 Conflict-averse optimization
Navon et al., "Nash-MTL" 2022 Nash bargaining 기반

마지막 업데이트: 2026-03-25