gpumode · 강의 아카이브
《GPU Mode》 L060 2025 · Linear attn High priority transcript · failed

Optimizing Linear Attention

quadratic attention 의 N² 을 N 으로 떨어뜨리는 linear attention 의 큰 가족 — Linear-Attn / RetNet / Mamba / RWKV / GLA / DeltaNet — 을 한 자리에 모으고, GPU 위에서 진짜로 빠르게 돌리려면 어떤 수학적 형태를 써야 하는지를 깐다. Songlin Yang 의 핵심 contribution 인 chunk-wise parallel form 과 flash-linear-attention 라이브러리가 이 강의의 중심.

linear attention Mamba / SSM RWKV RetNet GLA chunk-wise parallel FLA library hybrid sliding-linear
S
Speaker
Songlin Yang
MIT CSAIL · GLA / FLA / DeltaNet 저자
강의 번호
L060
스피커
Songlin Yang
학습 우선순위
High · 정독
자료 상태
transcript 없음 · 논문 기반
§ 01강의가 풀려는 문제· why linear attn

linear attention 이 “이론적으로 빠르다” 가 아니라 — GPU 위에서 정말 빠르게 도는 형태

2020 년 Linear Attention (Katharopoulos et al.) 이 처음 제안된 이래로 — RWKV, RetNet, Mamba, GLA, DeltaNet 까지 — 가족이 거대해졌다. 이론적 으로는 모두 quadratic 의 N² 을 N 으로 줄인다. 그런데 실제 GPU 위에서 dense softmax-attention 보다 빠르게 도는 구현은 거의 없었다. Songlin Yang 의 일련의 work — flash-linear-attention 라이브러리, GLA / DeltaNet / Gated DeltaProduct — 이 그 gap 을 메운다.

강의 transcript 가 비어 있으므로 본 노트는 flash-linear-attention repo, GLA 논문 (Gated Linear Attention Transformers with Hardware-Efficient Training, 2023), 그리고 SSM/linear-attn duality 관련 자료를 base 로 재구성한다. 강의가 다룬 정확한 비중은 확인 필요.

강의의 인지적 frame

linear attention 의 이야기는 두 단계로 나뉜다. (a) 수학적 형태 — kernel trick / RNN-like state / outer product 누적. (b) hardware 친화적 알고리즘 — chunk-wise parallel form 으로 sequence 를 자르고, intra-chunk 는 quadratic 이지만 작아서 OK, inter-chunk 는 matmul 로 연결. 이 두 단계가 분리되어야 GPU 위에서 진짜로 빨라진다.

“linear attention 의 진짜 도전은 N² 을 N 으로 만드는 게 아니다 — GPU 의 tensor core 를 쓰면서 N 으로 만드는 것.”Songlin Yang / 확인 필요
§ 02quadratic 의 한계· N² 의 분포

softmax attention 의 비용은 길이 N 에 quadratic 으로 폭발한다

표준 softmax attention 의 비용을 다시 한 번 — Attn(Q,K,V) = softmax(QK^T/√d) V. 메모리는 O(N²), FLOPs 도 O(N²·d). FlashAttention 이 메모리는 O(N) 으로 줄였지만 FLOPs 는 여전히 N². long context (32K, 128K, 1M) 에서 이게 곧장 문제.

구체적으로 어디가 아픈지 — traininginference 에서 다르게 나타난다.

  • training: 한 forward 에 N² compute. Llama-3 8B, 32K context = ~10 GFLOPs/token (attention 만), 즉 전체 training compute 의 30%+. context 가 늘면 dominant 가 됨.
  • inference (decode): 매 step 마다 KV cache O(N·layers·d) 를 read. KV memory 가 model weights 보다 커진다 — 32K context 에서 70B 의 KV ≈ 5 GB / req.

linear attention 은 두 비용을 모두 떨어뜨린다 — training 의 N² 을 N 으로, inference 의 KV cache 를 fixed-size state 로 (= d² × layers, sequence 와 무관).

FIG · quadratic vs linear, 길이별 attention FLOPsd=128
N=2K · softmax
N=8K · softmax
16×
N=32K · softmax
256×
N=128K · softmax
4096×
N=128K · linear
64×
128K 에서 quadratic 이 4096×, linear 가 64×. 64× 차이. 32K 까지는 차이가 작아서 softmax 가 여전히 dominant — long-context regime 으로 갈 때 의미가 폭발.

