Reducing Activation Recomputation in Large Transformer Models

Distributed Training Series (3/4)
  1. Tensor Parallel
  2. Tensor Parallel 구현 비교
  3. Pipeline Parallel (GPipe)
  4. Reducing Activation Recomputation in Large Transformer Models

TL;DR

본 논문은 대규모 transformer 모델 학습에서 activation recomputation의 계산 오버헤드를 줄이기 위한 두 가지 상호보완적 기법—sequence parallelism과 selective activation recomputation—을 제시합니다. 핵심 혁신은 tensor parallelism이 처리하지 않는 LayerNorm, Dropout 등의 연산을 시퀀스 차원으로 분할하고, attention 메커니즘 내의 고메모리·저계산 연산만 선택적으로 재계산하여 전체 재계산 대비 최소한의 오버헤드로 동등한 메모리 절감을 달성하는 것입니다. 22B에서 1T 파라미터까지 일관되게 활성화 메모리를 5배 절감하고 학습 처리량을 약 30% 향상시키며, 2240개 A100 GPU에서 54% GPU 활용률로 1조 파라미터 모델 학습을 가능하게 합니다. 중요한 제약은 시퀀스 길이가 tensor parallel 크기로 나누어떨어져야 하며, 데이터 병렬화를 사용하지 않는 실험 설정이 프로덕션 학습 환경을 완전히 반영하지 못한다는 점입니다.


Related Papers

  • GPipe - 마이크로 배치 파이프라인 병렬화를 통한 대규모 모델 학습
  • Megatron-LM - 텐서 모델 병렬화를 통한 대규모 Transformer 학습
  • How to Scale Your Model: Transformers - TPU 성능과 Transformer 스케일링 가이드
  • TorchTitan - PyTorch 네이티브 LLM 사전 학습 솔루션
  • veScale - Eager Mode SPMD 기반 텐서 프로그래밍

Takeaways

1. Contribution

대규모 transformer 모델 학습에서 activation memory는 가장 심각한 병목 중 하나입니다. 수십억에서 수조 파라미터 모델의 학습 과정에서 역전파를 위해 저장해야 하는 중간 활성화(activation)는 모델 파라미터 자체보다 훨씬 많은 메모리를 소비합니다. 이 문제에 대한 표준 해결책인 activation recomputation(gradient checkpointing)은 중간 활성화를 버리고 역전파 시 재계산하는 것으로, 메모리를 극적으로 절약하지만 순전파를 사실상 두 번 수행해야 하므로 30-40%의 학습 시간 오버헤드를 추가합니다. 이는 수천 GPU에서 수주에 걸쳐 실행되는 대규모 학습에서 막대한 비용 증가를 의미합니다.

본 논문 이전에, activation memory를 줄이기 위한 기존 접근법들은 각각 한계를 가지고 있었습니다. Tensor parallelism(Megatron-LM)은 MLP와 self-attention의 GEMM 연산을 GPU 간에 분할하여 해당 연산의 활성화를 분산시키지만, LayerNorm, Dropout 같은 비텐서병렬 연산의 활성화는 모든 GPU에서 중복 저장됩니다. 이 중복 저장되는 활성화는 모델이 커질수록 무시할 수 없는 비율을 차지합니다. 한편, 전체 activation recomputation은 메모리 문제를 해결하지만 앞서 언급한 대로 심각한 계산 오버헤드를 초래합니다. 이러한 상황에서 연구자들과 실무자들은 “메모리를 아끼면 속도를 잃고, 속도를 유지하면 메모리가 부족한” 딜레마에 직면했습니다.

본 논문의 핵심 기여는 이 딜레마를 두 가지 상호보완적 기법으로 해결하는 것입니다. 첫째, sequence parallelism은 tensor parallelism이 처리하지 않는 연산(LayerNorm, Dropout, 기타 element-wise 연산)에서 시퀀스 차원을 따라 활성화를 GPU 간에 분할합니다. 이 기법의 핵심 통찰은 all-reduce 연산을 reduce-scatter와 all-gather로 분해하면 총 통신량이 동일하면서도 각 GPU가 전체 활성화의 일부분만 저장하면 된다는 것입니다. 둘째, selective activation recomputation은 모든 활성화를 일괄적으로 재계산하는 대신, 메모리 소비는 크지만 재계산 비용은 작은 특정 연산(attention 내의 QK^T, softmax, dropout)만 선택적으로 재계산합니다. 이 두 기법을 결합하면 전체 activation recomputation과 동등한 5배 메모리 절감을 달성하면서도 계산 오버헤드를 39%에서 단 4%로 극적으로 줄입니다.

