Flash Attention 2

TL;DR

FlashAttention-2는 thread block 및 warp 수준에서 GPU 작업 분할을 재설계하여 원본 FlashAttention 대비 약 2배의 속도 향상을 달성하며, A100 GPU에서 이론적 최대 FLOPs/s의 50-73%에 도달합니다. 핵심 혁신은 비matmul FLOPs 감소, 시퀀스 길이 차원에 대한 병렬화, 비효율적인 “split-K” warp 분할 방식 제거입니다. GPT 스타일 모델을 end-to-end 학습에 사용할 때, FlashAttention-2는 A100 GPU당 최대 225 TFLOPs/s에 도달하며 이는 모델 FLOPs 활용률 72%에 해당합니다. 이 접근법은 순수한 시스템 수준의 최적화로 알고리즘 근사 없이 표준 attention과 비트 동일한 결과를 생성하면서도 극적으로 빠릅니다.


Related Papers

FlashAttention 시리즈:

하드웨어 최적화 및 시스템 연구:


Takeaways

1. Contribution

Attention 메커니즘은 Transformer 아키텍처의 핵심에 있지만 근본적인 계산 병목을 부과합니다. 즉, 실행 시간과 메모리가 시퀀스 길이 $N$에 대해 제곱으로 확장되며, 중간 attention 행렬 $\mathbf{S}$와 $\mathbf{P}$를 위해 $O(N^2)$ 저장 공간이 필요합니다. AI 커뮤니티가 더 긴 컨텍스트를 지향함에 따라 – GPT-4는 32k 토큰을 지원하고, MosaicML의 MPT는 65k에 도달하며, Anthropic의 Claude는 100k를 처리합니다 – 이 제곱 비용은 점점 더 제한적이 됩니다. 원본 FlashAttention은 타일링과 재계산을 활용하여 메모리를 $O(N^2)$에서 $O(N)$으로 줄이고 2-4배의 wall-clock 속도 향상을 달성했지만, 상당한 성능을 테이블에 남겨두었습니다. 구체적으로, FlashAttention의 forward pass는 A100의 이론적 최대 FLOPs/s의 30-50%에만 도달했고, backward pass는 25-35%에 그쳤으며, 최적화된 GEMM 연산은 일상적으로 80-90%를 달성합니다.

FlashAttention-2는 이 격차의 근본 원인을 식별합니다: GPU 실행 계층의 두 수준에서 최적이 아닌 작업 분할입니다. thread block 수준에서 FlashAttention은 배치 크기와 헤드 수에 대해서만 병렬화했습니다. 이는 긴 시퀀스에서 (배치 크기가 일반적으로 작은 경우) 많은 스트리밍 멀티프로세서(SM)가 유휴 상태로 있다는 것을 의미합니다. 각 thread block 내의 warp 수준에서, FlashAttention은 모든 warp가 중간 결과를 shared memory에 쓰고, 동기화한 다음, 집계해야 하는 “split-K” 스킴을 사용했습니다. 이는 성능 병목이 되는 불필요한 shared memory 트래픽을 도입했습니다.

이 논문은 FlashAttention과 이론적 하드웨어 상한 사이의 격차를 약 절반 줄이는 세 가지 고유한 기여를 제공합니다. 첫째, 비matmul FLOPs를 줄이기 위해 알고리즘이 조정됩니다. 이는 A100의 Tensor Core가 FP16/BF16 matmul에 대해 312 TFLOPs/s를 제공하지만 비matmul FP32 연산에 대해서는 19.5 TFLOPs/s만 제공하기 때문에 불균형하게 중요합니다 – 16배 차이입니다. 출력 누산기의 “un-scaled” 버전을 유지하고 최종 정규화를 지연시키며, $m$과 $\ell$을 별도로 저장하는 대신 logsumexp $L^{(j)} = m^{(j)} + \log(\ell^{(j)})$만 저장함으로써, FlashAttention-2는 귀중한 비matmul 사이클을 소비할 수 있는 여러 블록당 재스케일링 연산을 제거합니다.