그런데 hardware 입장에서는 한 가지 미묘한 점이 있다. softmax attention 은 FlashAttention 으로 tensor core 를 100% 쓴다. linear attention 은 자연스럽게는 element-wise 연산이 많아 tensor core 를 안 쓴다. 그래서 같은 N 에서 linear 가 quadratic 보다 느릴 수도 있다. 이게 강의의 핵심 frame — linear 의 이론적 우위가 hardware 우위가 되려면 형태를 다시 짜야 한다.

§ 03linear attention 의 수식· kernel trick · state

softmax 를 떼면 attention 이 RNN 이 된다 — 그 정확한 수식

linear attention 의 출발 idea — softmax 의 비선형성을 떼고 φ(Q)·φ(K)^T 같은 kernel feature map 으로 대체하면 — 결합법칙이 적용된다. (QK^T)V = Q(K^T V). 오른쪽 형태는 N 에 대해 linear.

표준 softmax attention
o_t = Σᵢ₌₁ᵗ exp(qₜ·kᵢ/√d) · vᵢ / Z_t
→ N² 메모리, t 마다 모든 과거 (k,v) 를 봐야 함
linear attention (kernel trick)
o_t = Σᵢ₌₁ᵗ φ(qₜ)·φ(kᵢ) · vᵢ
= φ(qₜ) · ( Σᵢ₌₁ᵗ φ(kᵢ) ⊗ vᵢ )
= φ(qₜ) · S_t
S_t = Σᵢ φ(kᵢ) ⊗ vᵢ 가 d × d 의 fixed-size state. RNN-like recurrence 로 표현 가능.
RNN 형태 (recurrence)
S_t = S_{t-1} + φ(kₜ) ⊗ vₜ
o_t = φ(qₜ) · S_t
매 step O(d²) 비용. 모든 과거 토큰의 정보가 d×d state 안에 압축.

이 형태의 의미 — linear attention 은 정확히 RNN 이다. softmax 라는 non-linear similarity 를 dot-product (또는 그 변형) 로 바꾸면, attention 이 fixed-size hidden state S_t 를 갖는 RNN 으로 다시 쓸 수 있다.

그러면 자연스럽게 트레이드 오프가 보인다.

  • 장점: inference 시 KV cache 가 state d×d 한 개로 압축. sequence 와 무관한 메모리. throughput 이 길이에 거의 무관.
  • 단점: 정보가 d×d state 안에 ‘섞여서’ 들어간다. 특정 토큰을 selectively recall 하기 어렵다 — 이게 RWKV / Linear Attn 이 long-context QA 에서 약한 이유.
왜 그냥 안 빠른가

RNN form 은 매 step 의 의존성 때문에 sequence 차원으로 parallel 이 안 된다. 한 step 의 S_t 가 다음 step 에 필요. GPU 의 1024 thread 가 동시에 일을 못 한다. 이걸 풀려고 — chunk-wise parallel (§05) 와 hardware-friendly form (§06) 이 등장.

이 RNN 등가성은 GPU 입장에서 parallel scan (Mamba 의 selective scan) 또는 chunk-wise matmul (GLA / RetNet) 으로 풀 수 있다는 사실로 이어진다. 두 길이 어떻게 다른지가 가족별 architectural 차이의 근원.

§ 04가족 비교· family tree

같은 RNN 형태의 다른 변형들 — gating / decay / data-dependent

linear attention 가족의 모든 모델은 S_t = α_t · S_{t-1} + β_t · (k_t ⊗ v_t) 의 변형으로 정리된다. 어떤 게 α, β 가 되는지가 모델을 구분한다.

Linear Attn (2020)

Katharopoulos

α_t = 1, β_t = 1. φ(k) 만 쓴다. 단순. recall 약함.

decay없음
gate없음
state sized×d

RetNet (2023)

Microsoft

α_t = γ (scalar). 시간에 따라 exponential decay. parallel form 가능.

decayscalar γ
gate없음
state sized×d

RWKV (v4–v7)

Bo Peng et al.

α_t 가 channel-wise vector. WKV operator. v6, v7 은 receptance 의 data-dep.

decaychannel
gatev5+
state sized×d

GLA (Songlin · 2023)

MIT