실용적 영향은 상당합니다. 2240개 A100 GPU에서 1조 파라미터 모델을 54% GPU 활용률로 학습할 수 있게 되었으며, 이는 이전에는 메모리 제약으로 불가능했던 규모입니다. 22B에서 1T까지 다양한 크기의 모델에서 일관되게 약 30%의 학습 처리량 향상을 보이며, 이는 대규모 학습의 비용을 직접적으로 30% 절감하는 것과 동일한 효과입니다. 이 기법들은 Megatron-LM 프레임워크에 통합되어 프로덕션 환경에서 즉시 사용 가능하며, 이후 대부분의 대규모 transformer 학습 시스템의 표준 구성 요소가 되었습니다.

특히 주목할 만한 점은 sequence parallelism이 통신을 추가함에도 불구하고 오히려 학습을 가속화한다는 발견입니다. 이는 LayerNorm과 Dropout이 예상보다 계산적으로 비싸다는 것을 암시하며, 더 작은 텐서에 대한 연산이 GPU의 메모리 서브시스템을 더 효율적으로 활용할 수 있음을 보여줍니다. 이 발견은 분산 시스템 최적화에서 단순히 통신을 최소화하는 것이 아니라, 메모리 접근 패턴과 GPU 활용률을 함께 고려해야 한다는 중요한 교훈을 제공합니다.

2. Methodology

2.1 Core Intuition

본 논문의 근본적인 통찰은 대규모 transformer 학습에서 activation memory의 구성을 정밀하게 분석하고, 각 구성 요소에 최적화된 전략을 적용하는 “연산 인식(operation-aware)” 접근법에 있습니다.

Tensor parallelism이 적용된 transformer 레이어에서 활성화 메모리를 분석하면 두 가지 범주로 나뉩니다. 첫째, tensor parallel 영역 내의 활성화—MLP의 GEMM 연산과 self-attention의 GEMM 연산에 의해 생성되는 활성화로, 이미 $t$개의 GPU에 걸쳐 분할되어 있습니다. 둘째, tensor parallel 영역 외부의 활성화—LayerNorm의 입력과 출력, Dropout의 마스크, residual connection의 입력 등으로, 모든 GPU에서 중복 저장됩니다. 핵심 관찰은 이 두 번째 범주가 전체 활성화 메모리에서 상당한 비율을 차지하며, 모델이 커질수록 그 비율이 증가한다는 것입니다.

Sequence parallelism의 핵심 아이디어는 이러한 비텐서병렬 연산들이 시퀀스 차원을 따라 독립적이라는 사실을 활용하는 것입니다. LayerNorm은 hidden dimension에 대해 정규화하므로 시퀀스의 각 위치가 독립적으로 계산될 수 있고, Dropout은 각 원소에 독립적으로 적용되며, element-wise 연산은 정의상 독립적입니다. 따라서 이러한 연산의 입출력을 시퀀스 차원을 따라 $t$개의 GPU에 분할할 수 있습니다.

이 분할을 가능하게 하는 수학적 기반은 all-reduce 연산의 분해입니다. 기존 tensor parallelism에서 forward pass의 $g$ 연산자는 all-reduce를 수행합니다. All-reduce는 수학적으로 reduce-scatter와 all-gather의 합성으로 분해될 수 있으며, 총 통신량은 동일합니다:

\[\text{all-reduce}(X) = \text{all-gather}(\text{reduce-scatter}(X))\]