둘째, FlashAttention-2는 forward 및 backward pass 모두에 대해 시퀀스 길이 차원을 따라 병렬화를 도입합니다. forward pass에서 외부 루프는 Q의 행 블록을 반복하며, 각 반복은 당황스러울 정도로 병렬적입니다 – 서로 다른 행 블록은 블록 간 통신 없이 서로 다른 thread block에서 처리될 수 있습니다. backward pass에서 열 블록이 유사하게 병렬화되며, atomic add를 사용하여 유일한 공유 계산($\mathbf{dQ}$ 업데이트)을 처리합니다. 이 시퀀스 수준 병렬성은 배치 $\times$ 헤드 thread block의 수가 A100의 108 SM보다 적을 수 있는 긴 컨텍스트 영역에서 특히 영향력이 있으며, 계산 리소스가 활용되지 않은 상태로 남습니다.

셋째, warp 수준 작업 분할이 근본적으로 재설계되었습니다. $\mathbf{K}$와 $\mathbf{V}$를 warp에 걸쳐 분할하는 대신(warp 간 통신이 필요한 “split-K” 스킴), FlashAttention-2는 $\mathbf{Q}$를 warp에 걸쳐 분할하고 $\mathbf{K}$와 $\mathbf{V}$를 모든 warp가 액세스할 수 있도록 유지합니다. 이는 각 warp가 독립적으로 $\mathbf{QK}^\top$의 슬라이스를 계산하고 공유 $\mathbf{V}$와 곱하여 출력 슬라이스를 생성한다는 것을 의미하며, forward pass에서 shared memory 동기화의 필요성을 제거합니다.

실용적 중요성은 상당합니다. FlashAttention-2는 테스트된 모든 구성에서 FlashAttention 대비 2배의 속도 향상을 달성하며, forward pass에서 최대 230 TFLOPs/s(이론적 최대의 73%)에 도달하고 backward pass에서 최대 196 TFLOPs/s(이론적 최대의 63%)에 도달합니다. end-to-end GPT 학습에서 이는 GPU당 225 TFLOPs/s(모델 FLOPs 활용률 72%)로 전환되며, 이는 FlashAttention 대비 1.3배 개선이고 baseline 구현 대비 2.8배 개선입니다. 이는 실무자들이 이전에 8k 컨텍스트로 학습하던 것과 동일한 비용으로 16k 컨텍스트로 학습할 수 있음을 의미하며, 더 긴 컨텍스트 애플리케이션을 직접 가능하게 합니다.

2. Methodology

2.1 Core Intuition

FlashAttention-2의 근본적인 통찰은 attention과 GEMM 성능 사이의 격차를 좁히려면 여러 세분성 수준에서 GPU 실행 모델을 이해해야 한다는 것입니다. A100과 같은 현대 GPU는 깊이 계층적 구조를 가지고 있습니다: HBM(40-80GB, 1.5-2.0 TB/s 대역폭)이 온칩 SRAM(SM당 192KB, ~19 TB/s 대역폭)에 공급되며, 계산은 SM에 스케줄링된 thread block으로 구성되고, 각 thread block에는 여러 warp(32개 스레드 그룹)가 포함됩니다. 성능은 총 메모리 트래픽을 줄이는 것(FlashAttention이 이미 해결함)뿐만 아니라 작업이 이 계층에 걸쳐 어떻게 분산되는지에 달려 있습니다.

핵심 이론적 통찰은 attention 계산이 GEMM처럼 순수하게 계산 제한적이지 않다는 것입니다. matmul 연산(Tensor Core가 엄청나게 가속화함)과 16배 낮은 처리량으로 실행되는 비matmul 연산(softmax 관련 요소별 계산)이 혼합되어 있습니다. 이는 attention 커널의 효과적인 처리량이 matmul 대 비matmul 연산에 소비된 시간의 비율에 의해 결정된다는 것을 의미합니다. FlashAttention-2는 비matmul FLOPs를 최소화하고 Tensor Core가 적극적으로 작동하는 시간의 비율을 최대화함으로써 이를 공격합니다.

FlashAttention의 온라인 softmax 기술은 전체 $N \times N$ attention 행렬의 구체화 없이 attention 계산을 블록으로 타일링할 수 있게 합니다. 핵심 수학적 속성은 softmax가 나중에 올바른 전역 결과를 생성하기 위해 재스케일링되는 로컬 계산으로 분해될 수 있다는 것입니다. FlashAttention-2는 모든 재스케일링 연산이 모든 반복에서 발생할 필요가 없다는 것을 관찰함으로써 이를 더 발전시킵니다 – 일부는 맨 끝까지 지연될 수 있으며, 내부 루프 반복당 비matmul FLOPs를 절약합니다. 또한, 이 논문은 두 통계($m$과 $\ell$)가 하나(logsumexp $L$)로 압축될 수 있음을 관찰하여 계산과 저장 모두를 줄입니다.