α_t 가 data-dependent vector — input 에서 학습. hardware-friendly chunk form.

decaydata-dep
gateα_t = σ(Wx)
state sized×d

Mamba S6 (2023)

Albert Gu

SSM 출발. selective scan 으로 input-dep A,B,C. continuous-time discretize.

decaydata-dep A
gatedata-dep B,C
state sizeN (≪ d²)

Mamba2 (2024)

Tri Dao et al.

SSM ↔ linear-attn duality 를 명시. matmul-form 으로 tensor core 활용.

decayscalar A
gateB, C
state sized×N

DeltaNet (Songlin · 2024)

MIT

delta rule — state update 가 (β k ⊗ (v − S k)) 형태. associative recall 강함.

decayβ k
updatedelta rule
state sized×d

Gated DeltaNet (2024+)

Songlin · NVIDIA

DeltaNet + GLA 의 gate 추가. 현재 가장 강한 linear-attn 변형 중 하나 (확인 필요).

decayα_t
updatedelta + gate
state sized×d
가족의 진화 축

(1) decay 가 점점 더 fine-grained — none → scalar → channel → data-dependent. 각 단계가 long-context recall 을 향상. (2) state update 가 단순 누적에서 delta-rule 로 — 새 정보를 단순히 더하는 게 아니라 기존 S 를 보고 “수정”. associative recall 향상. (3) 전부 hardware-friendly chunk form 을 갖추기 시작 — Songlin 의 FLA 가 통일된 framework.

§ 05chunk-wise parallel· intra · inter chunk

sequence 를 chunk 로 자르고 — chunk 안은 quadratic, chunk 사이는 matmul

linear attention 을 GPU 위에서 빠르게 돌리는 핵심 트릭. sequence 를 chunk (예: C=64) 로 자르고:

  1. intra-chunk: chunk 안에서는 표준 attention 처럼 quadratic. C² = 4096 — 작아서 BF16 tensor core 위에서 빠름.
  2. inter-chunk: chunk i 의 정보를 chunk i+1 에 전달할 때 — chunk-aggregated state S_i 를 한 번 만들고 matmul 로 다음 chunk 에 적용. 이 부분이 N/C 번 = linear.
FIG · chunk-wise parallel attention patternN=8 chunks · 각 C 토큰
i₁
i₂
·
i₃
·
·
i₄
·
·
·
i₅
·
·
·
·
i₆
·
·
·
·
·
i₇
·
·
·
·
·
·
i₈
■ intra-chunk attention (작은 quadratic, tensor core), ■ inter-chunk propagation (state matmul), ■ via chunk state S (이전 chunks 의 누적). N² 의 N=N_chunks·C 가 sub-quadratic 으로 분해.
chunk-wise form
// chunk i 의 출력
O_i = Q_i · S_{i-1} // inter-chunk: 이전 state 사용
+ Σ (Q_i K_j^T D_{ij}) V_j // intra-chunk: quadratic in C

// chunk i 끝나면 state 갱신
S_i = α_i · S_{i-1} + Σⱼ k_j ⊗ v_j // 이 chunk 의 새 정보 누적
D_{ij} 는 decay mask (RetNet 의 γ, GLA 의 data-dep α 등). 가족마다 모양이 약간 다르지만 이 framework 가 통합.
“linear attention 을 정말 GPU 위에서 빠르게 돌리려면 — RNN 으로 풀 수 있다는 사실은 잊고, chunk 안의 작은 quadratic + chunk 간 matmul 의 hybrid 로 다시 쓴다.”Songlin Yang / GLA paper · 확인 필요
§ 06hardware-friendly form· matmul-form · tensor cores

tensor core 를 쓰지 못하는 구현은 dense attention 보다 느리다

linear attention 의 reference 구현은 대부분 element-wise + small reductions. 그러면 — H100 의 BF16 Tensor Core 가 990 TFLOPs 인데 — element-wise 는 ~100 TFLOPs 도 못 친다. 같은 FLOPs 라도 tensor core 가 10× 더 빠르다는 뜻. linear attention 을 정말 빠르게 만들려면 모든 계산을 matmul 형태로 다시 써야 한다.

