Flash Attention 3
TL;DR
FlashAttention-3는 AI 모델(ChatGPT 같은)의 어텐션 메커니즘을 1.5-2.0배 빠르게 만들면서 절반 정밀도(FP8)를 사용해도 정확도 손실 없이 작동하는 획기적인 최적화 기술입니다.
문제점: FlashAttention-2가 H100 GPU에서 단 35%의 활용률만을 달성 - 즉, 비싼 AI 하드웨어가 가장 중요한 연산 중에 대부분 놀고 있었습니다.
해결책: 함께 작동하는 세 가지 핵심 혁신:
- 비동기 처리: 데이터 로딩과 연산을 위한 별도의 작업자를 두는 것과 같음
- 중첩 연산: 기다리는 대신 여러 작업을 동시에 계산
- 스마트 저정밀도: 정확도 손실 없이 16비트 대신 8비트 숫자 사용
주요 결과:
- 속도: 최대 740 TFLOPs/s (GPU 활용률 75% vs 이전 35%)
- 효율성: FP8 정밀도로 거의 1.2 PFLOPs/s
- 정확도: 표준 FP8 방법보다 2.6배 더 정확
- 임팩트: 장문맥 AI 애플리케이션(전체 책, 코드베이스 분석)을 실용적으로 만듦
어텐션이 현대 AI의 연산 병목이기 때문에 이를 빠르게 만드는 것은 더 강력한 AI 애플리케이션을 더 낮은 비용으로 가능하게 합니다.
- Paper Link: https://arxiv.org/pdf/2407.08608
Related Papers
FlashAttention 시리즈 발전:
- FlashAttention - 원조 IO-aware 어텐션 알고리즘과 타일링 기법의 기초
- FlashAttention-2 - 2배 빠른 속도와 향상된 병렬화로 발전된 버전
- From Online Softmax to FlashAttention - 온라인 소프트맥스의 이론적 기초와 수치 안정성
효율적인 어텐션 메커니즘:
- Memory-efficient Attention - 메모리 최적화 어텐션 기법들의 비교 분석
- MQA - 멀티쿼리 어텐션으로 키/값 공유를 통한 추론 가속
- GQA - 그룹드 쿼리 어텐션의 품질-속도 균형점 탐색
위치 인코딩과 긴 시퀀스 처리:
- RoFormer - 회전 위치 임베딩과의 최적화된 통합
- YaRN - RoPE 기반 컨텍스트 길이 확장 기법
- Scaling Laws of RoPE-based Extrapolation - 위치 인코딩 확장과 성능 스케일링
하드웨어 최적화 및 시스템 연구:
- Tensor Parallelism - 텐서 병렬화와 FlashAttention의 결합 전략
- Reducing Activation Recomputation in Large Transformer Models - 메모리 효율적인 병렬 훈련 기법
- GPipe - 파이프라인 병렬화와의 통합 방법론
Takeaways
1. 동기: 왜 FlashAttention-3가 필요했는가
성능 격차 문제
실험적 동기: 저자들은 기존 어텐션 구현의 심각한 비효율성을 발견했습니다. FlashAttention-2가 새로운 GPU에서 최적화된 행렬 곱셈(GEMM) 커널 대비 낮은 활용률을 달성하는데, Hopper H100 GPU에서 35% 대 80-90%였습니다.
이는 포뮬러 1 자동차를 가지고 있지만 엔진 출력의 35%만 사용하는 것과 같습니다 - 프리미엄 하드웨어 비용을 지불하지만 평범한 성능을 얻고 있었습니다. 저자들은 기존 어텐션 알고리즘이 새로운 GPU 기능을 활용하도록 설계되지 않았음을 깨달았습니다.
하드웨어 진화 격차: 현대 GPU(NVIDIA H100 같은)는 새로운 기능들을 도입했습니다:
- 비동기 텐서 코어: 데이터를 로딩하면서 동시에 연산 가능
- FP8 정밀도: FP16보다 2배 빠르지만 신중한 처리 필요
- 전용 메모리 유닛: 특정 연산을 위해 설계된 하드웨어
하지만 FlashAttention-2는 구형 하드웨어를 위해 설계되어 이러한 기능들을 효과적으로 사용할 수 없었습니다.
장문맥 도전
실제 임팩트: 어텐션을 더 긴 맥락으로 확장하면 새로운 능력(여러 긴 문서와 대용량 코드베이스의 파일들에 대한 모델링과 추론), 새로운 모달리티(고해상도 이미지, 오디오, 비디오), 새로운 애플리케이션(긴 기록과의 사용자 상호작용, 긴 지평선의 에이전트 워크플로)이 열릴 것입니다.
현재 AI 모델들은 다음과 같은 작업에서 어려움을 겪습니다:
- 전체 책이나 연구 논문 분석
- 완전한 영화 스크립트나 긴 대화 이해
- 고해상도 이미지나 긴 오디오 파일 처리
- 장기간 상호작용에서 맥락 유지
병목은? 어텐션 연산이 시퀀스 길이에 대해 제곱으로 증가하여 긴 맥락을 금지적으로 비싸게 만듭니다.
2. 핵심 방법: 세 가지 상승효과 혁신
혁신 1: 생산자-소비자 비동기성
통찰: 전통적인 어텐션은 데이터가 로드될 때까지 기다린 후 연산합니다. 현대 GPU는 동시에 로드하고 연산할 수 있지만, 작업을 다르게 구성해야 합니다.
Python 의사코드:
import asyncio
import torch
class AsyncAttention:
def __init__(self):
self.data_loader_workers = [] # 생산자 워프
self.compute_workers = [] # 소비자 워프
self.shared_buffer = CircularBuffer(stages=3)
async def producer_worker(self, K_blocks, V_blocks):
"""공유 버퍼에 데이터를 지속적으로 로드"""
for i, (K_i, V_i) in enumerate(zip(K_blocks, V_blocks)):
# 버퍼 슬롯이 비어있을 때까지 대기
await self.shared_buffer.wait_for_slot(i % 3)
# 비동기적으로 데이터 로드 (블록하지 않음)
await self.gpu_load_async(K_i, self.shared_buffer.slot(i % 3))
await self.gpu_load_async(V_i, self.shared_buffer.slot(i % 3))
# 데이터 준비 완료 신호
self.shared_buffer.mark_ready(i % 3)
async def consumer_worker(self, Q_block):
"""로드된 데이터에 대해 지속적으로 연산"""
result = torch.zeros_like(Q_block)
for i in range(len(K_blocks)):
# 데이터가 로드될 때까지 대기
await self.shared_buffer.wait_for_data(i % 3)
# 다음 데이터가 백그라운드에서 로드되는 동안 연산
K_i = self.shared_buffer.get_K(i % 3)
V_i = self.shared_buffer.get_V(i % 3)
# 어텐션 연산
scores = torch.matmul(Q_block, K_i.transpose(-1, -2))
probs = torch.softmax(scores, dim=-1)
result += torch.matmul(probs, V_i)
# 슬롯이 재사용을 위해 비었다고 신호
self.shared_buffer.mark_consumed(i % 3)
return result
작동 원리 - 레스토랑 비유:
- 전통적 접근: 셰프가 웨이터가 재료를 가져올 때까지 기다린 후 요리하고, 다시 기다림
- FlashAttention-3: 웨이터가 셰프가 사용 가능한 재료로 요리하는 동안 지속적으로 재료를 가져옴
- 결과: 주방이 놀지 않고 최대 용량으로 운영됨
내 분석: 이는 현대 CPU의 파이프라이닝 작동 방식과 일치하지만 GPU 어텐션 연산에 적용한 점이 우아합니다. 핵심 통찰은 GPU 메모리 대역폭과 연산 능력을 순차적이 아닌 동시에 활용할 수 있다는 것입니다.
혁신 2: GEMM-소프트맥스 중첩
문제: H100 SXM5 GPU는 989 TFLOPS의 FP16 행렬곱을 가지지만 소프트맥스에 필요한 지수함수 같은 특수 함수는 단 3.9 TFLOPS입니다. 소프트맥스 연산(비싼 지수 함수 포함)이 빠른 행렬 곱셈을 막고 있었습니다.
Python 의사코드:
class PipelinedAttention:
def __init__(self):
self.stage_current = 0
self.stage_next = 1
async def two_stage_pipeline(self, Q, K_blocks, V_blocks):
"""QK^T, 소프트맥스, PV 연산을 중첩"""
O = torch.zeros_like(Q)
# 단계 0: 첫 번째 연산 초기화
S_current = await self.async_matmul(Q, K_blocks[0].T)
await self.wait_completion()
m, P_current, l = self.compute_softmax_stats(S_current)
# 메인 파이프라인 루프
for i in range(1, len(K_blocks) - 1):
# 파이프라인 단계 1: 다음 QK^T 시작 (기다리지 않음)
S_next_future = self.async_matmul(Q, K_blocks[i].T)
# 파이프라인 단계 2: 현재 PV 시작 (기다리지 않음)
O_update_future = self.async_matmul(P_current, V_blocks[i-1])
# 파이프라인 단계 3: QK^T 대기, 소프트맥스 연산
S_next = await S_next_future
m, P_next, l = self.compute_softmax_stats(S_next, m, l)
# 파이프라인 단계 4: PV 대기, 출력 업데이트
O_update = await O_update_future
O = self.rescale_and_add(O, O_update, m, l)
# 다음 반복을 위해 교체
P_current = P_next
return O
def compute_softmax_stats(self, S, m_old=None, l_old=None):
"""실행 통계를 사용한 수치적으로 안정적인 소프트맥스"""
if m_old is None:
m_old = torch.full((S.shape[0],), float('-inf'))
l_old = torch.zeros(S.shape[0])
# 행 최댓값 업데이트
m_new = torch.maximum(m_old, torch.max(S, dim=-1)[0])
# 확률 계산 및 행 합계 업데이트
P = torch.exp(S - m_new.unsqueeze(-1))
l_new = torch.exp(m_old - m_new) * l_old + torch.sum(P, dim=-1)
return m_new, P, l_new
조립 라인 비유:
- 전통적: 자동차 조립이 각 스테이션에서 정지 - 엔진 설치, 그 다음 페인팅, 그 다음 시트 설치
- FlashAttention-3: 여러 자동차가 동시에 다른 스테이션에서 - A차가 페인팅되는 동안 B차는 엔진, C차는 시트
- 결과: 같은 자원으로 3배 높은 처리량
숫자를 사용한 실제 예시:
# 전통적 순차 타이밍
QK_time = 10ms # 행렬 곱셈
Softmax_time = 5ms # 지수 연산
PV_time = 10ms # 행렬 곱셈
Total = 25ms per block
# 파이프라인 타이밍
# 초기 설정 후 세 연산 모두 동시 실행
# 병목은 가장 느린 연산 (10ms)
Pipeline_time = 10ms per block
Speedup = 25/10 = 2.5x 이론적
내 통찰: 이는 CPU 명령어 파이프라이닝과 유사하지만 고수준 수학적 연산에 적용된 것입니다. 천재적인 점은 다른 연산들이 다른 하드웨어 유닛을 사용하므로 동시에 실행될 수 있다는 것을 인식한 것입니다.
혁신 3: 정확도 보존을 가진 FP8
도전: FP8은 2배 속도를 제공하지만 일반적으로 수치 정확도를 파괴합니다. 대형 모델들은 일반적으로 대부분의 다른 값들보다 크기가 훨씬 큰 이상치 값들을 가지고 있어 양자화를 어렵게 만듭니다.
두 부분 해결책:
파트 A: 블록 양자화
Python 의사코드:
def block_quantization(tensor, block_size=64):
"""전체 텐서 대신 블록 단위로 양자화"""
batch, seq_len, hidden = tensor.shape
# 문제: 텐서별 양자화
# tensor_max = tensor.abs().max() # 이상치에 의해 지배됨!
# scale = tensor_max / 127
# 결과: 대부분의 값이 0 또는 ±1이 됨
# 해결책: 블록별 양자화
blocks = tensor.view(batch, seq_len // block_size, block_size, hidden)
# 각 블록이 자체 스케일을 가짐
block_maxes = blocks.abs().max(dim=-1, keepdim=True)[0].max(dim=-2, keepdim=True)[0]
scales = block_maxes / 127.0
# 각 블록을 별도로 양자화
quantized_blocks = torch.round(blocks / scales).clamp(-127, 127)
return quantized_blocks.view(tensor.shape), scales
# 이상치가 있는 예시
tensor = torch.randn(1, 1024, 128)
tensor[0, :10, :10] = 50.0 # 50배 더 큰 1% 이상치
# 텐서별 양자화 (나쁨)
global_scale = tensor.abs().max() / 127 # = 50/127 = 0.39
quantized_global = torch.round(tensor / global_scale)
# 결과: 일반 값들이 0, ±1, ±2가 됨 (끔찍한 정밀도)
# 블록 양자화 (좋음)
quantized_blocks, scales = block_quantization(tensor)
# 결과: 일반 블록은 전체 ±127 범위 사용, 이상치 블록은 별도 스케일
파트 B: 비일관성 처리
Python 의사코드:
def incoherent_processing(Q, K):
"""랜덤 직교 변환을 사용하여 이상치 분산"""
d = Q.shape[-1]
# 랜덤 직교 행렬 M = D1 * Hadamard * D2 생성
# 핵심 통찰: M이 직교이므로 MM^T = I
# 따라서 (Q*M) @ (K*M)^T = Q @ M @ M^T @ K^T = Q @ K^T
# 어텐션 출력은 변하지 않지만 이상치가 분산됨!
D1 = torch.randint(0, 2, (d,)) * 2 - 1 # 랜덤 ±1
D2 = torch.randint(0, 2, (d,)) * 2 - 1
def fast_hadamard_transform(x):
"""O(d^2) 대신 O(d log d) 행렬 곱셈"""
return hadamard_recursive(x) # Walsh-Hadamard 변환
# Q와 K 모두 변환
Q_transformed = fast_hadamard_transform(Q * D1) * D2
K_transformed = fast_hadamard_transform(K * D1) * D2
return Q_transformed, K_transformed
# 이상치가 있을 때 작동하는 이유 - 예시
original_Q = torch.tensor([[1, 1, 100]]) # 하나의 거대한 이상치
# 변환 후: 각 요소가 원래 요소들의 합이 됨
# transformed_Q ≈ [[34, 34, 34]] # 이상치가 모든 차원에 분산됨
# 양자화하기 훨씬 쉬워짐!
복권 비유:
- 문제: 한 사람이 100만원, 999명이 0원 (공정하게 표현하기 어려움)
- 해결책: 재분배하여 모든 사람이 1000원씩 (공정하게 표현하기 쉬움)
- 핵심: 총 가치는 변하지 않지만 분포가 더 균일함
내 평가: 이는 수학적으로 우아합니다 - 직교 행렬의 성질을 사용하여 최종 결과를 보존하면서 중간 연산을 양자화에 더 친화적으로 만듭니다. 수학적 손재주와 같습니다.
3. 중요한 성공 조건
하드웨어 요구사항
class HardwareChecklist:
required_features = {
'hopper_gpu': 'Hopper 아키텍처를 가진 H100 이상',
'async_tensor_cores': 'WGMMA 명령어 지원',
'tensor_memory_accelerator': '비동기 메모리 연산을 위한 TMA',
'fp8_support': 'E4M3 형식 텐서 코어',
'shared_memory': 'SM당 최소 228KB',
'register_flexibility': '동적 레지스터 할당'
}
def check_compatibility(self):
# 실제로는 실제 하드웨어를 쿼리할 것
for feature, description in self.required_features.items():
print(f"✓ {feature}: {description}")
내 분석: 이는 강점이자 제한사항입니다. 최적화들이 특정 하드웨어에 깊이 연결되어 H100에서는 믿을 수 없이 효과적이지만 다른 가속기로의 이식성은 떨어질 수 있습니다.
알고리즘 가정
-
메모리 대역폭 균형: 생산자 워프가 메모리 대역폭을 포화시킬 수 있어야 함
- 함정: 데이터 로딩이 연산 대비 너무 빠르면 생산자가 유휴 상태가 됨
- 해결책: 블록 크기와 파이프라인 단계 수의 신중한 조정
-
파이프라인 단계 타이밍: GEMM과 소프트맥스 연산이 유사한 타이밍을 가져야 함
- 함정: 소프트맥스가 너무 빠르면 중첩 이득 없음; 너무 느리면 병목이 됨
- 성공: H100의 GEMM과 지수함수 간 256배 속도 차이가 이를 가능하게 함
-
컴파일러 협력: 컴파일러가 의도된 명령어 순서를 생성해야 함
- 함정: 컴파일러가 명령어를 재배열하여 파이프라인을 깨뜨릴 수 있음
- 해결책: 메모리 장벽과 인라인 어셈블리 힌트의 신중한 사용
4. 실험 결과 분석
주요 성능 결과
구성 | FlashAttention-2 | FlashAttention-3 | 가속비 | 내 해석 |
---|---|---|---|---|
헤드차원 64, 16k seq | 324 TFLOPs/s | 497 TFLOPs/s | 1.53× | 메모리 바운드 영역이 가장 많은 이득 |
헤드차원 128, 16k seq | 370 TFLOPs/s | 648 TFLOPs/s | 1.75× | 최적화의 스위트 스팟 |
헤드차원 256, 16k seq | 581 TFLOPs/s | 756 TFLOPs/s | 1.30× | 연산 바운드, 상대적 이득 적음 |
최고 활용률 | 35% | 75% | 2.1× | 하드웨어가 마침내 효과적으로 사용됨 |
내 분석: 결과는 논리적 패턴을 따릅니다 - 더 작은 헤드 차원은 더 메모리 바운드이므로 메모리 최적화가 더 도움이 됩니다. 75% 활용률은 놀랍습니다; 비교하자면, 복잡한 시스템에서 75% 효율성을 얻는 것은 뛰어납니다.
FP8 성능 돌파
정밀도 | 최고 성능 | 정확도 (RMSE) | 내 평가 |
---|---|---|---|
FP16 표준 | 139 TFLOPs/s | 3.2e-4 | 기준 정확도 |
FP16 FlashAttention-3 | 648 TFLOPs/s | 1.9e-4 | 더 빠른 속도 AND 정확도 |
FP8 표준 | ~800 TFLOPs/s | 2.4e-2 | 빠르지만 사용할 수 없을 정도로 부정확 |
FP8 FlashAttention-3 | 1008 TFLOPs/s | 9.1e-3 | 빠르고 사용 가능한 정확도 |
핵심 통찰: 표준 FP8은 FP16보다 75배 더 나쁜 정확도 - 완전히 사용 불가능합니다. FlashAttention-3의 FP8은 48배만 더 나쁨 - 여전히 완벽하지 않지만 많은 애플리케이션에서 잠재적으로 사용 가능합니다.
제거 연구 결과
구성 | 시간 (ms) | TFLOPs/s | 구성요소 |
---|---|---|---|
기준선 | 4.105 | 570 | 최적화 없음 |
+ 워프 특화 | 4.021 | 582 | +2% 개선 |
+ 파이프라이닝만 | 4.105 | 570 | 단독으로는 개선 없음 |
+ 둘 다 | 3.538 | 661 | +16% 총합 |
중요한 통찰: 이는 가산적이 아닌 시너지 효과를 보여줍니다. 파이프라이닝 단독으로는 0% 이득을 주는데, 이는 워프 특화를 기반으로 필요하기 때문입니다. 이는 시스템 최적화에서 흔한 일입니다 - 개별 기술들은 효과적이지 않아 보일 수 있지만, 조합은 변혁적일 수 있습니다.
내 해석: 이는 “시스템 사고” 접근법을 검증합니다. 저자들은 개별 구성요소만 최적화한 것이 아니라 전체 연산 파이프라인을 재설계했습니다. 7.6배 시너지 인수(16% 조합 vs 2% 개별)는 아키텍처 변경이 가산적이 아닌 곱셈적 이득을 어떻게 풀어낼 수 있는지 보여줍니다.
5. 더 넓은 의미와 미래 임팩트
즉각적인 애플리케이션
- 장문맥 언어 모델: 전체 책, 코드베이스, 또는 대화 처리
- 고해상도 비전: 의료 이미지나 위성 데이터 분석
- 멀티모달 AI: 오디오와 텍스트가 있는 긴 비디오 처리
- 과학 컴퓨팅: 어텐션 메커니즘을 가진 분자 역학, 기후 모델링
아키텍처 교훈
- 하드웨어-소프트웨어 공동 설계의 중요성: 범용 알고리즘은 성능을 탁자 위에 남겨둠
- 비동기성이 핵심: 현대 하드웨어는 병렬 실행을 위해 설계됨
- 정밀도는 협상 가능: 모든 연산이 완전한 정밀도를 필요로 하지 않음
- 시스템 최적화: 전체가 부분의 합보다 클 수 있음
내 최종 생각: FlashAttention-3는 분야의 성숙을 나타냅니다 - “작동하게 만들기”에서 “최적으로 작동하게 만들기”로의 이동입니다. 여기의 기술들은 미래 가속기를 위한 알고리즘 설계 방식에 영향을 줄 것 같습니다. 가장 중요한 것은, 어텐션 같은 잘 연구된 문제에서도 영리한 알고리즘 설계를 통해 상당한 성능 향상이 여전히 가능하다는 것을 보여준다는 점입니다.
75% GPU 활용률 달성은 특히 주목할 만합니다 - 이는 우리가 마침내 비효율적인 알고리즘에 더 많은 하드웨어를 투입하는 대신, 우리가 구축한 강력한 하드웨어를 효과적으로 사용하는 방법을 배우고 있음을 시사합니다.
Leave a comment