이 분해를 활용하면, tensor parallel 영역의 출력에서 all-reduce 대신 reduce-scatter를 수행하여 시퀀스 차원으로 분할된 결과를 얻고, 다음 tensor parallel 영역의 입력에서 all-gather를 수행하여 전체 시퀀스를 복원할 수 있습니다. 이 과정에서 tensor parallel 영역 사이의 모든 연산(LayerNorm, Dropout, residual addition)은 시퀀스 분할된 상태에서 수행되므로, 각 GPU는 전체 시퀀스의 $1/t$만큼의 활성화만 저장하면 됩니다.

Selective activation recomputation의 직관은 더 직접적입니다. Transformer 레이어의 활성화를 메모리 소비량과 재계산 비용으로 분류하면, attention 메커니즘 내의 특정 연산들—QK^T 행렬 곱셈의 출력, softmax의 출력, dropout 마스크—이 전체 활성화 메모리의 대부분을 차지하지만 재계산 비용은 매우 작다는 것을 발견합니다. 이는 이러한 연산이 시퀀스 길이의 제곱에 비례하는 텐서($s \times s$)를 생성하는 반면, 실제 계산은 주로 element-wise 연산(softmax, dropout)이기 때문입니다. 따라서 이 연산들만 선택적으로 재계산하면 최소한의 계산 비용으로 대부분의 메모리를 절약할 수 있습니다.

2.2 Model Architecture

본 논문은 Megatron-LM의 tensor parallelism 프레임워크를 확장하여 sequence parallelism을 통합합니다. 수정된 transformer 레이어의 데이터 흐름은 다음과 같습니다:

입력 X_sp (시퀀스 분할: [s/t, b, h])     ← 시퀀스 병렬 영역
  |
  v
LayerNorm (시퀀스 분할 상태에서)          ← 각 GPU: s/t 위치만 처리
  |
  v
[all-gather] → X_full [s, b, h]          ← 텐서 병렬 영역 진입
  |
  +---> Self-Attention (텐서 병렬)
  |      Q,K,V 프로젝션: [s, b, h/t]
  |      Attention 계산: 각 GPU는 h/t 차원의 head 처리
  |      출력 프로젝션: [s, b, h] (부분합)
  |
  +---> [reduce-scatter] → [s/t, b, h]   ← 시퀀스 병렬 영역 복귀
  |
  v
Dropout + Residual (시퀀스 분할 상태에서)  ← 각 GPU: s/t 위치만 처리
  |
  v
LayerNorm (시퀀스 분할 상태에서)
  |
  v
[all-gather] → X_full [s, b, h]          ← 텐서 병렬 영역 진입
  |
  +---> MLP (텐서 병렬)
  |      첫 번째 GEMM: [s, b, 4h/t] (열 분할)
  |      GeLU: 각 파티션 독립 적용
  |      두 번째 GEMM: [s, b, h] (부분합)
  |
  +---> [reduce-scatter] → [s/t, b, h]   ← 시퀀스 병렬 영역 복귀
  |
  v
Dropout + Residual (시퀀스 분할 상태에서)
  |
  v
출력 Y_sp (시퀀스 분할: [s/t, b, h])

기존 tensor parallelism 대비 핵심 변경 사항은 다음과 같습니다. 첫째, $g$ 연산자의 forward pass에서 all-reduce가 reduce-scatter로 대체됩니다. 이는 텐서 병렬 영역의 출력을 시퀀스 차원으로 분할합니다. 둘째, $f$ 연산자의 forward pass에서 identity가 all-gather로 대체됩니다. 이는 시퀀스 분할된 입력을 전체 시퀀스로 복원하여 텐서 병렬 GEMM에 공급합니다. Backward pass에서는 각각의 켤레 연산이 수행됩니다: reduce-scatter의 역은 all-gather, all-gather의 역은 reduce-scatter입니다.

통신량의 불변성은 다음과 같이 증명됩니다. 기존 tensor parallelism에서 각 transformer 레이어의 forward pass는 2회의 all-reduce를 수행합니다. 크기 $M$의 메시지에 대해, all-reduce의 통신량은 $2M \cdot \frac{t-1}{t}$입니다(ring-based 구현 기준). Sequence parallelism에서는 2회의 reduce-scatter와 2회의 all-gather를 수행하며, 각각의 통신량은 $M \cdot \frac{t-1}{t}$입니다. 따라서 총 통신량은:

\[2 \times M \cdot \frac{t-1}{t} + 2 \times M \cdot \frac{t-1}{t} = 4M \cdot \frac{t-1}{t} = 2 \times 2M \cdot \frac{t-1}{t}\]

이는 기존 2회의 all-reduce와 정확히 동일합니다.

2.3 Key Algorithms & Mechanisms

Activation Memory 분석. Transformer 레이어당 활성화 메모리를 정밀하게 분석하기 위해 논문은 각 연산의 활성화를 개별적으로 추적합니다. 시퀀스 길이 $s$, 마이크로 배치 크기 $b$, hidden dimension $h$, attention head 수 $a$를 갖는 표준 transformer 레이어에서, 혼합 정밀도 학습(FP16 forward, FP32 gradient 누적)을 가정하면 레이어당 총 활성화 메모리는:

\[\text{Activation Memory} = sbh \left(34 + 5 \cdot \frac{as}{h}\right) \text{ bytes}\]

이 공식에서 $34sbh$ 항은 linear 연산(GEMM)의 활성화를 나타내고, $5as^2b$ 항은 attention 메커니즘 내의 활성화($QK^T$ 결과, softmax 출력, dropout 마스크)를 나타냅니다. $34sbh$ 항의 세부 구성은:

연산 저장 텐서 메모리 (bytes)
LayerNorm (×2) 입력 $4sbh$
QKV 프로젝션 입력, 출력 $2sbh + 6sbh/t$
출력 프로젝션 입력 $2sbh/t$
MLP 첫 번째 GEMM 입력 $2sbh$
GeLU 입력 $8sbh/t$
MLP 두 번째 GEMM 입력 $8sbh/t$
Dropout (×2) 마스크 $2sbh$
Residual 입력 $4sbh$

$5as^2b$ 항의 구성:

연산 저장 텐서 메모리 (bytes)
$QK^T$ 결과 $[a/t, s, s]$ $2as^2b/t$
Softmax 출력 $[a/t, s, s]$ $2as^2b/t$
Attention dropout 마스크 $[a/t, s, s]$ $as^2b/t$

Sequence Parallelism의 메모리 절감. Tensor parallelism만 적용된 경우, 레이어당 활성화 메모리는:

\[\text{TP only} = sbh \left(10 + \frac{24}{t} + 5 \cdot \frac{as}{ht}\right)\]

여기서 $10sbh$는 tensor parallel 영역 외부의 활성화(모든 GPU에 중복), $\frac{24sbh}{t}$는 tensor parallel 영역 내부의 활성화(GPU 간 분할), $\frac{5as^2b}{t}$는 attention 활성화(GPU 간 분할)입니다. Sequence parallelism을 추가하면:

\[\text{TP + SP} = sbh \left(\frac{10}{t} + \frac{24}{t} + 5 \cdot \frac{as}{ht}\right) = \frac{sbh}{t}\left(34 + 5 \cdot \frac{as}{h}\right)\]

즉, 모든 활성화가 $t$개의 GPU에 걸쳐 균등하게 분할됩니다. $10sbh$의 중복 저장이 $10sbh/t$로 줄어들며, 이는 $t = 8$일 때 $8.75sbh$의 메모리 절감에 해당합니다.

Selective Activation Recomputation. 전체 activation recomputation은 각 transformer 레이어의 경계 활성화만 저장하고 나머지를 모두 재계산합니다. Selective recomputation은 이를 세분화하여, 높은 메모리-계산 비율을 갖는 연산의 활성화만 재계산 대상으로 선택합니다.

구체적으로, attention 메커니즘 내의 $QK^T$ 결과, softmax 출력, attention dropout 마스크를 재계산 대상으로 지정합니다. 이 연산들의 활성화 메모리는 $5as^2b/t$ bytes이며, 이는 시퀀스 길이의 제곱에 비례하여 긴 시퀀스에서 지배적입니다. 반면 이들의 재계산 비용은 전체 transformer 레이어 계산의 약 3%에 불과합니다. 이는 어텐션 점수 계산이 주로 softmax(element-wise)와 dropout(element-wise)으로 구성되어 계산적으로 가벼운 반면, 생성하는 텐서의 크기는 $[a, s, s]$로 매우 크기 때문입니다.