Songlin 의 flash-linear-attention 의 디자인 원칙:

  • state update 를 outer-product matmul 로S = αS + k ⊗ vS = αS + K^T V (chunk 차원으로 펼친) matmul 로.
  • decay 를 mask 로 곱한 K, V 형태 — D_{ij} 를 pre-compute 해서 K' = K ⊙ D, V' = V ⊙ D 로 만든 뒤 표준 matmul 로 attention.
  • tile 단위 fused kernel — Triton 으로 짠다. SRAM 안에서 chunk 의 S 가 머무르도록.
  • autograd 에서도 matmul — backward 도 같은 형태로 (역방향 chunk traversal).

이 결과로 — flash-linear-attention 의 GLA 커널이 같은 N=32K 에서 FlashAttention2 보다 ~2× 빠르다 (확인 필요 — 모델·hardware 별 변동). 그리고 정확한 N → ∞ scaling 에서 attention 보다 격차가 폭발적으로 커진다.

FIG · element-wise vs matmul-form 의 throughputH100 BF16
element-wise (naive)
~80 TF
tile-fused
~250 TF
matmul-form
~890 TF
peak BF16
990 TF
같은 algorithm, 같은 FLOPs 인데 — 형태에 따라 10× 차이. tensor core 활용이 단순한 “구현 디테일” 이 아니라 architecture choice 의 일부.
flash-linear-attention library

Songlin 이 메인테인하는 fla-org/flash-linear-attention 가 — Triton 으로 짠 chunked-matmul-form 커널 묶음. GLA, RetNet, RWKV, Mamba2, DeltaNet, Gated DeltaNet 모두 같은 framework 위에. 이 라이브러리가 곧 강의의 코드 부록이라고 봐야 함.

§ 07정확도 vs 속도· recall · long ctx

linear attention 의 약점은 항상 recall — needle in a haystack

linear attention 가족의 일관된 약점 — 특정 토큰을 정확히 회상하는 능력. 이유는 명확하다 — 모든 정보가 fixed-size state d×d 에 압축되므로, 충분히 긴 sequence 에서는 정보가 “찌그러진다”. needle in a haystack, multi-key retrieval, copying 같은 task 에서 softmax attention 보다 약한 게 표준 결과.

softmax · recall
97%
RetNet · recall
55%
Mamba · recall
65%
GLA · recall
78%
DeltaNet · recall
88%
Hybrid (slide+linear)
95%

예시 수치 (확인 필요) — MQAR (multi-query associative recall) 같은 synthetic benchmark 의 일반적 트렌드. 절대값은 모델 크기 / context 길이 / training data 에 따라 변동.

그래서 가족의 진화 축이 명확하다 — recall 향상이 거의 모든 새 변형의 motivation.

  • RetNet → GLA: data-dependent decay 로 “이 토큰은 더 오래 기억해” 라고 학습.
  • GLA → DeltaNet: delta rule — 새 토큰이 들어오면 기존 state 의 관련 영역만 수정. 깨끗한 overwrite.
  • DeltaNet → Gated DeltaNet: gate + delta 둘 다 — recall 과 forget 모두 학습.

그럼에도 불구하고 — pure linear 가 softmax 의 recall 을 완전히 따라잡았다고 보긴 어렵다. 이게 hybrid 패턴 (§08) 이 실용적으로 가장 의미 있는 이유.

언제 linear 가 충분한가

recall 이 critical 하지 않은 task — language modeling 의 perplexity, long-form generation, 일반적 reasoning — 에서는 GLA / DeltaNet 이 softmax 와 거의 같은 품질을 1/10 비용으로 낸다. retrieval-heavy (needle, multi-key, 정확한 copy) 에서는 hybrid 가 안전.

최근 (2024–25) 트렌드 — 완전 linear 모델 (RWKV, Mamba) 보다 hybrid (대부분 linear, 일부 layer 만 softmax) 가 production 채택률이 높아지는 중. Jamba, Zamba, Samba 같은 모델들.

§ 08hybrid 패턴· sliding window + linear

대부분의 layer 는 linear, 몇 개만 softmax — 비용은 linear 에 가깝고 recall 은 softmax 에 가깝다

2024 의 가장 실용적인 architecture 트렌드 — hybrid. transformer 의 일부 layer 를 linear 로, 일부를 softmax (또는 sliding window) 로. 비용의 대부분은 linear 가 차지하지만, softmax layer 가 “정밀 recall” 의 detail 을 가져간다.

