Flash Attention
- Flash Attention
- Flash Attention 2
- Flash Attention 3
TL;DR
FlashAttention은 GPU의 HBM(high bandwidth memory) 읽기/쓰기를 최소화하기 위해 타일링과 재계산을 사용하는 IO-aware 정확한 attention 알고리즘을 제안하며, attention 계산에서 최대 7.6배 속도 향상과 시퀀스 길이에 대한 선형 메모리 스케일링을 달성합니다. 핵심 혁신은 attention 최적화를 FLOP 감소 문제가 아닌 메모리 접근 문제로 재정의하는 것이며, 알고리즘의 IO 복잡도가 $\Theta(N^2 d^2 M^{-1})$로 SRAM 크기 범위에서 최적임을 증명합니다. FlashAttention은 훨씬 긴 컨텍스트(블록 희소 변형으로 최대 64K)로 Transformer 학습을 가능하게 하며, 16K 시퀀스 길이에서 Path-X 챌린지를 해결한 최초의 Transformer를 제공합니다. 주요 한계는 각 새로운 attention 변형에 대해 커스텀 CUDA 커널 작성이 필요하여 상당한 엔지니어링 오버헤드와 GPU 아키텍처별 구현이 필요하다는 것입니다.
- Paper Link: https://arxiv.org/pdf/2205.14135
Related Papers
FlashAttention 시리즈 발전:
- FlashAttention-2 - 2배 빠른 속도와 향상된 병렬화
- FlashAttention-3 - 비대칭 어텐션과 FP8 지원
- From Online Softmax to FlashAttention - 이론적 기초와 온라인 소프트맥스
긴 시퀀스 처리:
- Context Parallelism for Scalable Million-Token Inference - 추론 시 컨텍스트 분산
- Scaling Laws of RoPE-based Extrapolation - 위치 인코딩 확장
- YaRN - RoPE 기반 컨텍스트 길이 확장
- RoFormer - 회전 위치 임베딩
시스템 최적화:
- Reducing Activation Recomputation in Large Transformer Models - 메모리 효율적인 병렬 훈련
- USP - 통합 시퀀스 병렬화 프레임워크
- Tensor Parallelism - 텐서 병렬화와의 결합
- GPipe - 파이프라인 병렬화와의 통합
관련 어텐션 최적화:
- GQA - Grouped-Query Attention으로 추론 효율성 향상
- MQA - Multi-Query Attention으로 메모리 대역폭 최적화
- Memory-efficient Attention - 메모리 효율적인 어텐션 알고리즘
- GLU Variants Improve Transformer - Transformer 피드포워드 레이어 개선
Takeaways
1. Contribution
Self-attention 메커니즘은 Transformer 아키텍처의 핵심이지만, 시퀀스 길이에 대한 제곱 시간 및 메모리 복잡도는 더 긴 컨텍스트로 확장하는 데 지속적인 병목 현상이었습니다. FlashAttention 이전에는 이 병목 현상을 해결하기 위한 지배적인 접근 방식은 근사 attention 방법 설계였습니다. Reformer와 같은 희소 근사, Performer와 Linformer와 같은 저랭크 근사, 또는 BigBird와 Longformer와 같은 하이브리드 방법들이 있었습니다. 이러한 접근 방식은 계산 복잡도(FLOP)를 $O(N^2)$에서 $N$에 대해 선형 또는 거의 선형으로 줄이는 것을 목표로 했습니다. 그러나 당혹스러운 관찰이 지속되었습니다. 이러한 방법 중 많은 수가 이론적 FLOP 감소에도 불구하고 표준 attention에 비해 의미 있는 실제 시간 속도 향상으로 이어지지 않았습니다. 일부는 더 느렸습니다. 이론적 효율성과 실제 성능 사이의 이러한 불일치는 커뮤니티가 attention 최적화에 대해 추론하는 방식에 근본적인 격차가 있음을 지적했습니다.
FlashAttention은 IO-awareness라는 누락된 원칙을 명확히 하여 이 격차를 식별하고 해결합니다. 핵심 통찰은 최신 GPU에서 attention의 병목 현상은 산술 연산의 수가 아니라 GPU 메모리 계층의 서로 다른 레벨 간 메모리 읽기 및 쓰기의 수라는 것입니다. A100과 같은 최신 GPU는 온칩 SRAM(약 20MB, 19 TB/s 대역폭)과 high bandwidth memory 또는 HBM(40-80GB, 1.5-2.0 TB/s) 사이에 큰 격차가 있습니다. SRAM은 HBM보다 약 10배 빠르지만 크기는 수 배 작습니다. 표준 attention 구현은 HBM에 전체 $N \times N$ attention 행렬을 구체화하여 $\Theta(Nd + N^2)$ HBM 접근을 발생시킵니다. 대부분의 attention 연산(softmax, masking, dropout)은 compute-bound가 아니라 memory-bound이므로 이러한 HBM 접근이 런타임을 지배합니다.
FlashAttention은 정확한 attention을 계산하면서(근사가 아님) $\Theta(N^2 d^2 M^{-1})$ HBM 접근만 필요로 하며, 여기서 $d$는 head 차원이고 $M$은 SRAM 크기입니다. 일반적인 값($d = 64$-$128$, $M \approx 100$KB)에 대해 이는 HBM 접근의 많은 배수 감소를 나타냅니다. 알고리즘은 HBM에 전체 $N \times N$ attention 행렬을 구체화하지 않고 빠른 온칩 SRAM을 사용하여 블록별로 attention을 계산함으로써 이를 달성합니다. 두 가지 핵심 기술이 이를 가능하게 합니다. (1) 온라인 softmax 정규화를 사용하여 softmax 계산을 블록으로 분해하는 타일링, (2) SRAM에 저장된 Q, K, V 블록에서 재계산하여 역전파를 위해 $O(N^2)$ attention 행렬 저장을 피하는 재계산. 중요한 것은 재계산이 총 FLOP을 증가시키지만, 감소된 HBM 접근이 추가 계산을 보상하고도 남기 때문에 알고리즘이 더 빠르다는 것입니다.
논문은 모든 SRAM 크기에서 정확한 attention 알고리즘이 FlashAttention의 HBM 접근 횟수를 점근적으로 개선할 수 없음을 보여주는 하한을 추가로 증명하여 알고리즘의 최적성을 확립합니다. 블록 희소 확장은 IO 복잡도를 희소성 비율 $s$에 비례하는 인수만큼 더 줄여 $\Theta(Nd + N^2 d^2 M^{-1} s)$ HBM 접근을 달성합니다. 이를 통해 64K 토큰의 시퀀스 길이로 확장할 수 있습니다.
FlashAttention의 실질적 중요성은 상당하고 다차원적입니다. 첫째, 직접적인 학습 속도 향상을 제공합니다. BERT-large에서 MLPerf 1.1 기록보다 15% 빠르고, GPT-2에서 HuggingFace보다 3배 빠르며, Long-Range Arena 벤치마크에서 2.4배 빠릅니다. 둘째, 메모리나 런타임의 비례적 증가 없이 더 긴 컨텍스트 창을 가능하게 합니다. FlashAttention으로 학습된 4K 컨텍스트의 GPT-2는 Megatron-LM을 통한 1K 컨텍스트의 표준 GPT-2보다 30% 빠르면서 perplexity는 0.7 더 좋습니다. 셋째, 더 긴 컨텍스트는 진정으로 새로운 능력을 열어줍니다. FlashAttention은 Path-X(16K 시퀀스 길이)와 Path-256(64K)에서 무작위보다 나은 성능을 달성한 최초의 Transformer를 생성하며, 이는 이전의 모든 Transformer 변형이 실패한 작업입니다. 이러한 결과는 attention 병목 현상이 Transformer 아키텍처 자체에 내재된 것이 아니라 메모리 계층을 무시한 차선의 구현의 산물임을 보여줍니다.
2. Methodology
2.1 핵심 직관
FlashAttention의 이론적 기초는 메모리 계층의 서로 다른 레벨 간 메모리 전송 횟수를 연구하기 위해 Aggarwal과 Vitter(1988)가 원래 개발한 프레임워크인 IO 복잡도 분석에 기반합니다. GPU 계산의 맥락에서 관련 계층은 HBM(주요 GPU 메모리, 크지만 상대적으로 느림)과 SRAM(온칩 메모리, 매우 빠르지만 극히 작음)으로 구성됩니다. 핵심 관찰은 메모리 바운드 연산의 경우 런타임이 산술을 수행하는 데 소비되는 시간이 아니라 이러한 메모리 레벨 간 데이터 전송에 소비되는 시간에 의해 지배된다는 것입니다.
표준 attention은 세 가지 순차적 연산을 포함합니다. (1) $\mathbf{S} = \mathbf{Q}\mathbf{K}^\top$ 계산, (2) $\mathbf{P} = \text{softmax}(\mathbf{S})$ 계산, (3) $\mathbf{O} = \mathbf{P}\mathbf{V}$ 계산. 이러한 각 연산은 HBM에서 입력을 읽고 출력을 HBM에 다시 씁니다. 중간 행렬 $\mathbf{S}$와 $\mathbf{P}$는 모두 $N \times N$이므로 $O(N^2)$ 메모리를 소비하고 읽고 쓰기 위해 $O(N^2)$ HBM 접근이 필요합니다. 2단계의 softmax 연산은 $N \times N$ 행렬에 적용되는 원소별 연산(memory-bound)이기 때문에 행렬 곱셈에 비해 계산적으로 사소함에도 불구하고 상당한 런타임 병목 현상이 되므로 특히 문제가 됩니다.
IO 효율적인 attention 알고리즘을 설계하는 데 있어 중심 과제는 softmax가 K의 모든 열을 결합한다는 것입니다. $\text{softmax}(\mathbf{S}_{i:})$를 계산하려면 행 $i$의 모든 요소의 최댓값과 합을 알아야 하며, 이는 모든 key 벡터에 의존합니다. 이는 겉보기에 블록별 계산을 방지합니다. FlashAttention은 온라인 softmax 기술을 사용하여 이를 극복합니다. 두 벡터의 연결의 softmax는 실행 통계, 구체적으로 실행 최댓값 $m(x)$와 지수의 실행 합 $\ell(x)$를 추적하여 점진적으로 계산할 수 있습니다. $x = [x^{(1)}, x^{(2)}]$가 주어지면:
\[m(x) = \max(m(x^{(1)}), m(x^{(2)}))\] \[\ell(x) = e^{m(x^{(1)}) - m(x)} \ell(x^{(1)}) + e^{m(x^{(2)}) - m(x)} \ell(x^{(2)})\]이 분해는 각 블록을 처리한 후 통계 $(m, \ell)$를 유지하고 업데이트하는 한 softmax를 한 번에 한 블록씩 계산할 수 있음을 의미합니다. 출력 $\mathbf{O}$는 수정된 정규화 인수로 실행 출력을 재스케일링하여 유사하게 점진적으로 업데이트됩니다.
두 번째 핵심 통찰은 역전파와 관련이 있습니다. attention을 통한 표준 역전파는 저장된 $N \times N$ 행렬 $\mathbf{S}$와 $\mathbf{P}$가 필요합니다. FlashAttention은 이것들을 저장하는 대신 출력 $\mathbf{O}$와 softmax 통계 $(m, \ell)$만 저장합니다. 둘 다 크기 $O(N)$이며, 역전파 중에 블록별로 $\mathbf{S}$와 $\mathbf{P}$를 재계산합니다. 이것은 선택적 gradient checkpointing의 한 형태이지만, 속도를 메모리로 교환하는 표준 checkpointing과 달리 FlashAttention의 재계산은 저장된 행렬의 HBM 읽기가 아닌 빠른 SRAM에서 재계산이 발생하기 때문에 실제로 더 빠릅니다.
2.2 모델 아키텍처
FlashAttention은 Transformer 아키텍처를 전혀 수정하지 않습니다. 수학적으로 동일한 출력을 생성하는 표준 attention 계산의 드롭인 대체품입니다. 따라서 아키텍처 수준 뷰는 알고리즘이 GPU 하드웨어에 매핑되는 방식에 관한 것입니다.
데이터 흐름은 입력 행렬의 블록에 대한 중첩 루프 구조로 설명할 수 있습니다:
HBM (40 GB, 1.5 TB/s) SRAM (20 MB, 19 TB/s)
+---------------------------+ +-------------------+
| Q: [N x d] | | |
| K: [N x d] | Load K_j | K_j: [Bc x d] |
| V: [N x d] | ---------> | V_j: [Bc x d] |
| O: [N x d] (accumulator) | | |
| l: [N] (sum stats) | Load Q_i | Q_i: [Br x d] |
| m: [N] (max stats) | ---------> | O_i: [Br x d] |
+---------------------------+ | l_i, m_i: [Br] |
^ | |
| Write O_i, l_i, m_i | Compute on SRAM: |
| <-------------------------- | S_ij = Q_i K_j^T |
| P_ij = softmax |
| O_i += P_ij V_j |
+-------------------+
외부 루프는 K와 V의 블록($j$로 인덱싱)을 반복하고, 내부 루프는 Q의 블록($i$로 인덱싱)을 반복합니다. 블록 크기는 SRAM 활용을 최대화하도록 선택됩니다. $B_c = \lceil M / (4d) \rceil$와 $B_r = \min(\lceil M / (4d) \rceil, d)$이며, 여기서 인수 4는 SRAM에 네 개의 블록(Q, K, V 및 중간 S)을 동시에 맞춰야 함을 설명합니다. 모든 attention 연산(행렬 곱셈, softmax, masking, dropout 및 두 번째 행렬 곱셈)은 단일 CUDA 커널로 융합되어 모든 중간 HBM 읽기/쓰기를 제거합니다.
2.3 핵심 알고리즘 및 메커니즘
순전파(Algorithm 1). 순전파는 다음과 같이 진행됩니다. 먼저, 출력 누적기 $\mathbf{O}$, log-sum-exp 통계 $\ell$, 실행 최댓값 $m$이 HBM에서 초기화됩니다($\mathbf{O} = \mathbf{0}$, $\ell = \mathbf{0}$, $m = -\infty$). 입력 $\mathbf{Q}$, $\mathbf{K}$, $\mathbf{V}$는 블록으로 나뉩니다. $\mathbf{Q}$는 크기 $B_r \times d$의 $T_r = \lceil N / B_r \rceil$ 블록으로, $\mathbf{K}$, $\mathbf{V}$는 크기 $B_c \times d$의 $T_c = \lceil N / B_c \rceil$ 블록으로 나뉩니다.
외부 루프는 각 블록 $\mathbf{K}_j, \mathbf{V}_j$를 HBM에서 SRAM으로 한 번 로드합니다. 각 블록에 대해 내부 루프는 모든 Q 블록을 반복합니다. 각 쌍 $(i, j)$에 대해 알고리즘은 HBM에서 SRAM으로 $\mathbf{Q}_i, \mathbf{O}_i, \ell_i, m_i$를 로드하고 로컬 attention 블록을 계산합니다:
\[\mathbf{S}_{ij} = \mathbf{Q}_i \mathbf{K}_j^\top \in \mathbb{R}^{B_r \times B_c}\]로컬 softmax 통계가 계산됩니다:
\[\tilde{m}_{ij} = \text{rowmax}(\mathbf{S}_{ij}) \in \mathbb{R}^{B_r}, \quad \tilde{\mathbf{P}}_{ij} = \exp(\mathbf{S}_{ij} - \tilde{m}_{ij}) \in \mathbb{R}^{B_r \times B_c}, \quad \tilde{\ell}_{ij} = \text{rowsum}(\tilde{\mathbf{P}}_{ij}) \in \mathbb{R}^{B_r}\]그런 다음 전역 통계가 업데이트됩니다:
\[m_i^{\text{new}} = \max(m_i, \tilde{m}_{ij}), \quad \ell_i^{\text{new}} = e^{m_i - m_i^{\text{new}}} \ell_i + e^{\tilde{m}_{ij} - m_i^{\text{new}}} \tilde{\ell}_{ij}\]마지막으로 출력 누적기가 재스케일링되고 업데이트됩니다:
\[\mathbf{O}_i \leftarrow \text{diag}(\ell_i^{\text{new}})^{-1} \left( \text{diag}(\ell_i) e^{m_i - m_i^{\text{new}}} \mathbf{O}_i + e^{\tilde{m}_{ij} - m_i^{\text{new}}} \tilde{\mathbf{P}}_{ij} \mathbf{V}_j \right)\]이 재스케일링이 정확성을 보장하는 중요한 단계입니다. 이전 출력 $\mathbf{O}_i$는 이전 통계로 정규화되었으므로 현재 블록의 기여를 추가하기 전에 이전 정규화 인수와 새 정규화 인수의 비율로 재가중치되어야 합니다. 모든 $T_c$ K 블록을 처리한 후 최종 $\mathbf{O}$는 $\text{softmax}(\mathbf{Q}\mathbf{K}^\top)\mathbf{V}$와 정확히 같습니다. 논문은 $j$에 대한 귀납법으로 형식적 증명을 제공합니다.
역전파(Algorithm 4). 역전파는 저장된 $N \times N$ attention 행렬에 접근하지 않고 그래디언트 $d\mathbf{Q}$, $d\mathbf{K}$, $d\mathbf{V}$를 계산해야 하므로 더 미묘합니다. 출력 그래디언트 $d\mathbf{O}$가 주어지면, 핵심 통찰은 중간 양 $D_i = d\mathbf{o}_i^\top \mathbf{o}_i$(출력 그래디언트와 출력 벡터의 내적, 둘 다 크기 $d$)가 일반적으로 필요한 $N$ 차원 attention 분포에 대한 비싼 리덕션을 대체할 수 있다는 것입니다.
역전파는 순전파와 동일한 블록 구조를 사용합니다. 각 블록 쌍 $(i, j)$에 대해 저장된 통계에서 attention 블록 $\mathbf{P}{ij} = \text{diag}(\ell_i)^{-1} \exp(\mathbf{S}{ij}^{\text{masked}} - m_i)$를 재계산한 다음 계산합니다:
\[d\mathbf{V}_j \leftarrow d\mathbf{V}_j + (\mathbf{P}_{ij}^{\text{dropped}})^\top d\mathbf{O}_i\] \[d\mathbf{S}_{ij} = \mathbf{P}_{ij} \circ (d\mathbf{P}_{ij} - D_i) \quad \text{where } d\mathbf{P}_{ij} = d\mathbf{O}_i \mathbf{V}_j^\top\] \[d\mathbf{Q}_i \leftarrow d\mathbf{Q}_i + \tau \cdot d\mathbf{S}_{ij} \mathbf{K}_j, \quad d\tilde{\mathbf{K}}_j \leftarrow d\tilde{\mathbf{K}}_j + \tau \cdot d\mathbf{S}_{ij}^\top \mathbf{Q}_i\]Dropout 마스크는 순전파에서 저장되지 않습니다. 대신 의사 난수 생성기 상태가 저장되고 역전파 중에 동일한 dropout 마스크를 재생성하는 데 사용되어 $O(N^2)$ 저장을 피합니다.
IO 복잡도 분석(Theorem 2). FlashAttention의 IO 복잡도는 형식적으로 분석됩니다. K와 V의 각 요소는 HBM에서 한 번 로드됩니다(외부 루프에서). 알고리즘은 Q와 O에 대해 $T_c = \Theta(Nd / M)$ 패스를 만들며(내부 루프에서), 각 패스는 $\Theta(Nd)$ 요소를 로드합니다. 이는 총 HBM 접근을 산출합니다:
\[\Theta(Nd \cdot T_c) = \Theta\left(\frac{N^2 d^2}{M}\right)\]비교를 위해 표준 attention은 $\Theta(Nd + N^2)$ HBM 접근이 필요합니다. 일반적인 값($d = 64$, $M \approx 100$KB = $100{,}000$ 바이트)에 대해 $d^2 \ll M$이므로 FlashAttention은 큰 상수 인수 개선을 제공합니다.
하한(Proposition 3)은 최적성을 증명합니다. $M = \Theta(Nd)$에 대해 모든 알고리즘은 입력을 최소한 한 번 읽어야 하며, $\Omega(Nd)$ 접근이 필요하며, 이는 해당 영역에서 $\Omega(N^2 d^2 M^{-1})$와 일치합니다. 따라서 어떤 알고리즘도 모든 SRAM 크기에서 FlashAttention을 점근적으로 능가할 수 없습니다.
블록 희소 확장. 블록 희소 변형은 간단한 수정입니다. 블록 희소성 마스크 $\mathbf{M}_{ij} = 0$인 블록 쌍 $(i, j)$를 건너뜁니다. 이는 IO 복잡도를 $\Theta(Nd + N^2 d^2 M^{-1} s)$로 줄이며, 여기서 $s$는 0이 아닌 블록의 비율입니다. 나비 희소성 패턴($s = N^{-1/2}$)을 사용하면 $\Theta(N\sqrt{N})$ IO 복잡도를 산출하여 64K 시퀀스 길이로 확장할 수 있습니다.
2.4 구현 세부 사항
FlashAttention은 모든 attention 연산을 단일 GPU 커널 실행으로 융합하는 커스텀 CUDA 커널로 구현됩니다. 블록 크기는 SRAM 용량에 의해 결정됩니다. $B_c = \lceil M / (4d) \rceil$와 $B_r = \min(\lceil M / (4d) \rceil, d)$. 스트리밍 멀티프로세서당 192KB SRAM과 $d = 64$를 가진 A100 GPU에서 이는 약 $B_c = B_r \approx 64$의 블록 크기를 산출합니다.
FLOP 카운트는 순전파에 대해 $O(N^2 d)$(표준 attention과 동일)이고 역전파에 대해 $O(N^2 d)$입니다. 그러나 역전파는 attention 행렬 블록을 재계산하기 때문에 표준 attention보다 더 많은 FLOP을 수행합니다. 총 FLOP이 더 높음에도 불구하고 알고리즘은 IO 절감으로 인해 더 빠릅니다. GPT-2 medium(시퀀스 길이 1024, head 차원 64, 16 head, 배치 크기 64)의 경우 FlashAttention은 표준 attention의 66.6 GFLOP에 비해 75.2 GFLOP을 사용하지만 HBM 읽기/쓰기는 40.3 GB에 비해 4.4 GB만 사용하여 41.7ms에 비해 7.3ms 런타임을 산출합니다.
메모리 사용량은 시퀀스 길이에 선형으로 스케일링됩니다. softmax 통계 $(m, \ell)$에 대해 $O(N)$과 입력/출력 행렬($O(Nd)$). 8 head와 head 차원 64를 가진 시퀀스 길이 64K에서 FlashAttention은 약 13.4 GB를 사용하는 반면 표준 attention은 훨씬 이전에 메모리가 부족합니다. 메모리 절감은 $N \times N$ attention 행렬을 구체화하지 않고 dropout 마스크를 저장하지 않는 것에서 비롯됩니다.
구현은 head 차원 16, 32, 64, 128을 지원하며 모든 Turing 및 Ampere GPU 아키텍처에서 실행됩니다. 블록 희소 변형의 경우 임의의 희소성 패턴을 근사하는 것으로 나타난 나비 희소성 패턴이 사용됩니다. 모든 실험은 FP16 혼합 정밀도 학습을 사용합니다.
3. Results
FlashAttention의 실험 평가는 학습 속도, 더 긴 시퀀스로 모델 품질, attention 런타임 및 메모리의 마이크로벤치마킹의 세 가지 차원에 걸쳐 있습니다.
학습 속도. BERT-large 학습 결과는 고도로 최적화된 업계 기준을 나타내는 MLPerf 1.1 속도 기록에 대해 측정되기 때문에 특히 설득력이 있습니다. FlashAttention은 8xA100-80GB GPU에서 72.0% masked language modeling 정확도에 도달하기 위해 17.4분(10회 실행 평균)을 달성하며, Nvidia의 MLPerf 제출의 20.0분에 비해 15% 개선되었습니다. 이 속도 향상은 FlashAttention이 더 많은 총 FLOP을 수행함에도 불구하고 달성되어 IO 효율성이 지배적 요인임을 강조합니다.
| 모델 | 구현 | Perplexity | 학습 시간(속도 향상) |
|---|---|---|---|
| GPT-2 small | HuggingFace | 18.2 | 9.5일(1.0배) |
| GPT-2 small | Megatron-LM | 18.2 | 4.7일(2.0배) |
| GPT-2 small | FlashAttention | 18.2 | 2.7일(3.5배) |
| GPT-2 medium | HuggingFace | 14.2 | 21.0일(1.0배) |
| GPT-2 medium | Megatron-LM | 14.3 | 11.5일(1.8배) |
| GPT-2 medium | FlashAttention | 14.3 | 6.9일(3.0배) |
GPT-2 결과는 HuggingFace에 비해 최대 3.5배, Megatron-LM에 비해 1.7-2.0배 속도 향상을 보여주며 perplexity는 동일합니다. 검증 perplexity 곡선은 구현 간에 거의 구분할 수 없으며 수치적 안정성을 확인합니다.
Long-Range Arena 벤치마크(시퀀스 길이 1K-4K)에서 FlashAttention은 표준 attention과 비슷한 정확도(5개 작업에서 평균 정확도 59.8 대 59.3)로 2.4배 속도 향상을 달성합니다. 블록 희소 FlashAttention은 2.8배 속도 향상을 달성합니다. 특히 모든 근사 attention 기준(Linformer, Performer, Local Attention, Reformer, Smyrf)은 비슷하거나 더 나쁜 속도 향상을 제공하면서 FlashAttention보다 낮은 평균 정확도를 달성하여 메모리 접근 패턴을 무시하는 근사 방법보다 IO-aware 정확한 attention이 우수하다는 논문의 논제를 검증합니다.
더 긴 시퀀스와 모델 품질. 더 긴 시퀀스를 효율적으로 처리하는 능력은 개선된 모델 품질로 직접 변환됩니다. FlashAttention과 4K 컨텍스트를 사용하는 GPT-2 small은 17.5 perplexity를 달성하며, 1K 컨텍스트로 달성한 18.2에 비해 0.7 개선되었습니다. 그리고 여전히 1K 컨텍스트에서 Megatron-LM보다 30% 빠릅니다.
| 데이터셋 | 512 | 1024 | 2048 | 4096 | 8192 | 16384 |
|---|---|---|---|---|---|---|
| MIMIC-III(micro-$F_1$) | 52.8 | 50.7 | 51.7 | 54.6 | 56.4 | 57.1 |
| ECtHR(micro-$F_1$) | 72.2 | 74.3 | 77.1 | 78.6 | 80.7 | 79.2 |
장문서 분류는 명확한 이점을 보여줍니다. MIMIC-III(의료 퇴원 요약)에서 4.3포인트 개선, ECtHR(법률 사례)에서 8.5포인트 개선이 시퀀스 길이를 512에서 최적 길이로 늘릴 때 나타납니다. MIMIC-III의 매우 긴 길이에서 수익 감소는 의료 텍스트의 도메인별 분포 이동을 반영할 수 있습니다.
Path-X와 Path-256 결과는 아마도 가장 놀랍습니다. 모든 이전 Transformer 변형(Linformer, Performer, Reformer, Local Attention, Smyrf 포함)은 Path-X(16K 시퀀스 길이)에서 무작위 기회보다 나은 성능을 달성하지 못합니다. FlashAttention은 61.4% 정확도를 달성하고 블록 희소 FlashAttention은 Path-256(64K 시퀀스 길이)에서 63.1%를 달성합니다. 이것들은 이진 분류 작업이므로 무작위 성능은 50%입니다. 이러한 결과는 Transformer 아키텍처 자체가 극도로 긴 시퀀스를 처리할 수 있음을 보여줍니다. 장벽은 순전히 구현 산물이었습니다.
마이크로벤치마킹. 40GB HBM을 가진 A100의 상세한 벤치마크는 이론적 예측을 확인합니다. FlashAttention은 2K까지의 시퀀스에 대해 PyTorch attention보다 최대 3배 빠르고 64K로 확장됩니다. 메모리 풋프린트는 선형으로 스케일링됩니다. 64K에서 FlashAttention은 약 13.4 GB를 사용하는 반면 표준 attention은 훨씬 짧은 길이(약 4K)에서 메모리가 부족합니다. FlashAttention은 표준 attention보다 최대 20배 더 메모리 효율적이고 64K에서 Linformer보다 2배 더 효율적입니다. 블록 희소 FlashAttention은 정확한 방법과 근사 방법을 모두 포함하여 모든 시퀀스 길이에서 테스트된 모든 기준보다 빠릅니다.
HBM 접근과 런타임 간의 관계는 경험적으로 검증됩니다. 블록 크기 $B_c$를 변화시키면 HBM 접근 감소가 계산이 compute-bound가 되는 지점(블록 크기 약 256)까지 런타임을 직접 감소시킨다는 것을 보여줍니다. 이는 FlashAttention이 IO 최적화가 가장 중요한 메모리 바운드 영역에서 작동함을 확인합니다.
4. Critical Assessment
Strengths
- 논문은 이전 근사 attention 방법이 FLOP을 줄였음에도 불구하고 실제 시간 속도 향상을 달성하지 못한 이유를 설명하는 근본적이면서도 간과된 원칙(IO-awareness)을 식별하여 진정한 개념적 명확성을 제공합니다.
- FlashAttention은 정확한 attention을 계산하므로 근사 오차가 없고, 정확도-속도 트레이드오프를 위한 하이퍼파라미터 조정이 없으며, 동일한 학습 역학을 가진 드롭인 대체품입니다.
- 이론적 분석은 엄격하고 완전합니다. IO 복잡도가 형식적으로 증명되고 최적성 하한은 정확한 attention 알고리즘이 모든 SRAM 크기에서 FlashAttention을 점근적으로 개선할 수 없음을 보여줍니다.
- 실험 평가는 여러 모델(BERT, GPT-2), 벤치마크(LRA, Path-X, Path-256), 실용적 메트릭(실제 시간, 메모리, perplexity)에 걸쳐 있으며 MLPerf 기록을 포함한 강력한 기준에 대한 검증을 포함합니다.
- 이 접근 방식은 이전에 모든 Transformer 변형으로 달성할 수 없었던 진정으로 새로운 능력(Path-X, Path-256)을 가능하게 하여 기여가 단순히 점진적 속도 향상이 아니라 질적 능력 확장임을 보여줍니다.
Limitations
- 구현은 저수준 언어로 커스텀 CUDA 커널 작성이 필요하여 각 새로운 attention 변형(예: 다른 masking 패턴, 상대 위치 인코딩, cross-attention)에 대해 상당한 엔지니어링 오버헤드가 발생합니다.
- 이 접근 방식은 GPU 아키텍처별입니다. 최적 블록 크기는 GPU 세대(A100 대 V100 대 T4)에 따라 달라지는 SRAM 용량에 의존하며, 구현은 근본적으로 다른 하드웨어(TPU, 커스텀 가속기)에 대해 적응이 필요할 수 있습니다.
- 속도 향상은 더 큰 head 차원($d = 128$)에 대해 감소합니다. 블록이 더 많은 SRAM을 소비하므로 더 작은 블록 크기를 사용해야 하며 IO 이점이 감소합니다. 논문은 $d = 64$에 비해 $d = 128$에 대해 현저히 낮은 속도 향상을 보여줍니다.
- IO 복잡도는 여전히 $N$에 대해 제곱입니다($\Theta(N^2 d^2 M^{-1})$). FlashAttention은 점근적 복잡도 클래스를 변경하지 않고 큰 상수 인수 개선을 제공합니다. 극도로 긴 시퀀스의 경우 제곱 스케일링이 여전히 지배할 것입니다.
- 다중 GPU 설정은 다루지 않습니다. 분석은 단일 GPU 내의 SRAM-HBM 계층만 고려합니다. 대규모 모델 학습에서 일반적인 분산 attention 계산의 경우 GPU 간 통신은 향후 작업으로 남겨진 또 다른 IO 분석 계층을 추가합니다.
- 블록 희소 확장은 학습 전에 선택된 고정 희소성 패턴(나비 패턴)이 필요합니다. 논문은 잠재적으로 더 나은 품질-속도 트레이드오프를 산출할 수 있는 학습된 또는 적응형 희소성 패턴을 탐색하지 않습니다.
Future Directions
- PyTorch의 고수준 attention 알고리즘 설명에서 IO-aware CUDA 커널을 자동으로 생성할 수 있는 컴파일러 기반 접근 방식(이미지 처리를 위한 Halide와 유사)을 개발하여 수동 커널 엔지니어링의 필요성을 제거합니다.
- GPU 간 통신 대역폭을 추가 메모리 계층 레벨로 고려하여 attention이 장치 간에 병렬화되는 다중 GPU 설정으로 IO-aware 분석을 확장합니다.
- IO-aware 최적화 원칙을 attention 이외의 딥 네트워크의 다른 메모리 바운드 연산(MLP 레이어, 정규화 레이어, 손실 함수 계산)에 적용하여 잠재적으로 복합 속도 향상을 산출합니다.
- 모델 학습과 공동으로 최적화할 수 있는 적응형 또는 학습된 블록 희소 패턴을 탐색하여 잠재적으로 FlashAttention의 IO 이점을 작업별 희소성 구조와 결합합니다.
- 혼합 정밀도 학습, 모델 병렬화, activation checkpointing과 같은 다른 효율성 기술과 FlashAttention이 상호 작용하는 방식을 조사하여 IO-aware 딥러닝 학습을 위한 통합 프레임워크를 개발합니다.
- Flash Attention
- Flash Attention 2
- Flash Attention 3
Enjoy Reading This Article?
Here are some more articles you might like to read next:
Stay updated — subscribe via RSS
Leave a Comment
Found this useful or have questions? Sign in with GitHub to join the conversation.