Selective recomputation 후 레이어당 활성화 메모리:

\[\text{TP + SP + Selective} = \frac{34sbh}{t}\]

$5as^2b/t$ 항이 완전히 제거되어, 시퀀스 길이의 제곱 스케일링이 사라집니다. 이는 매우 긴 시퀀스에서 특히 중요한 결과입니다.

전체 대비 선택적 재계산의 비교. 전체 activation recomputation을 적용하면:

\[\text{Full Recomputation} = \frac{2sbh}{t}\]

이는 각 레이어의 입력만 저장하면 되므로 메모리가 최소화되지만, 전체 forward pass를 재실행해야 하므로 약 33%의 계산 오버헤드(forward pass는 전체의 약 1/3)가 추가됩니다. 선택적 재계산은 $\frac{34sbh}{t}$로 전체 재계산($\frac{2sbh}{t}$)보다 17배 많은 메모리를 사용하지만, 활성화 저장 없는 baseline($\frac{34sbh + 5as^2b}{t}$) 대비 $5as^2b/t$ 항을 제거하면서 계산 오버헤드는 약 3%에 불과합니다.

2.4 Implementation Details

Megatron-LM 통합. Sequence parallelism과 selective recomputation은 모두 NVIDIA의 Megatron-LM 프레임워크 내에서 구현되었습니다. Sequence parallelism의 구현은 기존 tensor parallelism의 통신 연산자 $f$와 $g$를 수정하는 것으로 달성됩니다:

연산자 기존 (TP) 수정 (TP + SP)
$f$: forward identity all-gather
$f$: backward all-reduce reduce-scatter
$g$: forward all-reduce reduce-scatter
$g$: backward identity all-gather

이 수정은 비교적 간단하며, all-reduce를 reduce-scatter + all-gather로 분해하는 것의 직접적 구현입니다.

Selective recomputation의 구현은 PyTorch의 torch.utils.checkpoint 메커니즘을 활용합니다. 전체 transformer 레이어를 단일 checkpoint 단위로 감싸는 대신, attention 블록 내의 core attention 연산($QK^T$, softmax, dropout, attention over $V$)만을 별도의 checkpoint 단위로 분리합니다. Q, K, V 텐서를 재계산의 입력으로 저장하고, attention 내부의 중간 텐서는 필요시 재계산됩니다.

실험 모델 구성. 논문은 GPT-3 아키텍처를 기반으로 22B, 175B, 530B, 1T 파라미터 모델을 평가합니다:

모델 레이어 수 Hidden dim Head 수 시퀀스 길이 TP 크기
22B 48 6144 64 2048 8
175B 96 12288 96 2048 8
530B 105 20480 128 2048 8
1T 128 25600 160 2048 8

모든 실험은 NVIDIA A100 80GB GPU에서 수행되었으며, tensor parallelism 크기는 8(단일 노드 내 NVLink 연결)로 고정되었습니다. 주목할 점은 이 실험에서 데이터 병렬화를 사용하지 않고 tensor parallelism(+ pipeline parallelism)만 사용했다는 것입니다.

메모리 복잡도 요약. 각 구성에 따른 레이어당 활성화 메모리:

구성 활성화 메모리 (레이어당)
Baseline (활성화 저장) $sbh(34 + 5as/h)$
TP only $sbh(10 + 24/t + 5as/(ht))$
TP + SP $sbh(34 + 5as/h) / t$
TP + SP + Selective $34sbh / t$
TP + Full Recomputation $2sbh / t$

3. Results

논문의 실험적 평가는 메모리 절감, 학습 처리량, 그리고 각 기법의 개별 기여도를 체계적으로 분석합니다.

주요 성능 결과. 네 가지 모델 크기에 걸친 종합 결과:

모델 크기 활성화 메모리 감소 처리량 향상 GPU 활용률 실용적 영향
22B 5배 29.0% 41.5% → 43.7% 메모리 여유로 더 큰 배치 가능
175B (GPT-3) 5배 31.8% 51.4% → 52.8% 학습 비용 약 24% 절감
530B (MT-NLG) 5배 29.7% 56.0% → 57.0% 표준 클러스터에서 학습 가능
1T 5배 32.1% 56.3% → 57.0% 1조 파라미터 학습 실현