FIG · 흔한 hybrid 패턴 비교32-layer 모델
A
all-softmax
표준 transformer · cost 100%
B
all-linear
RWKV/GLA · cost 15%, recall 78%
C
hybrid 7:1
7 linear + 1 softmax/sliding · cost 25%, recall 95%
D
interleaved
4 linear + 4 sliding 반복 · cost 30%, recall 97%
hybrid 의 sweet spot — 매 7~8 layer 마다 한 layer 의 softmax. Jamba, Samba, Zamba 가 이 비율 근처. 더 줄이면 recall, 더 늘리면 cost.
왜 sliding window 가 좋은 hybrid 짝인가

linear attention 이 약한 영역은 정밀 local lookup. “방금 그 토큰” 을 정확히 찾는 일. sliding window attention 이 정확히 이 영역을 — 메모리/시간 모두 적게 — 보강. linear (long-range, fuzzy) + sliding (short-range, exact) 의 조합이 인지적으로 자연스러움.

§ 09다음 방향· DeltaNet · Gated · Test-time

2024 후반 ~ 2025 의 새로운 변형들

“다음 5년의 attention 은 — 한 가지 형태가 아니라 — quadratic / linear / sliding 이 layer 단위로 섞이고, 각각이 자기 자리를 학습하는 형태로 갈 가능성이 크다.”학습 노트
§ 10기억할 메모와 코드· FLA repo

다시 열었을 때 빠르게 잡혀야 할 것

linear attn = RNN
softmax 빼면 fixed-size state d×d 의 RNN. KV cache → state.
chunk-wise parallel form
sequence 를 chunk 로 자름. intra=quadratic(작음), inter=matmul. tensor core 활용.
decay 의 진화
none → scalar (Ret) → channel (RWKV) → data-dep (GLA, Mamba2). 각 단계 recall 향상.
delta rule
단순 누적 대신 (β k ⊗ (v − Sk)). associative recall 강함. DeltaNet 핵심.
recall 약점
needle/copy 에서 softmax 보다 떨어짐. fixed state 압축의 본질적 한계.
hybrid sweet spot
7~8 layer 중 1 개를 softmax 또는 sliding. 비용 ~25%, recall ~95%.
matmul-form 필수
element-wise 구현은 tensor core 못 씀 → 같은 algorithm 의 10× 차이.
FLA library
fla-org/flash-linear-attention · Triton chunked-matmul. 가족 전체 한 framework.

손에 새기기 — 실습 시퀀스

  1. FLA 설치 + GLA 한번 forwardpip install flash-linear-attention. fla.ops.gla.fused_chunk_gla 로 같은 (Q,K,V) 에서 forward. 결과 모양 / 속도 확인.
  2. FlashAttention2 와 head-to-head — 같은 hidden=128, N ∈ {2K, 8K, 32K, 128K} sweep. ms 측정. crossover 가 어디서 일어나는지 plot.
  3. chunk size sweep — GLA 의 BLOCK_SIZE_C ∈ {16, 32, 64, 128, 256}. chunk 가 너무 크면 intra 가 quadratic 으로 폭발, 너무 작으면 launch overhead.
  4. recall 검증 — synthetic MQAR task 에 GLA / DeltaNet 학습. softmax baseline 과 비교. context 별 정확도 곡선.
  5. hybrid 모델 학습 — 6 GLA + 2 sliding-attn layer 의 hybrid. 같은 token budget 에서 perplexity 비교.
  6. state 크기 vs recall 의 trade-off — head dim 을 키우면 state d² 도 커짐. recall 향상되지만 state 메모리 폭발. sweet spot 찾기.
§ 11다른 강의로· connections

이 강의의 frame 이 다음에 어디에 다시 등장하는지

§ 12열린 질문· open questions

transcript 가 비어 있어 직접 검증해야 할 것들

검증 메모

본 노트의 모든 수치 (TFLOPs, recall %, throughput 비교) 는 일반적 산수와 FLA / 가족 논문들의 트렌드를 재구성한 예시값이다. 자기 hardware / 자기 모델에서 직접 측정 후에만 절대값을 인용할 것.

← Lecture 059 FastVideo — sliding tile 이 video DiT 에서 어떻게 구현되는지 Lecture 061 → D-Matrix Corsair — linear attn 과 dataflow inference 칩의 만남