병렬성 수준에서의 통찰은 원본 FlashAttention의 루프 순서(K/V 열 블록에 대한 외부 루프, Q 행 블록에 대한 내부 루프)가 불필요한 직렬화를 생성했다는 것입니다. Q 행 블록에 대한 외부 루프로 전환함으로써 전체 외부 루프가 당황스러울 정도로 병렬적이 되며, 배치 $\times$ 헤드가 SM 수에 비해 작을 때 중요한 시퀀스 수준 병렬성을 가능하게 합니다. backward pass의 경우, 유사한 분석은 열 블록 병렬성이 자연스럽다는 것을 보여주며, 유일한 블록 간 종속성은 atomic 연산을 통해 처리할 수 있는 $\mathbf{dQ}$ 누적입니다.

2.2 Model Architecture

FlashAttention-2는 모델 아키텍처가 아니라 표준 multi-head attention (MHA) 연산의 GPU 커널 구현입니다. 표준 attention과 정확히 동일한 함수를 계산합니다:

\[\mathbf{S} = \mathbf{Q}\mathbf{K}^\top \in \mathbb{R}^{N \times N}, \quad \mathbf{P} = \text{softmax}(\mathbf{S}) \in \mathbb{R}^{N \times N}, \quad \mathbf{O} = \mathbf{P}\mathbf{V} \in \mathbb{R}^{N \times d}\]

시스템의 아키텍처는 GPU 메모리 계층을 통한 데이터 흐름 측면에서 가장 잘 이해됩니다. 입력 $\mathbf{Q}, \mathbf{K}, \mathbf{V} \in \mathbb{R}^{N \times d}$는 HBM에 상주합니다. 계산은 타일을 SRAM에 로드하고, SRAM 내에서 부분 attention을 계산하고, 최종 출력 $\mathbf{O}$와 logsumexp $L \in \mathbb{R}^N$만 HBM에 다시 씁니다. 중간 $N \times N$ 행렬 $\mathbf{S}$와 $\mathbf{P}$는 HBM에 구체화되지 않습니다.

forward pass의 병렬화 구조는 다음과 같이 시각화할 수 있습니다:

Thread Block Grid: [batch_size x num_heads x T_r]

  각 thread block 처리:
    - 하나의 행 블록 Q_i (크기 B_r x d)
    - 모든 열 블록 K_j, V_j를 반복 (내부 루프)
    - 하나의 출력 블록 O_i 생성 (크기 B_r x d)

  각 thread block 내 (4개 warp):
    FlashAttention:        FlashAttention-2:
    K를 warp에 걸쳐 분할   Q를 warp에 걸쳐 분할
    V를 warp에 걸쳐 분할   K,V는 모든 warp가 공유
    -> 동기화 필요         -> warp 간 동기화 불필요

backward pass의 경우, 병렬화가 전치됩니다: 각 thread block은 K/V의 열 블록을 처리하고 Q, dO 등의 행 블록을 반복하며, $\mathbf{dK}_j$와 $\mathbf{dV}_j$를 온칩에서 누적하고 분산된 $\mathbf{dQ}$ 업데이트에 atomic add를 사용합니다.

커널은 또한 causal 마스킹(autoregressive 모델용), multi-query attention (MQA), 및 grouped-query attention (GQA)을 기본적으로 지원합니다. causal 마스킹의 경우, 모든 열 인덱스가 모든 행 인덱스를 초과하는 블록은 완전히 건너뛰어 약 1.7-1.8배의 추가 속도 향상을 제공합니다. MQA/GQA의 경우, 인덱스 조작이 K/V 헤드를 물리적으로 복제하는 것을 피합니다.

2.3 Key Algorithms & Mechanisms

Forward Pass Algorithm (Algorithm 1). forward pass는 $\mathbf{Q}$를 $T_r = \lceil N / B_r \rceil$ 행 블록으로 나누고 $\mathbf{K}, \mathbf{V}$를 $T_c = \lceil N / B_c \rceil$ 열 블록으로 나눕니다. 각 thread block은 하나의 행 블록 $\mathbf{Q}_i$를 처리하고 모든 열 블록을 반복합니다.

