Flash Attention
TL;DR
FlashAttention은 표준 구현보다 2-3배 빠르면서 10-20배 적은 메모리를 사용하는 IO-aware 어텐션 알고리즘입니다. GPU 메모리에 거대한 N×N 어텐션 행렬을 만드는 대신, 타일링(tiling)과 온라인 소프트맥스(online softmax)를 사용하여 빠른 온칩 SRAM에 맞는 작은 블록 단위로 어텐션을 처리합니다. 주요 기여사항:
- IO-Awareness: 어텐션의 실제 병목이 연산이 아닌 메모리 대역폭임을 파악
- 타일링 알고리즘: 느린 메모리 접근을 최소화하기 위해 블록 단위로 처리
- 정확한 계산: 근사 방법과 달리 정확한 어텐션을 계산하면서도 더 빠른 속도 달성
- 선형 메모리: 메모리 복잡도를 O(N²)에서 O(N)으로 줄여 16K+ 시퀀스 길이 가능
- 실용적 성과: MLPerf BERT 기록보다 15% 빠름, GPT-2 학습 3배 가속, Path-X/Path-256 최초 해결
FlashAttention 시리즈 발전:
- FlashAttention-2 - 2배 빠른 속도와 향상된 병렬화
- FlashAttention-3 - 비대칭 어텐션과 FP8 지원
- From Online Softmax to FlashAttention - 이론적 기초와 온라인 소프트맥스
긴 시퀀스 처리:
- Context Parallelism for Scalable Million-Token Inference - 추론 시 컨텍스트 분산
- Scaling Laws of RoPE-based Extrapolation - 위치 인코딩 확장
- YaRN - RoPE 기반 컨텍스트 길이 확장
- RoFormer - 회전 위치 임베딩
시스템 최적화:
- Reducing Activation Recomputation in Large Transformer Models - 메모리 효율적인 병렬 훈련
- USP - 통합 시퀀스 병렬화 프레임워크
- Tensor Parallelism - 텐서 병렬화와의 결합
- GPipe - 파이프라인 병렬화와의 통합
Takeaways
1. 문제: 표준 어텐션이 느린 이유
메모리 계층 구조의 불일치
이 논문은 대부분의 ML 연구자들이 놓치는 중요한 관찰로 시작합니다: 현대 GPU는 어텐션에서 연산 제한(compute-bound)이 아닌 메모리 제한(memory-bound) 상태입니다. 그 이유는:
# 표준 어텐션 - 실제로 일어나는 일
def standard_attention_reality(Q, K, V):
"""
Q, K, V: [batch, seq_len, dim]
실제로 시간이 소요되는 부분을 보여줍니다
"""
# 단계 1: 점수 계산 - 빠름 (텐서 코어 사용)
S = Q @ K.T # [seq_len, seq_len] - 하지만 느린 HBM에 써야 함!
# 단계 2: 소프트맥스 - 느림 (메모리 제한)
# 전체 S를 HBM에서 읽고, exp() 계산하고, 다시 쓰기
P = softmax(S) # seq_len² 개의 원소를 읽고/쓰기
# 단계 3: 가중 합 - 연산은 빠르지만 메모리는 느림
O = P @ V # 거대한 P 행렬을 HBM에서 읽어야 함
# 총 HBM 접근: O(seq_len²) - 이것이 병목!
제 통찰: 논문은 우리가 잘못된 것을 최적화하고 있다는 점을 훌륭하게 파악했습니다. 모두가 FLOPs(부동소수점 연산)를 줄이는 데 집중하는 동안, 실제 문제는 데이터 이동입니다. A100 GPU에서:
- 행렬 곱셈: 312 TFLOPS
- 메모리 대역폭: 1.5 TB/s
- 어텐션의 경우, 312 TFLOPS가 아닌 1.5 TB/s에 제한됩니다!
동기 부여 실험
저자들은 아마 다음과 같은 실험으로 연구 동기를 얻었을 것입니다:
# 표준 어텐션 프로파일링으로 병목 지점 찾기
def profile_attention(seq_len):
# 각 구성요소에서 소요되는 시간
compute_time = seq_len² * time_per_flop
memory_time = seq_len² * matrix_size * time_per_byte_transferred
# A100에서 seq_len=2048인 경우:
# 연산: ~0.5ms (활용도 낮음)
# 메모리: ~10ms (병목!)
print(f"연산 활용도: {compute_time/total_time * 100}%") # ~5%
print(f"메모리 제한 시간: {memory_time/total_time * 100}%") # ~95%
2. 해결책: FlashAttention 알고리즘
핵심 아이디어: 빠른 메모리에서 작업하기
핵심 통찰은 데이터를 가능한 한 SRAM(빠른 온칩 메모리)에 유지하는 것입니다:
# GPU 메모리 계층 구조 (A100 예시)
SRAM_per_SM = 192 * 1024 # 192 KB - 빠름 (19 TB/s)
HBM_TOTAL = 40 * 1024**3 # 40 GB - 느림 (1.5 TB/s)
# 표준 어텐션은 N×N 행렬을 위해 HBM을 사용해야 함
# FlashAttention은 타일링을 사용해 모든 것을 SRAM에 유지!
FlashAttention 알고리즘
구체적인 예시와 함께 알고리즘을 설명합니다:
def flash_attention(Q, K, V, block_size=128):
"""
상세한 예시와 함께하는 FlashAttention
Q, K, V: [1024, 64] (seq_len=1024, dim=64)
block_size: 128 (128×128 타일 처리)
"""
seq_len, dim = Q.shape
num_blocks = seq_len // block_size # 1024/128 = 8 블록
# 출력과 통계 초기화
O = zeros([seq_len, dim])
m = full([seq_len], -inf) # 수치 안정성을 위한 최댓값
l = zeros([seq_len]) # 소프트맥스를 위한 지수 합
# 블록 단위로 처리 - 이것이 핵심!
for i in range(num_blocks): # i = 0,1,...,7
# Q 블록 [128, 64]를 SRAM으로 로드
Q_block = Q[i*block_size:(i+1)*block_size]
# 블록 통계 초기화
m_block = m[i*block_size:(i+1)*block_size]
l_block = l[i*block_size:(i+1)*block_size]
O_block = O[i*block_size:(i+1)*block_size]
for j in range(num_blocks): # j = 0,1,...,7
# K,V 블록 [128, 64]를 SRAM으로 로드
K_block = K[j*block_size:(j+1)*block_size]
V_block = V[j*block_size:(j+1)*block_size]
# ---- 모든 계산이 SRAM에서 발생! ----
# 1. 블록 어텐션 점수 계산 [128, 128]
S_block = (Q_block @ K_block.T) / sqrt(dim)
# 2. 온라인 소프트맥스 - 실행 중 통계 업데이트
m_block_new = maximum(m_block, S_block.max(dim=1))
# 3. 새 최댓값으로 지수 계산 (수치 안정성)
P_block = exp(S_block - m_block_new.unsqueeze(1))
# 4. 지수 합 업데이트
l_block_new = exp(m_block - m_block_new) * l_block + P_block.sum(dim=1)
# 5. 블록 출력 계산
O_block_new = P_block @ V_block
# 6. 이전 출력 재조정 후 새 출력 추가
O_block = (exp(m_block - m_block_new).unsqueeze(1) * l_block.unsqueeze(1) * O_block +
O_block_new) / l_block_new.unsqueeze(1)
# 통계 업데이트
m_block = m_block_new
l_block = l_block_new
# 최종 결과를 HBM에 쓰기 (블록당 한 번만!)
O[i*block_size:(i+1)*block_size] = O_block
m[i*block_size:(i+1)*block_size] = m_block
l[i*block_size:(i+1)*block_size] = l_block
return O
구체적 예시: 하나의 블록 처리
특정 블록에서 일어나는 일을 추적해보겠습니다:
# 예시: Q[0:128]이 K[256:384]에 주목하는 처리
# 1024 길이 시퀀스에서 블록 (i=0, j=2)
# 이 블록 이전 상태:
# - O[0:128]은 K[0:256]으로부터의 어텐션 결과 포함
# - m[0:128] = [0.82, 0.79, ...] (지금까지의 최대 점수)
# - l[0:128] = [45.2, 52.1, ...] (지금까지의 지수 합)
# 단계 1: 새 점수 계산
S_block = Q[0:128] @ K[256:384].T / 8 # [128, 128] 행렬
# 예시 값: [[0.71, 0.65, ...], [0.73, 0.68, ...], ...]
# 단계 2: 최댓값 업데이트 (수치 안정성을 위해)
m_new = [0.82, 0.79, ...] # 이전 최댓값이 여전히 더 큼
# 단계 3: 이 블록의 어텐션 가중치 계산
P_block = exp(S_block - m_new.unsqueeze(1))
# 값: [[0.89, 0.84, ...], [0.91, 0.87, ...], ...]
# 단계 4: 지수 합 업데이트
l_new = l * 1.0 + P_block.sum(dim=1) # 재조정 불필요
# 새 값: [67.3, 73.5, ...]
# 단계 5: 가중 값 누적
O_new = (l * O + P_block @ V[256:384]) / l_new
# 이전과 새로운 기여를 올바르게 가중!
제 통찰: 아름다운 점은 시퀀스 길이와 관계없이 언제나 128×128 행렬만 SRAM에 유지한다는 것입니다!
3. 중요한 가정과 조건
하드웨어 요구사항
def check_flashattention_viability(gpu_specs, model_config):
"""
FlashAttention은 특정 하드웨어 속성이 필요합니다
"""
# 1. SM당 충분한 SRAM
sram_needed = (3 * block_size * dim + # Q, K, V 블록
block_size * block_size + # S 블록
block_size * dim) # O 블록
if sram_needed > gpu_specs.sram_per_sm:
print("❌ 블록 크기가 SRAM에 비해 너무 큼")
return False
# 2. 높은 메모리 대역폭 비율
compute_to_memory_ratio = gpu_specs.tflops / gpu_specs.memory_bandwidth
if compute_to_memory_ratio < 100: # 대략적인 임계값
print("⚠️ GPU가 메모리 제한이 아닌 연산 제한일 수 있음")
# 3. 효율적인 행렬 곱셈 유닛
if not gpu_specs.has_tensor_cores:
print("⚠️ 텐서 코어 없이는 성능 향상 감소")
return True
# 예시 확인
a100_specs = {
'sram_per_sm': 192 * 1024, # 192 KB
'tflops': 312, # FP16
'memory_bandwidth': 1.5, # TB/s
'has_tensor_cores': True
}
check_flashattention_viability(a100_specs, {'block_size': 128, 'dim': 64})
알고리즘적 가정
FlashAttention이 가장 잘 작동하는 경우에 대한 제 분석:
- 메모리 제한 워크로드: 시퀀스 길이 > 512
- 표준 어텐션 패턴: 커스텀 어텐션 마스크와는 효율적으로 작동하지 않음
- 수치 정밀도: FP16/BF16 권장 (FP32는 SRAM 효율성 감소)
- 배치 크기 유연성: 시퀀스 길이 차원으로 병렬화 불가
4. 실험 결과 분석
주요 성능 결과
모델 | 설정 | 기준선 | FlashAttention | 제 해석 |
---|---|---|---|---|
BERT-large | MLPerf 1.1 | 100% (기준) | 115% (1.15배) | MLPerf가 이미 고도로 최적화된 점을 고려하면 인상적 |
GPT-2 | HuggingFace | 100% | 300% (3.0배) | 덜 최적화된 구현에서의 가치를 보여줌 |
GPT-2 | Megatron-LM | 100% | 180% (1.8배) | 최적화된 코드에서도 여전히 큰 향상 |
Long Range Arena | Seq 1K-4K | 100% | 240% (2.4배) | 예상대로 시퀀스 길이가 길수록 향상 증가 |
제 견해: 결과는 메모리 제한 가설을 검증합니다. 더 긴 시퀀스에서 더 큰 속도 향상은 O(N²) 메모리 접근이 병목임을 확인합니다.
메모리 효율성 결과
시퀀스 길이 | 표준 메모리 | FlashAttention 메모리 | 절약 |
---|---|---|---|
512 | 1 GB | 200 MB | 5배 |
2,048 | 16 GB | 1.6 GB | 10배 |
8,192 | 256 GB | 6.4 GB | 40배 |
16,384 | 1 TB (OOM) | 25.6 GB | 새로운 기능 가능 |
제 분석: 선형 대 이차 스케일링은 혁신적입니다. 이것은 단순한 최적화가 아니라 가능한 것을 근본적으로 바꿉니다.
품질 개선
작업 | 지표 | 결과 | 제 해석 |
---|---|---|---|
GPT-2 (4K vs 1K 컨텍스트) | Perplexity | -0.7 개선 | 더 긴 컨텍스트 = 더 나은 언어 이해 |
문서 분류 | F1 점수 | +6.4 포인트 | 전체 문서를 보는 것으로 큰 향상 |
Path-X (16K) | 정확도 | 61.4% (랜덤 50% 대비) | 이를 해결한 최초 사례! |
Path-256 (64K) | 정확도 | 63.1% (랜덤 50% 대비) | 시퀀스 모델링의 한계를 넓힘 |
제 통찰: 이것들은 단순한 벤치마크 개선이 아니라 새로운 능력을 나타냅니다. Path-X는 트랜스포머에게 불가능하다고 여겨졌습니다!
5. 절제 연구: 무엇이 정말 중요한가?
구성요소 기여도
구성요소 | 없을 때 | 있을 때 | 영향 | 제 분석 |
---|---|---|---|---|
타일링 | 1.0배 | 1.4배 | +40% | 핵심 혁신 - 데이터를 SRAM에 유지 |
온라인 소프트맥스 | 1.4배 | 1.75배 | +25% | 데이터를 한 번만 통과 가능 |
재계산 | 1.75배 | 2.1배 | +20% | N×N 어텐션 행렬 저장 회피 |
커널 융합 | 2.1배 | 2.4배 | +15% | 오버헤드 감소, SRAM 사용 최대화 |
제 견해: 각 구성요소는 필요하지만 충분하지 않습니다. 이들 간의 시너지가 FlashAttention을 특별하게 만듭니다.
블록 크기 민감도
블록 크기 | 성능 | 메모리 사용 | 제 분석 |
---|---|---|---|
32 | 85% | 최소 | 너무 많은 HBM 접근이 목적을 무너뜨림 |
64 | 95% | 낮음 | 작은 모델/차원에 좋음 |
128 | 100% (최고) | 최적 | 대부분 GPU의 스위트 스팟 |
256 | 98% | 높음 | SRAM 제한이 나타남 |
제 통찰: 블록 크기는 중요하고 하드웨어 의존적입니다. 저자들은 128이 다양한 GPU에서 잘 작동한다는 것을 발견했지만, 조정이 필요합니다.
6. 실용적 의미
실무자들에게 의미하는 것
# FlashAttention 이전
def train_gpt_traditional(seq_len=2048):
if seq_len > 2048:
raise OOMError("어텐션 행렬을 메모리에 맞출 수 없음")
# 느린 학습, 제한된 컨텍스트
# FlashAttention 이후
def train_gpt_flash(seq_len=16384):
# 그냥 작동합니다! 그리고 더 빠르게도
model = GPT(seq_len=16384, use_flash_attention=True)
# 8배 더 긴 컨텍스트, 3배 빠른 학습
미래 방향 (제 생각)
- 하드웨어 공동 설계: 미래 GPU는 SRAM을 늘려 FlashAttention을 더 효과적으로 만들 수 있음
- 희소 패턴: 블록 희소 FlashAttention은 100K+ 시퀀스에 대한 가능성을 보여줌
- 다른 연산: IO-aware 원칙은 다른 메모리 제한 연산에도 적용 가능
- 추론 최적화: 현재 작업은 학습에 초점을 맞추고 있지만, 추론은 다른 패턴을 가짐
결론
FlashAttention은 시스템 사고의 걸작입니다. FLOPs 최적화에 몰려드는 대신, 저자들은 실제 병목(메모리 대역폭)을 파악하고 하드웨어 제약을 존중하는 알고리즘을 설계했습니다. 결과는 단순히 더 빠른 어텐션이 아니라 언어 모델링에서 완전히 새로운 능력을 가능하게 합니다.
핵심 교훈: 때로는 최고의 최적화는 더 적은 작업을 하는 것이 아니라, 기본 하드웨어에 대해 같은 작업을 더 지능적으로 수행하는 것에서 나옵니다.
Leave a comment