모든 모델에서 일관되게 약 30%의 처리량 향상을 보이며, 이는 기법이 특정 규모에 특화된 것이 아니라 근본적인 최적화임을 시사합니다.

구성요소 분석(Ablation Study). 22B 모델에서 각 기법의 개별 및 결합 효과:

기법 메모리 절약 속도 영향 핵심 발견
Sequence parallelism만 ~50% +3% 속도 향상 메모리 감소와 동시에 성능 향상
Selective recomputation만 ~50% −7% 오버헤드 동일한 메모리 절약, 다른 성능 프로필
둘 모두 결합 80% (5배) −4% 오버헤드 시너지 효과—개별 효과의 합보다 우수
전체 recomputation 90% −39% 오버헤드 기존 방식의 막대한 비용 확인

이 ablation에서 가장 주목할 만한 발견은 sequence parallelism이 통신을 추가함에도 불구하고 오히려 학습을 3% 가속화한다는 점입니다. 이는 두 가지 요인으로 설명됩니다. 첫째, LayerNorm과 Dropout이 더 작은 텐서에 대해 수행되어 GPU 메모리 서브시스템을 더 효율적으로 활용합니다. 둘째, 감소된 메모리 사용량이 GPU의 캐시 효율성을 향상시킵니다.

결합 효과에서 −4% 오버헤드는 sequence parallelism의 +3% 가속과 selective recomputation의 −7% 오버헤드의 합과 유사하며, 이는 두 기법이 독립적으로 작용함을 나타냅니다. 그러나 메모리 절감에서는 개별 ~50% + ~50%가 아닌 80%를 달성하여 상호보완적 효과를 보입니다. 이는 두 기법이 활성화 메모리의 서로 다른 구성 요소를 대상으로 하기 때문입니다: sequence parallelism은 비텐서병렬 영역의 중복 활성화를, selective recomputation은 attention의 시퀀스 제곱 스케일링 활성화를 각각 제거합니다.

대규모 모델 결과. 1T 파라미터 모델(128 레이어, hidden dimension 25600, 160 attention head)은 2240개 A100 GPU에서 학습되었으며, 57.0%의 GPU 활용률을 달성합니다. Pipeline parallelism 8단계, tensor parallelism 8-way, 마이크로 배치 크기 1의 구성에서, 본 논문의 기법 없이는 활성화 메모리로 인해 단일 마이크로 배치조차 GPU 메모리에 적재할 수 없었습니다. 이는 본 기법이 단순히 효율성 향상이 아니라 대규모 학습의 가능성 자체를 열었음을 의미합니다.

하드웨어 효율성 분석. 이론적 모델 FLOPs와 실제 하드웨어 FLOPs가 밀접하게 일치하는 것으로 보고됩니다. 이는 두 가지를 확인합니다: 첫째, 추가된 통신이 GPU의 계산 파이프라인을 의미있게 방해하지 않으며, 둘째, 논문의 수학적 메모리 분석 모델이 실제 시스템을 정확히 반영합니다.

4. Critical Assessment

Strengths

  1. 정밀한 수학적 분석: 논문은 transformer 레이어의 활성화 메모리를 바이트 단위로 분석하여, 각 연산의 메모리 기여도를 정확히 정량화합니다. 이 분석은 최적화 대상을 식별하는 데 직접적으로 활용되며, 단순한 경험적 관찰이 아닌 원칙적 접근법의 기반이 됩니다.
  2. 통신량 불변성의 보장: Sequence parallelism이 추가 통신 오버헤드를 도입하지 않는다는 것은 all-reduce의 분해를 통해 수학적으로 증명되며, 실험적으로도 확인됩니다. 이는 기법의 채택 장벽을 크게 낮춥니다.
  3. 일관된 스케일링: 22B에서 1T까지 4개의 서로 다른 규모에서 일관되게 약 30%의 처리량 향상을 보이며, 이는 기법이 특정 규모에 과적합되지 않은 근본적 최적화임을 강력히 시사합니다.
  4. 실용적 즉시성: Megatron-LM과 NeMo 프레임워크에 통합되어 프로덕션 환경에서 즉시 사용 가능하며, 이후 대부분의 대규모 학습 시스템의 표준 구성 요소가 되었습니다.
  5. 시퀀스 병렬화의 예상 외 가속 발견: Sequence parallelism이 통신을 추가함에도 학습을 가속화한다는 발견은 GPU의 메모리 서브시스템 활용 패턴에 대한 중요한 통찰을 제공합니다.