각 행 블록 $i$와 열 블록 $j$에 대해 알고리즘은 다음을 수행합니다:

  1. $\mathbf{K}_j, \mathbf{V}_j$를 HBM에서 SRAM으로 로드합니다.
  2. 로컬 attention 점수를 계산합니다: $\mathbf{S}_i^{(j)} = \mathbf{Q}_i \mathbf{K}_j^\top \in \mathbb{R}^{B_r \times B_c}$.
  3. 실행 최대값을 업데이트합니다: $m_i^{(j)} = \max(m_i^{(j-1)}, \text{rowmax}(\mathbf{S}_i^{(j)})) \in \mathbb{R}^{B_r}$.
  4. 로컬 softmax 가중치를 계산합니다: $\tilde{\mathbf{P}}_i^{(j)} = \exp(\mathbf{S}_i^{(j)} - m_i^{(j)}) \in \mathbb{R}^{B_r \times B_c}$.
  5. 지수의 실행 합계를 업데이트합니다: $\ell_i^{(j)} = e^{m_i^{(j-1)} - m_i^{(j)}} \ell_i^{(j-1)} + \text{rowsum}(\tilde{\mathbf{P}}_i^{(j)}) \in \mathbb{R}^{B_r}$.
  6. un-scaled 출력 누산기를 업데이트합니다: $\mathbf{O}_i^{(j)} = \text{diag}(e^{m_i^{(j-1)} - m_i^{(j)}})^{-1} \mathbf{O}_i^{(j-1)} + \tilde{\mathbf{P}}_i^{(j)} \mathbf{V}_j$.

FlashAttention과의 중요한 차이점은 6단계에 있습니다. 원본 FlashAttention은 매 반복마다 $\text{diag}(\ell_i^{(j)})^{-1}$로 정규화하여 반복당 요소당 두 개의 나눗셈 연산이 필요했습니다. FlashAttention-2는 대신 un-scaled 누산기 $\tilde{\mathbf{O}}_i^{(j)}$를 유지하고 모든 열 블록이 처리된 후 한 번만 최종 정규화를 적용합니다:

\[\mathbf{O}_i = \text{diag}(\ell_i^{(T_c)})^{-1} \tilde{\mathbf{O}}_i^{(T_c)}\]

또한, logsumexp는 $L_i = m_i^{(T_c)} + \log(\ell_i^{(T_c)})$로 계산되어 backward pass를 위해 저장되며, $m$과 $\ell$을 별도로 저장할 필요를 대체합니다.

Backward Pass Algorithm (Algorithm 2). backward pass는 열 블록에 대한 외부 루프(인덱스 $j$)와 행 블록에 대한 내부 루프(인덱스 $i$)로 구성됩니다. 각 열 블록 $j$에 대해 알고리즘은 $\mathbf{K}_j, \mathbf{V}_j$를 SRAM에 로드하고 $\mathbf{dK}_j$와 $\mathbf{dV}_j$를 온칩에서 누적합니다. 각 행 블록 $i$에 대해:

  1. $\mathbf{Q}_i, \mathbf{O}_i, \mathbf{dO}_i, \mathbf{dQ}_i, L_i, D_i$를 HBM에서 로드합니다.
  2. 재계산: $\mathbf{S}_i^{(j)} = \mathbf{Q}_i \mathbf{K}_j^\top$.
  3. 저장된 logsumexp를 사용하여 attention 가중치를 재계산합니다: $\mathbf{P}i^{(j)} = \exp(\mathbf{S}{ij} - L_i)$.
  4. 누적: $\mathbf{dV}_j \leftarrow \mathbf{dV}_j + (\mathbf{P}_i^{(j)})^\top \mathbf{dO}_i$.
  5. 계산: $\mathbf{dP}_i^{(j)} = \mathbf{dO}_i \mathbf{V}_j^\top$.
  6. 계산: $\mathbf{dS}_i^{(j)} = \mathbf{P}_i^{(j)} \circ (\mathbf{dP}_i^{(j)} - D_i)$, 여기서 $D_i = \text{rowsum}(\mathbf{dO} \circ \mathbf{O})$.
  7. 누적: $\mathbf{dQ}_i \leftarrow \mathbf{dQ}_i + \mathbf{dS}_i^{(j)} \mathbf{K}_j$ (HBM에 다시 씀).
  8. 누적: $\mathbf{dK}_j \leftarrow \mathbf{dK}_j + (\mathbf{dS}_i^{(j)})^\top \mathbf{Q}_i$.

