Flash Attention 3
- Flash Attention
- Flash Attention 2
- Flash Attention 3
TL;DR
FlashAttention-3는 NVIDIA Hopper GPU의 비동기 실행 및 저정밀도(FP8) 하드웨어 기능을 활용하도록 FlashAttention 알고리즘을 재설계하여, FlashAttention-2 대비 1.5-2.0배의 속도 향상을 달성하고 FP16에서 최대 740 TFLOPs/s(75% 활용률), FP8에서 약 1.2 PFLOPs/s에 근접한 성능을 보인다. 핵심 혁신은 warp 전문화된 생산자-소비자 파이프라이닝, 2단계 비동기 파이프라인을 통한 GEMM-softmax 중첩, 그리고 정확한 FP8 attention을 위한 블록 양자화와 incoherent 처리이다. 처리량 향상은 상당하지만, 이 기법들은 현재 Hopper 아키텍처에 특화되어 있으며 FP8 커널에는 persistent 커널 설계와 최적화된 causal 마스킹이 부재하여 추가 개선의 여지가 있다.
- Paper Link: https://arxiv.org/pdf/2407.08608
Related Papers
- FlashAttention-2 - 2배 빠른 속도와 향상된 병렬화로 발전된 버전
- FlashAttention - 원조 IO-aware 어텐션 알고리즘과 타일링 기법의 기초
Takeaways
1. Contribution
Attention은 Transformer 아키텍처의 계산 병목으로, 시퀀스 길이에 대해 이차적으로 스케일된다. 문서 이해, 코드 분석, 비디오 생성, 에이전트 워크플로우 등의 응용을 위해 수만 또는 수십만 토큰에 이르는 더 긴 컨텍스트를 처리하도록 모델이 발전함에 따라, attention 연산을 더 빠르게 만드는 것은 단순한 엔지니어링 문제가 아니라 완전히 새로운 능력을 가능하게 하는 전제 조건이 되었다. FlashAttention(Dao et al., 2022)과 FlashAttention-2(Dao, 2023)는 attention 계산을 단일 GPU 커널로 융합하고 느린 전역 메모리(HBM)에 대한 읽기/쓰기를 최소화하며, 온라인 softmax 트릭에 기반한 타일링 전략을 사용하여 획기적인 발전을 이뤘다. 그러나 FlashAttention-2는 동기 실행 모델로 설계되어 NVIDIA H100 GPU에서 약 35%의 하드웨어 활용률만 달성했는데, 이는 최적화된 GEMM 커널이 달성하는 80-90% 활용률과 비교된다. 이 격차는 FlashAttention-2가 Hopper 아키텍처의 두 가지 핵심 기능을 활용하지 못하기 때문에 존재한다: 전문화된 하드웨어 유닛의 비동기 실행과 네이티브 FP8 저정밀도 계산이다.
NVIDIA Hopper 아키텍처는 GPU 커널 설계 방식을 근본적으로 변화시키는 여러 하드웨어 혁신을 도입했다. Tensor Memory Accelerator(TMA)는 전역 메모리(HBM)와 공유 메모리(SMEM) 간의 비동기 데이터 이동을 위한 전용 하드웨어 유닛으로, 계산 유닛과 독립적으로 작동한다. Hopper의 Tensor Core는 WGMMA(warpgroup-level matrix multiply-accumulate) 명령어를 통해 노출되며, 비동기적으로 실행되고 공유 메모리에서 직접 입력을 받을 수 있다. 중요한 점은 이러한 유닛들–데이터 이동을 위한 TMA와 행렬 곱셈을 위한 Tensor Core–이 독립적인 하드웨어 에이전트로 작동한다는 것이다. 이는 원칙적으로 데이터 로딩, 행렬 곱셈, 비행렬곱셈 연산(예: softmax)이 알고리즘과 소프트웨어가 이러한 병렬성을 활용하도록 설계된 경우 모두 동시에 진행될 수 있음을 의미한다. 또한 Hopper의 FP8 Tensor Core는 동일한 전력과 칩 면적에서 FP16 Tensor Core의 두 배 처리량을 제공하지만, 엄격한 레이아웃 제약과 감소된 수치 정밀도를 가지고 있어 신중한 알고리즘 처리가 필요하다.
FlashAttention-3는 세 가지 시너지적인 기여로 이러한 격차를 해결한다. 첫째, cooperative thread array(CTA) 내의 warp를 별개의 역할로 분리하는 생산자-소비자 warp 전문화 방식을 도입한다: 생산자 warp는 TMA를 통한 데이터 이동을 처리하고, 소비자 warp는 WGMMA를 통한 계산을 수행한다. 이러한 분리는 순환 공유 메모리 버퍼 및 배리어 동기화와 결합되어 데이터 로딩이 계산과 중첩되도록 하여 메모리 레이턴시를 효과적으로 숨긴다. 둘째, FlashAttention-3는 warpgroup 내 GEMM-softmax 파이프라이닝을 도입하는데, 이는 attention 내부 루프 내에서 softmax 연산과 GEMM 연산 간의 순차적 의존성을 깨는 2단계 알고리즘이다. 점수 행렬 $\mathbf{S}$에 대한 두 개의 버퍼를 유지함으로써, 알고리즘은 한 블록에 대한 softmax 계산을 다음 블록에 대한 WGMMA 실행과 중첩시켜, 상대적으로 낮은 처리량의 softmax 연산(FP16 행렬곱셈의 989 TFLOPs/s에 비해 지수 함수를 위한 다기능 유닛에서 3.9 TFLOPs/s만 사용)이 높은 처리량의 Tensor Core 연산의 그림자에서 실행되도록 보장한다. 셋째, FlashAttention-3는 커널 내 전치 및 바이트 순열 명령어를 사용하여 FP32 누산기와 FP8 피연산자 간의 레이아웃 적합성 문제를 해결하고, 블록 양자화 및 incoherent 처리를 도입하여 감소된 정밀도로 인한 정확도 손실을 완화함으로써 FP8 정밀도에 알고리즘을 적응시켜, 표준 텐서당 FP8 양자화 대비 2.6배 낮은 수치 오차를 달성한다.
이러한 기여의 실용적 의의는 상당하다. H100 SXM5 GPU에서 FP16 FlashAttention-3는 forward pass에서 최대 740 TFLOPs/s(이론적 989 TFLOPs/s 피크의 75%)에 도달하는데, 이는 head dimension 128에서 FlashAttention-2의 약 370 TFLOPs/s에 비교된다. FP8에서는 FlashAttention-3가 긴 시퀀스 길이에서 head dimension 256에 대해 약 1.2 PFLOPs/s에 근접하여 FP16 처리량의 거의 두 배를 보여준다. 중간에서 긴 시퀀스(1k 토큰 이상)에 대해 FlashAttention-3는 H100 GPU를 위해 특별히 최적화된 NVIDIA의 독점 cuDNN attention 구현마저 능가한다. 이러한 개선은 대규모 언어 모델 및 긴 컨텍스트 애플리케이션의 더 빠른 학습 및 추론으로 직접 전환되며, FlashAttention이 Ring Attention과 같은 분산 attention 방법의 기초 primitive로 작용하기 때문에 이점은 전체 생태계에 걸쳐 확산된다.
2. Methodology
2.1 Core Intuition
FlashAttention-3를 주도하는 근본적인 통찰은 현대 GPU 아키텍처가 단일체 계산 엔진이 아니라 동시에 작동할 수 있는 전문화된 하드웨어 유닛의 집합이라는 것이다. Hopper H100 GPU는 attention과 관련된 최소한 세 가지 독립적인 실행 리소스를 가지고 있다: HBM과 SMEM 간 데이터 이동을 위한 Tensor Memory Accelerator(TMA), 행렬 곱셈 수행을 위한 Tensor Core(WGMMA 명령어를 통해), 그리고 softmax에 필요한 지수 함수, 행 최대값, 행 합과 같은 비행렬곱셈 연산을 위한 CUDA 코어(다기능 유닛 포함)이다. FlashAttention-2에서는 이러한 리소스가 순차적으로 사용된다: K 블록을 로드하고, $\mathbf{S} = \mathbf{Q}\mathbf{K}^\top$를 계산하고, 대기하고, softmax를 계산하고, 대기한 다음, V를 로드하고 $\mathbf{O} = \tilde{\mathbf{P}}\mathbf{V}$를 계산하고, 대기한다. 이러한 직렬화는 각 하드웨어 유닛이 다른 유닛이 활성화되어 있는 동안 유휴 상태로 있음을 의미하며, 이것이 35% 활용률을 설명한다.
FlashAttention-3의 설계 철학은 세 가지 하드웨어 리소스를 모두 동시에 바쁘게 유지하는 것이다. 이를 위해서는 근본적인 문제를 해결해야 한다: attention 계산에는 진정한 데이터 의존성이 있다. softmax는 첫 번째 GEMM($\mathbf{Q}\mathbf{K}^\top$)의 출력이 필요하고, 두 번째 GEMM($\tilde{\mathbf{P}}\mathbf{V}$)은 softmax의 출력이 필요하다. 이러한 의존성은 내부 루프의 단일 반복 내에서 제거될 수 없다. 그러나 반복 간에 파이프라이닝을 함으로써, 알고리즘은 반복 $j$의 softmax를 반복 $j+1$의 GEMM 연산 및 반복 $j+2$의 데이터 로딩과 중첩시킬 수 있다. 이는 CPU의 명령어 수준 파이프라이닝을 주도하는 것과 동일한 원리이지만, GPU 하드웨어 유닛의 비동기 기능을 활용하기 위해 알고리즘 수준에서 적용된 것이다.
행렬곱셈과 비행렬곱셈 연산 간의 처리량 격차는 이러한 중첩을 단순히 바람직한 것이 아니라 필수적으로 만든다. H100은 FP16 행렬곱셈의 989 TFLOPs/s 처리량을 제공하지만 지수 함수와 같은 특수 함수에 대해서는 3.9 TFLOPs/s만 제공한다. head dimension $d = 128$에 대해 행렬곱셈 FLOP과 지수 연산의 비율은 $512\times$이지만, 처리량 비율은 $256\times$에 불과하므로 지수 함수가 행렬곱셈 대비 사이클 시간의 약 50%를 소비할 수 있음을 의미한다. FP8에서는 행렬곱셈 처리량이 약 1978 TFLOPs/s로 두 배가 되는 반면 지수 함수 처리량은 3.9 TFLOPs/s로 유지되므로, 비행렬곱셈 연산이 훨씬 더 심각한 병목이 되어 높은 활용률 달성을 위해 중첩이 절대적으로 중요하다.
2.2 Model Architecture
FlashAttention-3는 FlashAttention-2와 동일한 타일링된 융합 커널 접근 방식을 유지하지만 내부 실행 모델을 재구성한다. 전체 계산은 동일하게 유지된다: $\mathbf{Q}, \mathbf{K}, \mathbf{V} \in \mathbb{R}^{N \times d}$가 주어지면, 다음을 계산한다:
\[\mathbf{S} = \alpha \mathbf{Q}\mathbf{K}^\top \in \mathbb{R}^{N \times N}, \quad \mathbf{P} = \text{softmax}(\mathbf{S}) \in \mathbb{R}^{N \times N}, \quad \mathbf{O} = \mathbf{P}\mathbf{V} \in \mathbb{R}^{N \times d}\]여기서 $\alpha = 1/\sqrt{d}$이고 softmax는 행별로 적용된다. 커널은 배치 크기, 헤드 수, 쿼리 시퀀스 길이 타일에 걸쳐 병렬화된다(embarrassingly parallel). 각 CTA는 크기 $B_r \times d$의 하나의 쿼리 블록 $\mathbf{Q}_i$를 처리하고 크기 $B_c \times d$의 모든 키/값 블록 $\mathbf{K}_j, \mathbf{V}_j$에 걸쳐 반복한다.
각 CTA 내의 아키텍처는 다음과 같이 구성된다:
CTA 구조 (FlashAttention-3)
=================================
Producer Warpgroup (1 warp):
- TMA 로드 발행: Q_i, K_j, V_j from HBM -> SMEM
- s-단계 순환 SMEM 버퍼 관리
- 최소 레지스터 사용 (setmaxnreg를 통해)
Consumer Warpgroup(s) (1-2 warpgroup, 각각 4 warp):
- GEMM 연산을 위한 WGMMA 실행
- softmax 계산 (exp, row-max, row-sum)
- 최대 레지스터 할당
- 2개의 consumer warpgroup 사용 시: pingpong 스케줄링
H100의 메모리 계층이 모든 수준에서 활용된다. HBM(3.35 TB/s로 80 GiB)은 Q, K, V, O를 저장한다. 공유 메모리(SM당 228 KiB, 총 31 TB/s)는 순환 버퍼를 통해 스테이징 영역 역할을 한다. 레지스터 파일(SM당 256 KiB)은 실행 중인 누산기 $\mathbf{O}_i$, $\ell_i$, $m_i$, 점수 행렬, 그리고 softmax 중간 값을 보유한다.
2.3 Key Algorithms & Mechanisms
생산자-소비자 Warp 전문화. CTA 내의 warp는 생산자와 소비자 역할로 나뉜다. 생산자 warpgroup은 TMA 로드를 발행하는 데 단일 스레드만 필요하므로 레지스터를 할당 해제한다(Hopper setmaxnreg 명령어를 통해). 이렇게 해제된 레지스터는 WGMMA 누산기 및 중간 값을 위한 광범위한 레지스터 공간이 필요한 소비자 warpgroup에 재할당된다. 생산자와 소비자는 배리어 동기화가 있는 $s$-단계 순환 SMEM 버퍼를 통해 통신한다. 생산자는 비동기적으로 $\mathbf{K}_j$와 $\mathbf{V}_j$를 버퍼로 로드하고, 알림을 커밋하고, 다음 블록을 로드하기 위해 이동한다. 소비자는 알림을 기다리고, 데이터를 소비하고, 버퍼 단계를 생산자에게 다시 릴리스한다. TMA 로드는 비동기적이므로 생산자는 소비자보다 여러 버퍼 단계를 앞서 채울 수 있어 메모리 레이턴시를 숨기는 파이프라인을 생성한다.
Pingpong 스케줄링. 두 개의 소비자 warpgroup이 사용될 때, FlashAttention-3는 한 warpgroup의 softmax를 다른 warpgroup의 GEMM과 중첩시키기 위해 pingpong 스케줄링을 사용한다. 동기화 배리어(bar.sync 명령어)는 warpgroup 1의 GEMM이 warpgroup 2의 GEMM보다 먼저 스케줄링되도록 강제한다. 결과적으로 warpgroup 2가 Tensor Core에서 GEMM을 실행하는 동안, warpgroup 1은 CUDA 코어 및 다기능 유닛에서 softmax를 수행한다. 그런 다음 역할이 바뀐다. 경험적으로 이는 처리량을 head dimension 128 및 시퀀스 길이 8192에 대해 FP16 forward pass에서 약 570 TFLOPs/s에서 620-640 TFLOPs/s로 개선한다.
2단계 GEMM-Softmax 파이프라이닝 (Algorithm 2). 이것이 가장 기술적으로 미묘한 기여이다. 단일 소비자 warpgroup 내에서 알고리즘은 내부 루프의 반복 간 파이프라이닝을 통해 순차적 의존성 체인을 깬다. Algorithm 1(기본 warp 전문화 버전)에서의 핵심 수정 사항은 다음과 같다:
기본 버전에서 각 반복 $j$는 순차적으로 실행된다:
- WGMMA: $\mathbf{S}_i^{(j)} = \mathbf{Q}_i \mathbf{K}_j^\top$ (커밋 후 대기)
- Softmax: $m_i$, $\tilde{\mathbf{P}}_i^{(j)}$, $\ell_i$ 계산
- WGMMA: $\mathbf{O}_i = \mathbf{O}_i + \tilde{\mathbf{P}}_i^{(j)} \mathbf{V}_j$ (커밋 후 대기)
2단계 파이프라인 버전에서 반복 $j$의 메인루프 본문은 다음과 같이 된다:
- WGMMA: $\mathbf{S}_\text{next} = \mathbf{Q}_i \mathbf{K}_j^\top$ (커밋하지만 대기하지 않음)
- WGMMA: $\mathbf{O}i = \mathbf{O}_i + \tilde{\mathbf{P}}\text{cur} \mathbf{V}_{j-1}$ (커밋하지만 대기하지 않음)
- $\mathbf{S}_\text{next}$ WGMMA 완료 대기
- Softmax: $\mathbf{S}\text{next}$에 기반하여 $m_i$, $\tilde{\mathbf{P}}\text{next}$, $\ell_i$ 계산
- $\tilde{\mathbf{P}}\text{cur} \mathbf{V}{j-1}$ WGMMA 대기 후 $\mathbf{O}_i$ 재스케일
- $\mathbf{S}\text{next} \to \mathbf{S}\text{cur}$, $\tilde{\mathbf{P}}\text{next} \to \tilde{\mathbf{P}}\text{cur}$ 복사
중요한 점은 2단계이다: 두 번째 WGMMA($\tilde{\mathbf{P}}\text{cur} \mathbf{V}{j-1}$)가 대기 없이 발행되면 Tensor Core의 비동기 실행 큐에 들어간다. 이 WGMMA가 실행되는 동안 3-4단계는 CUDA 코어 및 다기능 유닛에서 다음 반복에 대한 softmax를 수행한다. WGMMA와 비WGMMA 명령어는 서로 다른 하드웨어 유닛을 사용하므로 진정으로 동시에 실행될 수 있다. 논문의 SASS 분석은 이를 확인한다: 첫 번째 WGMMA(8개의 HGMMA 명령어로 분해됨)는 softmax 연산(FMNMX, MUFU.EX2, FADD, FMUL)과 인터리빙되는 반면, 두 번째 WGMMA(12개의 HGMMA 명령어)는 인터리빙 없이 실행된다.
이 파이프라이닝의 비용은 추가 레지스터 압력이다. 크기 $B_r \times B_c \times \text{sizeof}(\text{float})$의 추가 $\mathbf{S}_\text{next}$ 버퍼를 레지스터에 유지해야 하므로 파이프라인 깊이와 타일 크기 간에 긴장이 생긴다.
FP8 레이아웃 적합성. Hopper의 FP8 WGMMA는 SMEM의 두 피연산자 모두에 k-major 레이아웃 제약을 부과한다. 첫 번째 GEMM($\mathbf{Q}\mathbf{K}^\top$)의 경우 Q와 K는 head dimension(이 GEMM의 K dimension)에서 자연스럽게 연속적으로 저장되므로 변환이 필요하지 않다. 그러나 두 번째 GEMM($\tilde{\mathbf{P}}\mathbf{V}$)의 경우 V는 head dimension이 아닌 시퀀스 길이 dimension(이 GEMM의 K dimension)에서 연속적이어야 한다. FlashAttention-3는 레이아웃을 전치하면서 warp당 128바이트를 집합적으로 로드/저장할 수 있는 LDSM(ldmatrix) 및 STSM(stmatrix) 명령어를 사용하여 커널 내 전치로 이를 해결한다. 이 전치는 생산자 warpgroup에서 실행되며 이전 반복의 두 WGMMA의 그림자에 숨겨진다.
또한 첫 번째 WGMMA의 FP32 누산기 레이아웃은 두 번째 WGMMA가 예상하는 FP8 피연산자 A 레이아웃과 일치하지 않는다. 누산기는 스레드당 요소를 ${d_0, d_1, d_2, d_3, d_4, d_5, d_6, d_7}$ 순서로 저장하지만, FP8 피연산자는 ${d_0, d_1, d_4, d_5, d_2, d_3, d_6, d_7}$를 예상한다. FlashAttention-3는 바이트 순열 명령어를 사용하여 레지스터 내용을 재배열하고, 효과적으로 $\tilde{\mathbf{P}}$의 열을 순열한다. V의 커널 내 전치는 일치하는 행 순열을 적용하도록 그에 맞춰 조정되어 행렬 곱이 올바르게 유지되도록 한다.
블록 양자화. 텐서당 스케일링(텐서당 하나의 스케일 인자) 대신, FlashAttention-3는 Q, K, V에 대해 각각 크기 $B_r \times d$ 또는 $B_c \times d$의 블록당 하나의 스케일 인자를 유지한다. FlashAttention 알고리즘은 자연스럽게 이러한 블록에서 작동하므로, 블록별 스케일 인자는 첫 번째 GEMM 이후 S의 각 블록에 추가 계산 비용 없이 적용될 수 있다. 블록 수준 양자화는 회전 임베딩과 같은 선행 연산과 융합될 수 있는데, 회전 임베딩이 메모리 대역폭 제한적이므로 속도 저하 없이 가능하다.
Incoherent 처리. LLM의 이상치 특징(활성화 값의 작은 부분이 나머지보다 훨씬 큰 경우)을 처리하기 위해, FlashAttention-3는 FP8 양자화 전에 Q와 K에 무작위 직교 행렬 $\mathbf{M}$을 곱한다. $\mathbf{M}\mathbf{M}^\top = \mathbf{I}$이므로:
\[(\mathbf{Q}\mathbf{M})(\mathbf{K}\mathbf{M})^\top = \mathbf{Q}\mathbf{M}\mathbf{M}^\top\mathbf{K}^\top = \mathbf{Q}\mathbf{K}^\top\]이는 이상치를 “퍼뜨린다”. $\mathbf{Q}\mathbf{M}$의 각 항목이 $\mathbf{Q}$ 항목의 무작위 선형 조합이기 때문에 동적 범위를 줄이고 따라서 양자화 오차를 줄인다. 실제로 $\mathbf{M}$은 무작위 대각 $\pm 1$ 행렬과 Hadamard 행렬의 곱으로 선택되어 $O(d^2)$ 대신 $O(d \log d)$ 곱셈을 가능하게 하며 회전 임베딩과 융합 가능하다.
2.4 Implementation Details
FlashAttention-3는 NVIDIA의 CUTLASS 라이브러리 primitive를 사용하여 구현되며, 특히 CUTLASS 3.5의 WGMMA 및 TMA 추상화를 사용한다. 구현은 CUDA 12.3을 사용하는 H100 SXM5 GPU를 대상으로 한다.
블록 크기: 쿼리 블록 크기 $B_r$과 키 블록 크기 $B_c$는 레지스터 압력과 병렬성 간의 균형을 맞추도록 선택된다. 2단계 파이프라인의 추가 레지스터 요구 사항인 $B_r \times B_c \times 4$ 바이트는 최대 타일 크기를 제약한다. head dimension 128의 FP16에 대해 구현은 로드 밸런싱이 있는 persistent 커널을 사용한다. FP8 커널은 아직 persistent 커널 설계를 통합하지 않았으며, 이는 causal 마스킹이 있는 작은 시퀀스 길이에서 감소된 성능을 부분적으로 설명한다.
순환 버퍼 단계: 순환 SMEM 버퍼의 단계 수 $s$는 튜닝 매개변수이다. 더 많은 단계는 더 깊은 프리페칭을 가능하게 하지만 더 많은 SMEM을 소비한다. GEMM-softmax 중첩을 위한 파이프라인 단계 수(2 또는 3)는 $s$에 의해 제한되지만 같을 필요는 없다.
수치 안정성: Softmax 중간 결과(행 최대값 빼기, 재스케일링)는 입력이 FP16일 때도 FP32로 유지되어 FlashAttention-2의 수치 동작과 일치하고 FP16에서 softmax를 수행하는 표준 attention보다 $1.7\times$ 낮은 RMSE를 생성한다.
벤치마크 구성: 벤치마크는 1830 MHz의 고정 GPU 클럭 속도, 512에서 16k까지의 시퀀스 길이, 은닉 차원 2048, head dimension 64/128/256을 사용한다. FLOP은 forward pass에 대해 $4 \cdot N^2 \cdot d \cdot h$로 계산되며(causal의 경우 절반), backward pass에 대해 2.5를 곱한다. 벤치마크는 100회 실행의 평균이다.
Backward pass: backward pass(Algorithm 3)는 warp 전문화를 세 가지 역할로 확장한다: 생산자(데이터 이동), 소비자(계산), dQ-writer($d\mathbf{Q}$의 전역 메모리로의 원자적 누적). dQ-writer 역할은 여러 CTA가 겹치는 $d\mathbf{Q}$ 영역에 쓰기 때문에 분리되며, 그렇지 않으면 계산 warp를 차단할 메모리 경합을 유발한다.
3. Results
FlashAttention-3의 경험적 평가는 H100 80GB SXM5 GPU에서 여러 구성에 걸쳐 수행되어 주장된 속도 향상에 대한 포괄적인 증거를 제공한다.
Forward Pass 성능 (FP16). 헤드라인 결과는 FlashAttention-3가 모든 테스트된 구성에 걸쳐 forward pass에서 FlashAttention-2 대비 1.5-2.0배의 속도 향상을 달성한다는 것이다.
| 구성 | FA-2 (TFLOPs/s) | FA-3 (TFLOPs/s) | cuDNN (TFLOPs/s) | FA-2 대비 속도 향상 |
|---|---|---|---|---|
| hdim 64, no causal, 16k | 324 | 497 | 413 | 1.53x |
| hdim 64, causal, 16k | 299 | 473 | 388 | 1.58x |
| hdim 128, no causal, 16k | 370 | 648 | 595 | 1.75x |
| hdim 128, causal, 16k | 335 | 616 | 539 | 1.84x |
| hdim 256, no causal, 16k | 326 | 756 | 581 | 2.32x |
| hdim 256, causal, 16k | 298 | 642 | 509 | 2.15x |
결과에서 몇 가지 패턴이 나타난다. 속도 향상은 head dimension 256에서 가장 극적인데, FlashAttention-3가 756 TFLOPs/s(76.4% 활용률)에 도달하여 FlashAttention-2의 326 TFLOPs/s보다 두 배 이상이다. 이는 더 큰 head dimension이 메모리 액세스에 대한 계산의 비율을 증가시켜 비동기 중첩이 비행렬곱셈 레이턴시의 더 큰 부분을 숨길 수 있기 때문이다. FlashAttention-3는 중간에서 긴 시퀀스(1k+)에 대해 일관되게 cuDNN(NVIDIA의 독점 폐쇄 소스 라이브러리)을 능가하는데, 이는 cuDNN이 H100 하드웨어를 위해 특별히 최적화되었다는 점을 고려하면 오픈 소스 구현으로는 주목할 만한 성과이다. FlashAttention-3와 cuDNN 간의 격차는 시퀀스 길이가 증가함에 따라 확대되며, 이는 warp 전문화 및 파이프라이닝 기법이 잘 스케일됨을 시사한다.
Backward Pass 성능 (FP16). backward pass는 1.5-1.75배의 속도 향상을 보여주는데, forward pass보다 다소 낮으며, 이는 더 복잡한 데이터 흐름(2개 대신 5개의 행렬곱셈) 및 추가 dQ 누적 단계 때문일 가능성이 있다.
| 구성 | FA-2 (TFLOPs/s) | FA-3 (TFLOPs/s) | cuDNN (TFLOPs/s) |
|---|---|---|---|
| hdim 64, no causal, 16k | 291 | 474 | 433 |
| hdim 128, no causal, 16k | 322 | 561 | 516 |
FP8 Forward Pass 성능. FP8 FlashAttention-3는 긴 시퀀스에서 head dimension 256에 대해 약 1.2 PFLOPs/s에 근접하여 FP16 처리량의 거의 두 배를 보여준다. head dimension 64의 경우 FlashAttention-3 FP8는 긴 시퀀스에서 cuDNN FP8를 능가한다(causal 마스크 없이 16k에서 613 대 438 TFLOPs/s). head dimension 128 및 256의 경우 결과는 혼합적이다: FlashAttention-3는 causal 마스킹 없이는 경쟁력이 있지만 causal 마스킹에서는 뒤처진다. 이는 FP8 커널에 FP16 커널에 있는 persistent 커널 설계 및 로드 밸런싱 전략이 없기 때문으로 설명된다.
Ablation Study. 고정된 매개변수(batch=4, seqlen=8448, nheads=16, hdim=128)에서의 ablation은 개별 기여를 정량화한다:
| 구성 | 시간 (ms) | TFLOPs/s |
|---|---|---|
| FlashAttention-3 (전체) | 3.538 | 661 |
| GEMM-Softmax 파이프라이닝 없음, Warp-전문화 있음 | 4.021 | 582 |
| GEMM-Softmax 파이프라이닝 있음, Warp-전문화 없음 | 4.105 | 570 |
Warp 전문화 단독으로 570에서 582 TFLOPs/s로의 속도 향상(2.1%)에 기여하며, 그 위에 GEMM-softmax 파이프라이닝을 추가하면 661 TFLOPs/s(13.6% 추가 개선)가 된다. 결합된 효과(총 16% 속도 향상)는 두 기법이 모두 필요하며 상보적임을 보여준다. Warp 전문화는 주로 메모리 레이턴시를 숨기는 반면, GEMM-softmax 파이프라이닝은 비행렬곱셈 계산 레이턴시를 숨긴다.
수치 오차 검증. 정확도 평가는 실제 LLM 활성화 분포를 모방하기 위해 시뮬레이션된 이상치 특징($\mathcal{N}(0,1) + \mathcal{N}(0,100) \cdot \text{Bernoulli}(0.001)$)이 있는 입력을 사용한다.
| 방법 | RMSE |
|---|---|
| Baseline FP16 | $3.2 \times 10^{-4}$ |
| FlashAttention-2 FP16 | $1.9 \times 10^{-4}$ |
| FlashAttention-3 FP16 | $1.9 \times 10^{-4}$ |
| Baseline FP8 (텐서당) | $2.4 \times 10^{-2}$ |
| FlashAttention-3 FP8 (블록 양자화 + incoherent) | $9.1 \times 10^{-3}$ |
| FA-3 FP8 (블록 양자화 없음) | $9.3 \times 10^{-3}$ |
| FA-3 FP8 (incoherent 처리 없음) | $2.4 \times 10^{-2}$ |
FP16에서 FlashAttention-3는 FlashAttention-2와 동일한 RMSE($1.9 \times 10^{-4}$)를 달성하며, 둘 다 baseline보다 $1.7\times$ 우수하여 파이프라이닝이 수치 정확도에 영향을 미치지 않음을 확인한다. FP8에서 결합된 블록 양자화 및 incoherent 처리는 RMSE를 $2.6\times$ 감소시킨다($2.4 \times 10^{-2}$ 대비 $9.1 \times 10^{-3}$). ablation은 incoherent 처리가 지배적인 정확도 개선에 기여함을 보여준다(이것 없이는 RMSE가 $2.4 \times 10^{-2}$로 돌아감), 반면 블록 양자화 단독으로는 적은 이점을 제공한다($9.3 \times 10^{-3}$ 대 $9.1 \times 10^{-3}$). 이는 말이 된다: incoherent 처리는 FP8 양자화 오차의 주요 원인인 이상치 문제를 직접 해결한다.
4. Critical Assessment
Strengths
- 원칙적인 하드웨어-알고리즘 공동 설계. 일반적인 최적화 기법을 적용하는 대신, FlashAttention-3는 Hopper 아키텍처의 특정 기능(TMA, async WGMMA, FP8 Tensor Core)을 중심으로 처음부터 설계되어 알고리즘 설계가 하드웨어와 함께 어떻게 진화해야 하는지를 보여준다.
- 상당하고 일관된 속도 향상. FlashAttention-2 대비 1.5-2.0배의 개선은 다양한 시퀀스 길이, head dimension, 마스킹 구성에 걸쳐 일관되어 틈새 최적화가 아닌 광범위하게 적용 가능하다.
- 벤더 최적화 코드와 경쟁력. 중간에서 긴 시퀀스에서 NVIDIA의 독점 cuDNN 라이브러리를 능가하는 것은 오픈 소스 구현으로는 주목할 만한 성과로, 알고리즘 혁신이 벤더 특정 엔지니어링 노력과 일치하거나 초과할 수 있음을 보여준다.
- FP8 정확도 혁신. 블록 양자화와 incoherent 처리의 조합은 감소된 정밀도에서 수치 정확도를 유지하는 원칙적인 접근 방식을 제공하며, 2.6배 오차 감소가 현실적인 이상치 조건에서 경험적으로 검증되었다.
- 철저한 엔지니어링 분석. 예상된 명령어 인터리빙을 확인하는 SASS 코드 분석, 3단계 파이프라인 부정적 결과, 그리고 상세한 레지스터 압력 논의는 일반적인 시스템 논문이 제공하는 것을 넘어서는 진정한 엔지니어링 통찰력을 제공한다.
- 오픈 소스 릴리스. FlashAttention-3를 관대한 라이선스로 제공하고 PyTorch/HuggingFace 통합을 계획하여 커뮤니티 이점을 극대화한다.
Limitations
- Hopper 특정 설계. 논문은 기법이 “충분히 강력한 비동기 실행”을 갖춘 다른 아키텍처에 적용 가능하다고 주장하지만, 모든 구현 및 평가는 H100을 대상으로 한다. Warp 전문화 및 WGMMA 특정 파이프라이닝의 AMD GPU, Intel GPU 또는 커스텀 가속기로의 이식 가능성은 불명확하다.
- FP8 커널 불완전성. FP8 커널에는 FP16 커널에 있는 persistent 커널 설계 및 로드 밸런싱 전략이 없어 작은 시퀀스 및 causal 마스킹에 대해 최적이 아닌 성능을 보인다. 이는 인정되지만 추론 워크로드에 대한 상당한 격차를 나타낸다.
- 종단 간 학습/추론 평가 없음. 모든 벤치마크는 격리된 커널 수준 처리량을 측정한다. 논문은 종단 간 모델 학습 시간 또는 perplexity에 대한 영향을 평가하지 않아 학습 중 FP8 attention의 실용적 의미를 열린 질문으로 남긴다.
- 컨텍스트에서 FP8 정확도의 제한된 평가. 수치 오차 검증은 시뮬레이션된 이상치가 있는 합성 데이터를 사용한다. Llama 또는 GPT와 같은 모델의 실제 LLM 활성화에 대한 평가는 정확도가 실제 학습 및 추론에 충분하다는 더 강력한 증거를 제공할 것이다.
- 레지스터 압력 트레이드오프가 충분히 탐구되지 않음. 논문은 파이프라인 깊이와 타일 크기 간의 긴장을 인정하지만 Pareto 프론티어를 체계적으로 탐구하지 않는다. 3단계 파이프라인은 레지스터 압력으로 인해 2단계보다 성능이 떨어지지만, 대안적인 타일 크기 선택이 이를 해결할 수 있는지 조사되지 않았다.
- Backward pass의 최적화 부족. backward pass는 더 낮은 속도 향상 비율(1.5-1.75배 대 1.5-2.0배)을 달성하고 GEMM-softmax 파이프라이닝을 포함하지 않아 추가 최적화의 여지가 있음을 시사한다.
Future Directions
- LLM 추론 최적화. 자기회귀 생성의 디코드 단계는 LeanAttention의 로드 밸런싱 접근 방식과 같은 다른 최적화 전략이 필요한 근본적으로 다른 특성(긴 KV 캐시에 대한 단일 토큰 쿼리의 배치)을 가지고 있다.
- Persistent FP8 커널. 작은 시퀀스 길이 및 causal 마스킹 구성에서 cuDNN과의 성능 격차를 줄이기 위해 FP8 커널에 persistent 커널 설계를 통합한다.
- 종단 간 FP8 학습 검증. 블록 양자화 및 incoherent 처리가 있는 FP8 attention이 안정적인 대규모 학습과 호환되는지 이해하여 잠재적으로 전체 FP8 학습 파이프라인을 가능하게 한다.
- Blackwell 아키텍처 적응. NVIDIA의 Blackwell GPU는 FP4 정밀도와 추가로 향상된 TMA 기능을 도입한다. FlashAttention-3의 기법을 FP4 및 Blackwell 특정 기능을 활용하도록 확장하는 것이 자연스러운 다음 단계이다.
- 크로스 플랫폼 일반화. 아키텍처 특정 구현을 요구하는 대신 다양한 가속기(AMD MI300, Intel Gaudi, TPU)를 대상으로 할 수 있는 프레임워크로 비동기 파이프라이닝 원리를 추상화한다.
- Flash Attention
- Flash Attention 2
- Flash Attention 3
Enjoy Reading This Article?
Here are some more articles you might like to read next:
Stay updated — subscribe via RSS
Leave a Comment
Found this useful or have questions? Sign in with GitHub to join the conversation.