Reducing Activation Recomputation in Large Transformer Models
TL;DR
문제: 대규모 트랜스포머 모델(100B+ 매개변수) 훈련 시 중대한 병목현상 발생—그래디언트 계산을 위한 활성화(activation) 저장이 엄청난 메모리를 소모하여, 비싼 “활성화 재계산(activation recomputation)”을 강요하고 이는 30-40%의 훈련 시간 오버헤드를 추가함.
해결책: 함께 5배 메모리 감소와 30% 속도 향상을 달성하는 두 가지 상호보완적 기법:
- 시퀀스 병렬화(Sequence Parallelism): 텐서 병렬화가 처리할 수 없는 연산에서 시퀀스 차원을 따라 활성화를 분할
- 선택적 활성화 재계산(Selective Activation Recomputation): 특정한 고메모리, 저계산 어텐션 연산만 재계산
영향: 2240개 A100에서 54% GPU 사용률로 1조 매개변수 모델 훈련 가능—이전에는 불가능했던 규모를 실용적으로 만듦.
- Paper Link: https://arxiv.org/pdf/2205.05198
Related Papers
메모리 최적화 기법:
- Tensor Parallelism - 텐서 병렬화와 메모리 효율성 결합
- GPipe - 파이프라인 병렬화에서의 메모리 최적화
- Ring Self-Attention - 시퀀스 병렬화와 메모리 관리
대규모 모델 훈련:
- Switch Transformers - 대규모 MoE 모델의 메모리 효율성
- MoE - 전문가 혼합 모델의 메모리 요구사항
- LoongTrain - 긴 시퀀스 훈련에서의 메모리 최적화
분산 훈련 시스템:
- DeepSpeed Ulysses - 시퀀스 병렬화와 메모리 효율성
- USP - 통합 병렬화 프레임워크에서의 메모리 관리
- Blockwise RingAttention - 어텐션 계산의 메모리 효율성
Takeaways
근본적 문제: 트랜스포머 훈련의 메모리 벽
메모리에 맞지 않을 정도로 큰 모델을 훈련한다고 상상해보세요. 전통적인 해결책은 활성화 재계산(그래디언트 체크포인팅)입니다—역전파를 위한 중간 계산을 저장하는 대신, 버리고 나중에 다시 계산하는 것입니다. 이는 메모리를 절약하지만 본질적으로 순전파를 두 번 실행해야 하므로 엄청난 계산 오버헤드를 추가합니다.
문제의 규모:
# 530B 매개변수 모델의 경우
activation_memory = 34 * seq_len * batch * hidden + 5 * heads * seq_len²
# ≈ 레이어당 160GB (최적화 없이)
# vs. 80GB A100 GPU 메모리 용량
이는 고통스러운 트레이드오프를 강요합니다: 비싼 재계산을 사용하거나 모델을 전혀 훈련할 수 없거나.
핵심 혁신 1: 시퀀스 병렬화
통찰: 텐서 병렬화는 계산 집약적인 연산(행렬 곱셈)에는 훌륭하지만 간단한 연산(레이어 정규화, 드롭아웃)은 장치 간에 복제되어 메모리를 낭비합니다.
해결책: 이러한 “간단한” 연산들은 시퀀스 차원을 따라 독립적이므로, 대신 그 차원에서 분할할 수 있습니다.
# 전통적인 텐서 병렬화
def tensor_parallel_layer(x): # x: [seq_len, batch, hidden]
# 레이어 정규화가 각 장치에서 전체 데이터로 실행됨 (낭비적!)
x_norm = layer_norm(x) # 각 랭크에서 [seq_len, batch, hidden]
# 행렬 곱셈은 효율적으로 분할됨
x_split = matmul_tensor_parallel(x_norm) # [seq_len, batch, hidden/num_ranks]
return x_split
# 시퀀스 병렬화 적용
def sequence_tensor_parallel_layer(x_seq): # x_seq: [seq_len/num_ranks, batch, hidden]
# 시퀀스 분할 데이터에서 레이어 정규화 (메모리 효율적!)
x_norm = layer_norm(x_seq) # 랭크당 [seq_len/num_ranks, batch, hidden]
# 필요시 텐서 병렬로 변환
x_full = all_gather(x_norm) # 각 랭크에서 [seq_len, batch, hidden]
x_split = matmul_tensor_parallel(x_full)
# 시퀀스 병렬로 다시 변환
return reduce_scatter(x_split) # 랭크당 [seq_len/num_ranks, batch, hidden]
핵심 수학적 통찰: 통신 비용이 동일합니다—all_reduce = reduce_scatter + all_gather
—따라서 대역폭 오버헤드가 없습니다.
핵심 혁신 2: 선택적 활성화 재계산
통찰: 모든 연산이 동등하지 않습니다. 일부는 많은 메모리를 소모하지만 재생성하는 데 최소한의 계산만 필요합니다.
대상 연산: 어텐션에서 Q, K, V 계산 후:
- QK^T 행렬 곱셈 → 소프트맥스 → 드롭아웃 → V에 대한 어텐션
- 대형 모델에서 활성화 메모리의 ~70%를 차지
- 하지만 전체 계산의 ~3%만 차지 (대부분 원소별 연산)
def selective_attention_recomputation(q, k, v):
# 저장: Q, K, V (메모리 집약적 연산의 입력)
q_checkpoint = q.detach().requires_grad_(True)
k_checkpoint = k.detach().requires_grad_(True)
v_checkpoint = v.detach().requires_grad_(True)
# 재계산: 고메모리, 저계산 연산들
def recompute_expensive_memory_ops(q_in, k_in, v_in):
scores = torch.matmul(q_in, k_in.transpose(-1, -2)) # QK^T
weights = F.softmax(scores, dim=-1) # 소프트맥스
weights = F.dropout(weights, training=True) # 드롭아웃
output = torch.matmul(weights, v_in) # V에 대한 어텐션
return output
# 자동 재계산을 위한 PyTorch 체크포인트 사용
return torch.utils.checkpoint.checkpoint(
recompute_expensive_memory_ops, q_checkpoint, k_checkpoint, v_checkpoint
)
메모리 수학:
원래: 34*s*b*h + 5*a*s²*b (모든 것 저장)
선택적: 34*s*b*h (저메모리/고계산 연산만 저장)
감소: 5*a*s²*b 항 제거 (시퀀스 길이 제곱 스케일링 제거!)
실험 결과: 숫자의 진정한 의미
주요 성능 결과
| 모델 크기 | 메모리 감소 | 훈련 속도 향상 | GPU 사용률 | 실용적 영향 | |———–|————-|—————-|————|—————–| | 22B | 5배 | 29.0% | 41.5% → 43.7% | 겨우 맞던 모델이 이제 편안하게 훈련됨 | | 175B (GPT-3) | 5배 | 31.8% | 51.4% → 52.8% | 10만 달러 훈련 → 7만 6천 달러 (2만 4천 달러 절약) | | 530B (MT-NLG) | 5배 | 29.7% | 56.0% → 57.0% | 표준 클러스터에서 훈련 가능 | | 1T | 5배 | 32.1% | 56.3% → 57.0% | 1조 매개변수 모델을 실용적으로 만듦 |
구성요소 분석: 각 기법의 기여도 이해
| 기법 | 메모리 절약 | 속도 영향 | 핵심 통찰 | |——|————-|———–|—————| | 시퀀스 병렬화만 | ~50% | +3% 속도 향상 | 메모리 감소 및 성능 향상—레이어 정규화/드롭아웃이 예상보다 비쌌음 | | 선택적 재계산만 | ~50% | +7% 오버헤드 | 동일한 메모리 절약이지만 다른 성능 프로필—서로 다른 병목 최적화 | | 둘 모두 결합 | 80% (5배) | +4% 오버헤드 | 시너지 효과—이익이 선형적으로 결합되는 것보다 좋음 | | 전통적 재계산 | 90% | +39% 오버헤드 | 현상 유지가 극도로 비싸다는 것을 검증 |
놀라운 발견: 시퀀스 병렬화는 통신을 추가함에도 불구하고 실제로 훈련을 가속화합니다. 이는 레이어 정규화와 드롭아웃이 예상보다 계산적으로 비싸다는 것을 보여줍니다.
비판적 평가: 장점과 한계
주요 장점
- 수학적 엄밀성: 다양한 병렬화 전략 하에서 메모리 스케일링에 대한 정확한 공식 제공
- 일관된 스케일링: 22B에서 1T 매개변수까지 30% 향상 유지—근본적 최적화를 시사
- 실용적 구현: 프로덕션 프레임워크(Megatron-LM, NeMo)에서 사용 가능
- 이론적 검증: 하드웨어 FLOPs가 예측된 모델 FLOPs와 밀접하게 일치하여 수학적 모델 확인
중요한 한계점
- 실험적 현실성: 데이터 병렬화 미사용 (프로덕션 훈련에 비현실적)
- 기준선 완전성: ZeRO, 매개변수 샤딩, CPU 오프로딩과의 비교 누락
- 하드웨어 특이성: A100에서만 테스트—일반화 불분명
- 스케일 제약: 시퀀스 길이가 텐서 병렬 크기로 나누어떨어져야 함
성공을 위한 숨겨진 가정들
# 종종 간과되는 중요한 요구사항들:
assert sequence_length % tensor_parallel_size == 0 # 균등하게 나누어떨어져야 함
assert torch.cuda.memory_fragmentation() < 0.1 # 낮은 단편화 필요
assert all_ranks_synchronized() # 완벽한 동기화 필요
assert batch_sequences_uniform_length() # 가변 길이 문제 있음
실용적 구현 가이드
가장 잘 작동하는 경우
- 대형 모델 (>100B 매개변수) 활성화 메모리가 지배적인 경우
- 긴 시퀀스 5as²/h 항이 중요한 경우
- 고대역폭 상호연결 (NVLink, InfiniBand) 효율적인 통신을 위해
- 균등한 워크로드 일관된 시퀀스 길이
주의해야 할 경우
- 가변 시퀀스 길이: 계산을 낭비할 수 있는 패딩 필요
- 소형 모델: 오버헤드가 이익을 상회할 수 있음
- 메모리 제약 환경: 추가 기법(ZeRO, 오프로딩) 필요할 수 있음
- 레거시 하드웨어: 현대 가속기에 최적화된 통신 패턴
통합 전략
# 권장 구현 방식:
def adaptive_optimization_strategy(model_size, available_memory, sequence_length):
if model_size > 100e9: # >100B 매개변수
use_sequence_parallelism = True
memory_ratio = estimate_memory_usage() / available_memory
if memory_ratio > 0.9:
use_selective_recomputation = True
elif memory_ratio > 0.7:
use_microbatch_recomputation = True
else:
use_selective_recomputation = False
else:
# 작은 모델의 경우, 이익이 복잡성을 정당화하지 못할 수 있음
return "standard_tensor_parallelism"
미래 방향과 연구 기회
즉시 확장 가능한 영역
- 동적 시퀀스 병렬화: 가변 길이 시퀀스를 효율적으로 처리
- 하이브리드 메모리 전략: ZeRO 및 CPU 오프로딩과 결합
- 하드웨어 최적화: 다양한 가속기 아키텍처에 맞춘 튜닝
- 자동화된 최적화: ML 가이드 재계산 전략 선택
근본적 질문들
- 일반화: 이러한 패턴이 다른 아키텍처(Vision Transformer 등)에 적용되는가?
- 스케일링 한계: 1조 매개변수 모델을 넘어서면 어떻게 되는가?
- 통신 진화: 미래의 상호연결이 최적 전략을 어떻게 바꿀 것인가?
큰 그림
이 연구는 “일률적인” 최적화에서 연산 인식 메모리 관리로의 패러다임 전환을 나타냅니다. 서로 다른 연산이 서로 다른 메모리-계산 트레이드오프를 가진다는 것을 인식함으로써, 훨씬 더 정교한 최적화 전략의 문을 엽니다.
연구자들을 위해: 대규모 훈련에서 메모리 병목을 분석하기 위한 수학적 프레임워크 제공
실무자들을 위해: 훈련 효율성을 즉시 개선할 수 있는 프로덕션 준비된 기법 제공
분야 전체를 위해: 계산 패턴의 신중한 분석이 놀라운 최적화를 가져올 수 있음을 보여줌—대형 모델 훈련 방식에는 여전히 상당한 개선 여지가 있음
이 기법들은 이미 이전에는 불가능했던 훈련을 가능하게 하고 있으며, 수학적 프레임워크는 모델이 다조 매개변수 체제로 계속 확장됨에 따라 미래 최적화의 기반을 제공합니다.
Leave a comment