3단계에서 $\exp(\mathbf{S}_{ij} - L_i)$의 사용은 $m$과 $\ell$을 별도로 저장하고 사용하는 원본 FlashAttention의 접근법보다 더 효율적입니다. 블록당 하나의 벡터 뺄셈과 하나의 요소별 나눗셈을 제거하기 때문입니다.

Warp Partitioning (Forward Pass). 4개의 warp를 포함하는 각 thread block 내에서 핵심 설계 결정은 어떤 행렬을 warp에 걸쳐 분할할 것인지입니다.

FlashAttention의 “split-K” 스킴에서 $\mathbf{K}$는 4개의 warp에 걸쳐 분할되므로 각 warp는 $\mathbf{Q} \cdot (\mathbf{K}{\text{warp}})^\top$를 계산하여 attention 점수의 부분 슬라이스를 얻습니다. 부분 $\tilde{\mathbf{P}}{\text{warp}} \cdot \mathbf{V}_{\text{warp}}$를 계산한 후, 모든 warp는 shared memory에 쓰고, 동기화하고, 감소시켜야 합니다. 이 동기화는 (a) shared memory 쓰기/읽기가 대역폭을 소비하고, (b) 동기화 장벽이 warp가 독립적으로 진행하는 것을 방해하기 때문에 비용이 많이 듭니다.

FlashAttention-2에서는 대신 $\mathbf{Q}$가 4개의 warp에 걸쳐 분할됩니다. 각 warp는 $\mathbf{Q}_{\text{warp}} \cdot \mathbf{K}^\top$(attention 점수 행렬의 행 슬라이스)를 계산한 다음 전체 $\mathbf{V}$와 곱하여 출력의 해당 행을 얻습니다. 각 warp가 출력 행의 고유한 집합을 생성하므로 warp 간 감소가 필요하지 않습니다. 이것은 중간 결과에 대한 shared memory 쓰기를 완전히 제거합니다.

Causal Masking Optimization. $j > i$에 대해 $S_{ij} = -\infty$인 causal (autoregressive) attention의 경우, FlashAttention-2는 두 가지 최적화를 구현합니다: (1) 모든 열 인덱스가 모든 행 인덱스를 초과하는 블록은 완전히 건너뜁니다(큰 $N$에 대해 블록의 약 절반), (2) 일부 항목만 마스킹이 필요한 블록의 경우, 마스크는 행당 해당 단일 경계 블록에만 적용되어 완전히 유효한 블록에 대한 불필요한 마스킹 연산을 피합니다.

2.4 Implementation Details

Block Sizes. 일반적인 블록 크기는 $(B_r, B_c)$에 대해 ${64, 128} \times {64, 128}$에서 선택됩니다. 선택은 헤드 차원 $d$와 GPU shared memory 용량(A100의 SM당 192KB)에 따라 달라집니다. 더 큰 블록은 계산에 대한 메모리 로드의 비율을 줄이지만 레지스터 압력과 shared memory 사용을 증가시킵니다. 임계 블록 크기를 초과하면 로컬 메모리로의 레지스터 스필링이 심각한 성능 저하를 유발합니다. 블록 크기는 헤드 차원당 수동으로 조정되며 4가지 가능한 조합만 있습니다.

Memory Requirements. 입력 행렬 $\mathbf{Q}, \mathbf{K}, \mathbf{V} \in \mathbb{R}^{N \times d}$와 출력 $\mathbf{O} \in \mathbb{R}^{N \times d}$ 외에도, FlashAttention-2는 $O(N)$ 추가 메모리만 필요합니다: logsumexp 벡터 $L \in \mathbb{R}^N$(backward pass를 위해 저장됨)과 벡터 $D = \text{rowsum}(\mathbf{dO} \circ \mathbf{O}) \in \mathbb{R}^N$(backward pass 시작 시 계산됨). 이는 전체 attention 행렬에 대해 $O(N^2)$ 메모리가 필요한 표준 attention과 극명한 대조를 이룹니다.