Limitations

  1. 데이터 병렬화 미사용: 모든 실험이 tensor parallelism과 pipeline parallelism만을 사용하며, 프로덕션 학습에서 필수적인 데이터 병렬화를 포함하지 않습니다. 데이터 병렬화와의 상호작용이나 잠재적 간섭 효과가 검증되지 않았습니다.
  2. 기준선 비교의 불완전성: ZeRO-stage 1/2/3, 매개변수 샤딩, CPU 오프로딩, 혼합 정밀도 최적화 등 다른 메모리 최적화 기법과의 직접 비교가 제공되지 않습니다. 이는 실무자가 최적의 기법 조합을 선택하는 데 필요한 정보를 제한합니다.
  3. A100 하드웨어 특이성: 모든 실험이 NVIDIA A100 80GB GPU에서만 수행되었습니다. NVLink 대역폭, GPU 메모리 크기, 계산-통신 비율이 다른 하드웨어에서의 성능은 불확실합니다.
  4. 시퀀스 길이 제약: 시퀀스 길이가 tensor parallelism 크기($t$)로 정확히 나누어떨어져야 하며, 가변 길이 시퀀스에 대한 처리가 논의되지 않습니다. 이는 패딩으로 인한 계산 낭비를 초래할 수 있습니다.
  5. 학습 수렴에 대한 영향 미검증: 기법이 학습 동역학(수렴 속도, 최종 loss)에 미치는 영향이 명시적으로 분석되지 않았습니다. 이론적으로는 수학적 동등성이 보장되지만, 수치적 정밀도 차이가 대규모 학습에서 축적될 가능성이 있습니다.

Future Directions

  1. 동적 시퀀스 분할: 가변 길이 시퀀스를 효율적으로 처리하기 위한 적응적 시퀀스 분할 메커니즘 개발, 패딩 없이 불균등한 시퀀스 길이를 처리하는 것이 핵심 과제입니다.
  2. ZeRO 및 FSDP와의 통합: 데이터 병렬화 기반 메모리 최적화(ZeRO, PyTorch FSDP)와 sequence parallelism의 결합 시 발생하는 통신 패턴 최적화 및 메모리 절감 효과의 정량적 분석이 필요합니다.
  3. 다양한 하드웨어 플랫폼 검증: H100, MI300X, TPU v5 등 다양한 가속기에서의 성능 특성 분석을 통해, 하드웨어별 최적 구성을 식별하고 이식성을 검증할 수 있습니다.
  4. 자동화된 재계산 전략 선택: 모델 구조, 하드웨어 특성, 메모리 제약을 입력으로 받아 최적의 재계산 전략을 자동으로 선택하는 시스템 개발, 프로파일링 기반 또는 ML 기반 접근법이 가능합니다.
  5. 비-Transformer 아키텍처로의 확장: State Space Model(Mamba 등)이나 Mixture of Experts 아키텍처에서 유사한 연산 인식 메모리 최적화 원칙을 적용하는 연구가 새로운 기회를 제공합니다.

Distributed Training Series (3/4)
  1. Tensor Parallel
  2. Tensor Parallel 구현 비교
  3. Pipeline Parallel (GPipe)
  4. Reducing Activation Recomputation in Large Transformer Models



    Enjoy Reading This Article?

    Here are some more articles you might like to read next:

  • Tensor Parallel 구현 비교
  • Flash Attention 3
  • Flash Attention 2
  • Flash Attention
  • Unified Sequence Parallelism
  • Stay updated — subscribe via RSS




    Leave a Comment

    Found this useful or have questions? Sign in with GitHub to join the conversation.