Computational Complexity. 총 FLOPs는 forward 및 backward pass 모두에 대해 $O(N^2 d)$로 유지되며, 표준 attention과 동일합니다. forward pass는 블록당 2개의 행렬 곱셈($\mathbf{QK}^\top$과 $\tilde{\mathbf{P}}\mathbf{V}$)을 포함하며, backward pass는 5개($\mathbf{QK}^\top$, $\mathbf{P}^\top \mathbf{dO}$, $\mathbf{dO}\mathbf{V}^\top$, $\mathbf{dS}\mathbf{K}$, $\mathbf{dS}^\top\mathbf{Q}$)와 $\mathbf{S}$와 $\mathbf{P}$의 재계산을 포함합니다.

FLOPs Counting. Forward pass FLOPs: $4 \cdot N^2 \cdot d \cdot h$ 여기서 $h$는 헤드 수입니다. causal 마스크의 경우 이는 절반이 됩니다. Backward pass FLOPs: forward pass FLOPs의 $2.5 \times$(5개 matmul 대 2개 matmul의 비율). end-to-end 모델 FLOPs는 Megatron-LM 공식을 따릅니다: $6 \cdot N \cdot P + 12 \cdot L \cdot d_{\text{hidden}} \cdot N^2$, 여기서 $P$는 파라미터 수이고 $L$은 레이어 수입니다.

Hardware. 모든 벤치마크는 A100 80GB SXM4 GPU(312 TFLOPs/s 이론적 FP16/BF16 matmul 피크)를 사용합니다. end-to-end 학습은 8개의 A100을 사용합니다. H100 결과도 보고되며(TMA 또는 4세대 Tensor Core와 같은 H100 특정 최적화 없이) 최대 335 TFLOPs/s를 보여줍니다.

Atomic Operations. backward pass에서 서로 다른 열 블록 thread block의 $\mathbf{dQ}$ 업데이트는 HBM에 대한 atomic add를 사용합니다. 이것은 필요한 유일한 블록 간 동기화이며, 서로 다른 열 블록이 동일한 $\mathbf{dQ}$ 행에 기여하기 때문에 필요합니다.

3. Results

실증 평가는 두 가지 보완적 평가를 통해 FlashAttention-2의 개선 사항을 보여줍니다: 격리된 attention 커널의 마이크로 벤치마크와 GPT 스타일 모델의 end-to-end 학습 처리량입니다.

Attention Kernel Benchmarks. 실험은 A100 80GB SXM4 GPU에서 512에서 16k 토큰까지 시퀀스 길이를 변화시키며, hidden 차원 2048과 헤드 차원은 64 또는 128입니다. 배치 크기는 총 토큰이 16k로 유지되도록 조정됩니다. 모든 구성(causal 마스크 유무, $d = 64$ 또는 $d = 128$)에서 FlashAttention-2는 일관되게 원본 FlashAttention 대비 약 2배의 속도 향상을 달성합니다.

Setting FlashAttention FlashAttention-2 Speedup
Fwd+Bwd, no mask, $d=64$, seq=16k 110 TFLOPs/s 176 TFLOPs/s 1.6x
Fwd+Bwd, no mask, $d=128$, seq=16k 98 TFLOPs/s 203 TFLOPs/s 2.1x
Fwd+Bwd, causal, $d=64$, seq=16k 97 TFLOPs/s 171 TFLOPs/s 1.8x
Fwd+Bwd, causal, $d=128$, seq=16k 83 TFLOPs/s 189 TFLOPs/s 2.3x
Fwd only, no mask, $d=128$, seq=2k 71 TFLOPs/s 227 TFLOPs/s 3.2x
Fwd only, causal, $d=128$, seq=16k 71 TFLOPs/s 197 TFLOPs/s 2.8x

forward pass는 최대 이론적 최대 처리량의 73%에 도달하며(312 TFLOPs/s 중 230 TFLOPs/s), 이는 FlashAttention의 30-50%에서 극적인 개선입니다. backward pass는 최대 이론적 최대의 63%에 도달하며(196 TFLOPs/s), 25-35%에서 증가했습니다. forward와 backward 효율성 사이의 격차는 backward pass가 5개의 행렬 곱셈과 더 복잡한 데이터 종속성을 가지고 있어 덜 효율적인 파이프라이닝을 초래하기 때문에 예상됩니다.

다른 구현과의 비교도 마찬가지로 인상적입니다. FlashAttention-2는 FlashAttention의 Triton 구현보다 1.3-2.5배 빠르고 표준 PyTorch attention보다 3-10배 빠릅니다. 주목할 점은 PyTorch attention이 16k 시퀀스 길이에서 메모리 부족(OOM)이 발생하는 반면, FlashAttention-2는 $O(N)$ 메모리 풋프린트 덕분에 편안하게 처리한다는 것입니다.

헤드 차원 $d = 128$의 결과는 일반적으로 절대 TFLOPs/s에서 $d = 64$보다 우수합니다. 이는 더 큰 행렬 곱셈(더 큰 $d$)이 메모리 액세스에 대한 계산의 더 나은 비율을 가지므로 Tensor Core가 더 효율적으로 활용될 수 있기 때문입니다. causal 마스크 구성은 FLOPs 공식이 절반 계산을 고려하기 때문에 절대 TFLOPs/s가 낮지만, causal 마스킹으로 인한 실제 wall-clock 속도 향상은 논문에서 언급된 약 1.7-1.8배 속도 향상에 반영됩니다.

H100 Results. 동일한 A100 최적화 커널을 H100 GPU에서 실행하면(H100 특정 최적화 없이) forward+backward 결합에 대해 최대 335 TFLOPs/s를 산출하며, 이는 TMA 및 4세대 Tensor Core와 같은 H100 특정 기능으로 또 다른 1.5-2배 개선이 가능함을 시사합니다.

End-to-End Training. end-to-end 평가는 8개의 A100 80GB SXM GPU에서 1.3B 및 2.7B 파라미터의 GPT 스타일 모델을 학습하며, 컨텍스트 길이는 2k 및 8k입니다.

Model No FlashAttention FlashAttention FlashAttention-2
GPT3-1.3B, 2k ctx 142 TFLOPs/s 189 TFLOPs/s 196 TFLOPs/s
GPT3-1.3B, 8k ctx 72 TFLOPs/s 170 TFLOPs/s 220 TFLOPs/s
GPT3-2.7B, 2k ctx 149 TFLOPs/s 189 TFLOPs/s 205 TFLOPs/s
GPT3-2.7B, 8k ctx 80 TFLOPs/s 175 TFLOPs/s 225 TFLOPs/s

속도 향상은 더 긴 시퀀스(8k 컨텍스트)에서 가장 두드러지며, FlashAttention-2는 baseline 대비 2.8배, FlashAttention 대비 1.3배를 달성합니다. 2k 컨텍스트에서 attention 레이어는 총 계산의 더 작은 비율을 구성하며(선형 레이어가 지배적), attention 속도 향상은 더 완만한 end-to-end 개선으로 전환됩니다. 8k 컨텍스트에서 2.7B 모델이 달성한 225 TFLOPs/s는 72%의 모델 FLOPs 활용률을 나타내며, 이는 GEMM 연산이 지배하는 모델의 효율성에 놀라울 정도로 가깝습니다.

커널 수준 속도 향상과 end-to-end 개선 사이의 연결은 attention 계산이 실제로 긴 컨텍스트 학습의 병목임을 검증합니다. 더 짧은 컨텍스트 길이(2k)에서 감소된 속도 향상은 Amdahl의 법칙과 일치합니다: attention이 계산의 더 작은 비율일 때 이를 가속화하는 것이 더 적은 영향을 미칩니다.

4. Critical Assessment

Strengths

  1. 근사 없는 정확한 계산. 근사 attention 방법(Longformer, Performer 등)과 달리 FlashAttention-2는 표준 attention과 정확히 동일한 출력을 계산합니다. 이는 품질 저하에 대한 모든 우려를 제거하고 drop-in 교체가 가능하게 합니다.
  2. 원칙적인 시스템 최적화. 개선 사항은 임시 조정이 아니라 thread block 및 warp 수준에서의 신중한 GPU 성능 분석에 기반합니다. 비matmul FLOPs를 불균형적인 병목으로 식별하는 것(16배 처리량 격차로 인해)은 attention을 넘어 일반화되는 핵심 통찰입니다.
  3. 포괄적인 벤치마킹. 논문은 여러 구성(시퀀스 길이, 헤드 차원, causal 마스크 유무)과 여러 baseline(PyTorch, xformers, Triton FlashAttention)에 걸쳐 평가하여 언제 그리고 왜 속도 향상이 발생하는지에 대한 철저한 그림을 제공합니다.
  4. 실용적 영향. FlashAttention-2는 사실상 모든 현대 LLM 학습 및 추론 프레임워크에서 attention 계산의 사실상의 표준이 되었습니다. github.com/Dao-AILab/flash-attention의 오픈 소스 릴리스는 즉각적인 광범위한 채택을 가능하게 했습니다.
  5. 현대 attention 변형 지원. 물리적 복제가 아닌 암시적 인덱스 조작을 통한 MQA 및 GQA의 기본 지원은 추론 최적화 아키텍처에 실용적으로 중요합니다.
  6. GPU 실행 모델의 명확한 설명. 논문은 thread block, warp, shared memory 및 matmul/비matmul 처리량 격차에 대한 접근 가능한 설명을 제공하여 최적화를 광범위한 청중이 이해할 수 있게 합니다.

Limitations

  1. 하드웨어 특정 최적화. 커널은 NVIDIA A100 GPU에 대해 수동으로 조정되었으며 블록 크기는 헤드 차원당 수동으로 선택됩니다. 논문 자체는 자동 조정 및 다른 하드웨어(H100, AMD GPU)에 대한 적응이 향후 작업임을 언급합니다.
  2. 표준 attention 패턴으로 제한됨. causal 마스킹은 지원되지만 더 복잡한 attention 패턴(로컬 attention, dilated attention, block-sparse attention)은 별도의 커널 구현이 필요합니다. 논문은 저수준 최적화를 고수준 알고리즘 변경과 결합하는 것이 향후 방향임을 인정합니다.
  3. 공식적인 roofline 분석 없음. 논문은 이론적 최대 처리량의 백분율을 보고하지만, 나머지 비효율성을 특정 병목(메모리 대역폭, 계산, 지연 시간)에 정확하게 귀속시킬 상세한 roofline 모델 분석을 제공하지 않습니다.
  4. Backward pass 효율성 격차. backward pass는 이론적 최대의 63%에만 도달하며(forward는 73%), 논문은 이 격차가 존재하는 이유나 어떻게 좁힐 수 있는지에 대한 제한적인 분석을 제공합니다.
  5. Atomic 연산 오버헤드. backward pass는 $\mathbf{dQ}$ 업데이트에 atomic add를 사용하며, 이는 경합을 유발할 수 있습니다. 논문은 이 오버헤드를 정량화하거나 언제 병목이 되는지 논의하지 않습니다.
  6. 수동 블록 크기 조정. 4가지 선택만 있으므로 수동 조정이 가능하지만 일반화되지 않습니다. 논문은 이것을 제한으로 인정하고 자동 조정을 향후 작업으로 연기합니다.

Future Directions

  1. H100/Hopper 특정 최적화. TMA (Tensor Memory Accelerator), 4세대 Tensor Core 및 FP8 데이터 타입을 활용하여 또 다른 1.5-2배 속도 향상을 달성합니다. 논문에서 언급되었고 이후 FlashAttention-3에서 실현되었습니다.
  2. 블록 크기 자동 조정. 검색 기반 또는 머신 러닝 기반 접근법을 사용하여 서로 다른 하드웨어 및 구성 조합에 대한 최적의 블록 크기를 자동으로 선택합니다.
  3. 희소/구조화된 attention과의 통합. FlashAttention-2의 저수준 커널 최적화를 알고리즘적 attention 희소성 패턴(로컬, dilated, block-sparse)과 결합하여 매우 긴 컨텍스트(100k+ 토큰)로 학습을 가능하게 합니다.
  4. 다중 GPU 병렬화. 타일링 전략을 GPU 경계를 넘어 확장하여 분산 학습에서 시퀀스 병렬성을 가능하게 하고 더 긴 컨텍스트를 지원합니다.
  5. 컴파일러 통합. 손으로 작성된 CUDA 커널을 요구하는 대신 컴파일러 인프라(예: Triton, MLIR)를 통해 이러한 최적화 기술을 프로그래밍 가능하게 만들어 사용자 정의 attention 변형의 장벽을 낮춥니다.




    Enjoy Reading This Article?

    Here are some more articles you might like to read next:

  • Tensor Parallel 구현 비교
  • Flash Attention 3
  • Flash Attention
  • Unified Sequence Parallelism
  • Reducing Activation Recomputation in Large Transformer Models
  • Stay updated — subscribe via RSS




    Leave a Comment

    Found this useful or have questions? Sign in with GitHub to join the conversation.