《GPU Mode》
L060
2025 · Linear attn
High priority
transcript · failed
Optimizing Linear Attention
quadratic attention 의 N² 을 N 으로 떨어뜨리는 linear attention 의 큰 가족 — Linear-Attn / RetNet / Mamba / RWKV / GLA / DeltaNet — 을 한 자리에 모으고, GPU 위에서 진짜로 빠르게 돌리려면 어떤 수학적 형태를 써야 하는지를 깐다. Songlin Yang 의 핵심 contribution 인 chunk-wise parallel form 과 flash-linear-attention 라이브러리가 이 강의의 중심.
linear attention
Mamba / SSM
RWKV
RetNet
GLA
chunk-wise parallel
FLA library
hybrid sliding-linear
S
Speaker
Songlin Yang
MIT CSAIL · GLA / FLA / DeltaNet 저자
§ 01강의가 풀려는 문제· why linear attn
linear attention 이 “이론적으로 빠르다” 가 아니라 — GPU 위에서 정말 빠르게 도는 형태
2020 년 Linear Attention (Katharopoulos et al.) 이 처음 제안된 이래로 — RWKV, RetNet, Mamba, GLA, DeltaNet 까지 — 가족이 거대해졌다. 이론적 으로는 모두 quadratic 의 N² 을 N 으로 줄인다. 그런데 실제 GPU 위에서 dense softmax-attention 보다 빠르게 도는 구현은 거의 없었다. Songlin Yang 의 일련의 work — flash-linear-attention 라이브러리, GLA / DeltaNet / Gated DeltaProduct — 이 그 gap 을 메운다.
강의 transcript 가 비어 있으므로 본 노트는 flash-linear-attention repo, GLA 논문 (Gated Linear Attention Transformers with Hardware-Efficient Training, 2023), 그리고 SSM/linear-attn duality 관련 자료를 base 로 재구성한다. 강의가 다룬 정확한 비중은 확인 필요.
강의의 인지적 frame
linear attention 의 이야기는 두 단계로 나뉜다. (a) 수학적 형태 — kernel trick / RNN-like state / outer product 누적. (b) hardware 친화적 알고리즘 — chunk-wise parallel form 으로 sequence 를 자르고, intra-chunk 는 quadratic 이지만 작아서 OK, inter-chunk 는 matmul 로 연결. 이 두 단계가 분리되어야 GPU 위에서 진짜로 빨라진다.
“linear attention 의 진짜 도전은 N² 을 N 으로 만드는 게 아니다 — GPU 의 tensor core 를 쓰면서 N 으로 만드는 것.”Songlin Yang / 확인 필요
§ 02quadratic 의 한계· N² 의 분포
softmax attention 의 비용은 길이 N 에 quadratic 으로 폭발한다
표준 softmax attention 의 비용을 다시 한 번 — Attn(Q,K,V) = softmax(QK^T/√d) V. 메모리는 O(N²), FLOPs 도 O(N²·d). FlashAttention 이 메모리는 O(N) 으로 줄였지만 FLOPs 는 여전히 N². long context (32K, 128K, 1M) 에서 이게 곧장 문제.
구체적으로 어디가 아픈지 — training 과 inference 에서 다르게 나타난다.
- training: 한 forward 에 N² compute. Llama-3 8B, 32K context = ~10 GFLOPs/token (attention 만), 즉 전체 training compute 의 30%+. context 가 늘면 dominant 가 됨.
- inference (decode): 매 step 마다 KV cache
O(N·layers·d) 를 read. KV memory 가 model weights 보다 커진다 — 32K context 에서 70B 의 KV ≈ 5 GB / req.
linear attention 은 두 비용을 모두 떨어뜨린다 — training 의 N² 을 N 으로, inference 의 KV cache 를 fixed-size state 로 (= d² × layers, sequence 와 무관).
FIG · quadratic vs linear, 길이별 attention FLOPsd=128
128K 에서 quadratic 이 4096×, linear 가 64×. 64× 차이. 32K 까지는 차이가 작아서 softmax 가 여전히 dominant — long-context regime 으로 갈 때 의미가 폭발.
그런데 hardware 입장에서는 한 가지 미묘한 점이 있다. softmax attention 은 FlashAttention 으로 tensor core 를 100% 쓴다. linear attention 은 자연스럽게는 element-wise 연산이 많아 tensor core 를 안 쓴다. 그래서 같은 N 에서 linear 가 quadratic 보다 느릴 수도 있다. 이게 강의의 핵심 frame — linear 의 이론적 우위가 hardware 우위가 되려면 형태를 다시 짜야 한다.
§ 03linear attention 의 수식· kernel trick · state
softmax 를 떼면 attention 이 RNN 이 된다 — 그 정확한 수식
linear attention 의 출발 idea — softmax 의 비선형성을 떼고 φ(Q)·φ(K)^T 같은 kernel feature map 으로 대체하면 — 결합법칙이 적용된다. (QK^T)V = Q(K^T V). 오른쪽 형태는 N 에 대해 linear.
표준 softmax attention
o_t = Σᵢ₌₁ᵗ exp(qₜ·kᵢ/√d) · vᵢ / Z_t
→ N² 메모리, t 마다 모든 과거 (k,v) 를 봐야 함
linear attention (kernel trick)
o_t = Σᵢ₌₁ᵗ φ(qₜ)·φ(kᵢ) · vᵢ
= φ(qₜ) · ( Σᵢ₌₁ᵗ φ(kᵢ) ⊗ vᵢ )
= φ(qₜ) · S_t
S_t = Σᵢ φ(kᵢ) ⊗ vᵢ 가 d × d 의 fixed-size state. RNN-like recurrence 로 표현 가능.
RNN 형태 (recurrence)
S_t = S_{t-1} + φ(kₜ) ⊗ vₜ
o_t = φ(qₜ) · S_t
매 step O(d²) 비용. 모든 과거 토큰의 정보가 d×d state 안에 압축.
이 형태의 의미 — linear attention 은 정확히 RNN 이다. softmax 라는 non-linear similarity 를 dot-product (또는 그 변형) 로 바꾸면, attention 이 fixed-size hidden state S_t 를 갖는 RNN 으로 다시 쓸 수 있다.
그러면 자연스럽게 트레이드 오프가 보인다.
- 장점: inference 시 KV cache 가 state d×d 한 개로 압축. sequence 와 무관한 메모리. throughput 이 길이에 거의 무관.
- 단점: 정보가 d×d state 안에 ‘섞여서’ 들어간다. 특정 토큰을 selectively recall 하기 어렵다 — 이게 RWKV / Linear Attn 이 long-context QA 에서 약한 이유.
왜 그냥 안 빠른가
RNN form 은 매 step 의 의존성 때문에 sequence 차원으로 parallel 이 안 된다. 한 step 의 S_t 가 다음 step 에 필요. GPU 의 1024 thread 가 동시에 일을 못 한다. 이걸 풀려고 — chunk-wise parallel (§05) 와 hardware-friendly form (§06) 이 등장.
이 RNN 등가성은 GPU 입장에서 parallel scan (Mamba 의 selective scan) 또는 chunk-wise matmul (GLA / RetNet) 으로 풀 수 있다는 사실로 이어진다. 두 길이 어떻게 다른지가 가족별 architectural 차이의 근원.
§ 04가족 비교· family tree
같은 RNN 형태의 다른 변형들 — gating / decay / data-dependent
linear attention 가족의 모든 모델은 S_t = α_t · S_{t-1} + β_t · (k_t ⊗ v_t) 의 변형으로 정리된다. 어떤 게 α, β 가 되는지가 모델을 구분한다.
Linear Attn (2020)
Katharopoulos
α_t = 1, β_t = 1. φ(k) 만 쓴다. 단순. recall 약함.
decay없음
gate없음
state sized×d
RetNet (2023)
Microsoft
α_t = γ (scalar). 시간에 따라 exponential decay. parallel form 가능.
decayscalar γ
gate없음
state sized×d
RWKV (v4–v7)
Bo Peng et al.
α_t 가 channel-wise vector. WKV operator. v6, v7 은 receptance 의 data-dep.
decaychannel
gatev5+
state sized×d
GLA (Songlin · 2023)
MIT
α_t 가 data-dependent vector — input 에서 학습. hardware-friendly chunk form.
decaydata-dep
gateα_t = σ(Wx)
state sized×d
Mamba S6 (2023)
Albert Gu
SSM 출발. selective scan 으로 input-dep A,B,C. continuous-time discretize.
decaydata-dep A
gatedata-dep B,C
state sizeN (≪ d²)
Mamba2 (2024)
Tri Dao et al.
SSM ↔ linear-attn duality 를 명시. matmul-form 으로 tensor core 활용.
decayscalar A
gateB, C
state sized×N
DeltaNet (Songlin · 2024)
MIT
delta rule — state update 가 (β k ⊗ (v − S k)) 형태. associative recall 강함.
decayβ k
updatedelta rule
state sized×d
Gated DeltaNet (2024+)
Songlin · NVIDIA
DeltaNet + GLA 의 gate 추가. 현재 가장 강한 linear-attn 변형 중 하나 (확인 필요).
decayα_t
updatedelta + gate
state sized×d
가족의 진화 축
(1) decay 가 점점 더 fine-grained — none → scalar → channel → data-dependent. 각 단계가 long-context recall 을 향상. (2) state update 가 단순 누적에서 delta-rule 로 — 새 정보를 단순히 더하는 게 아니라 기존 S 를 보고 “수정”. associative recall 향상. (3) 전부 hardware-friendly chunk form 을 갖추기 시작 — Songlin 의 FLA 가 통일된 framework.
§ 05chunk-wise parallel· intra · inter chunk
sequence 를 chunk 로 자르고 — chunk 안은 quadratic, chunk 사이는 matmul
linear attention 을 GPU 위에서 빠르게 돌리는 핵심 트릭. sequence 를 chunk (예: C=64) 로 자르고:
- intra-chunk: chunk 안에서는 표준 attention 처럼 quadratic.
C² = 4096 — 작아서 BF16 tensor core 위에서 빠름.
- inter-chunk: chunk i 의 정보를 chunk i+1 에 전달할 때 — chunk-aggregated state S_i 를 한 번 만들고 matmul 로 다음 chunk 에 적용. 이 부분이 N/C 번 = linear.
FIG · chunk-wise parallel attention patternN=8 chunks · 각 C 토큰
i₁
↓
i₂
·
↓
i₃
·
·
↓
i₄
·
·
·
↓
i₅
·
·
·
·
↓
i₆
·
·
·
·
·
↓
i₇
·
·
·
·
·
·
↓
i₈
■ intra-chunk attention (작은 quadratic, tensor core), ■ inter-chunk propagation (state matmul), ■ via chunk state S (이전 chunks 의 누적). N² 의 N=N_chunks·C 가 sub-quadratic 으로 분해.
chunk-wise form
// chunk i 의 출력
O_i = Q_i · S_{i-1} // inter-chunk: 이전 state 사용
+ Σ (Q_i K_j^T D_{ij}) V_j // intra-chunk: quadratic in C
// chunk i 끝나면 state 갱신
S_i = α_i · S_{i-1} + Σⱼ k_j ⊗ v_j // 이 chunk 의 새 정보 누적
D_{ij} 는 decay mask (RetNet 의 γ, GLA 의 data-dep α 등). 가족마다 모양이 약간 다르지만 이 framework 가 통합.
“linear attention 을 정말 GPU 위에서 빠르게 돌리려면 — RNN 으로 풀 수 있다는 사실은 잊고, chunk 안의 작은 quadratic + chunk 간 matmul 의 hybrid 로 다시 쓴다.”Songlin Yang / GLA paper · 확인 필요
§ 06hardware-friendly form· matmul-form · tensor cores
tensor core 를 쓰지 못하는 구현은 dense attention 보다 느리다
linear attention 의 reference 구현은 대부분 element-wise + small reductions. 그러면 — H100 의 BF16 Tensor Core 가 990 TFLOPs 인데 — element-wise 는 ~100 TFLOPs 도 못 친다. 같은 FLOPs 라도 tensor core 가 10× 더 빠르다는 뜻. linear attention 을 정말 빠르게 만들려면 모든 계산을 matmul 형태로 다시 써야 한다.
Songlin 의 flash-linear-attention 의 디자인 원칙:
- state update 를 outer-product matmul 로 —
S = αS + k ⊗ v 를 S = αS + K^T V (chunk 차원으로 펼친) matmul 로.
- decay 를 mask 로 곱한 K, V 형태 — D_{ij} 를 pre-compute 해서 K' = K ⊙ D, V' = V ⊙ D 로 만든 뒤 표준 matmul 로 attention.
- tile 단위 fused kernel — Triton 으로 짠다. SRAM 안에서 chunk 의 S 가 머무르도록.
- autograd 에서도 matmul — backward 도 같은 형태로 (역방향 chunk traversal).
이 결과로 — flash-linear-attention 의 GLA 커널이 같은 N=32K 에서 FlashAttention2 보다 ~2× 빠르다 (확인 필요 — 모델·hardware 별 변동). 그리고 정확한 N → ∞ scaling 에서 attention 보다 격차가 폭발적으로 커진다.
FIG · element-wise vs matmul-form 의 throughputH100 BF16
element-wise (naive)
~80 TF
같은 algorithm, 같은 FLOPs 인데 — 형태에 따라 10× 차이. tensor core 활용이 단순한 “구현 디테일” 이 아니라 architecture choice 의 일부.
flash-linear-attention library
Songlin 이 메인테인하는 fla-org/flash-linear-attention 가 — Triton 으로 짠 chunked-matmul-form 커널 묶음. GLA, RetNet, RWKV, Mamba2, DeltaNet, Gated DeltaNet 모두 같은 framework 위에. 이 라이브러리가 곧 강의의 코드 부록이라고 봐야 함.
§ 07정확도 vs 속도· recall · long ctx
linear attention 의 약점은 항상 recall — needle in a haystack
linear attention 가족의 일관된 약점 — 특정 토큰을 정확히 회상하는 능력. 이유는 명확하다 — 모든 정보가 fixed-size state d×d 에 압축되므로, 충분히 긴 sequence 에서는 정보가 “찌그러진다”. needle in a haystack, multi-key retrieval, copying 같은 task 에서 softmax attention 보다 약한 게 표준 결과.
예시 수치 (확인 필요) — MQAR (multi-query associative recall) 같은 synthetic benchmark 의 일반적 트렌드. 절대값은 모델 크기 / context 길이 / training data 에 따라 변동.
그래서 가족의 진화 축이 명확하다 — recall 향상이 거의 모든 새 변형의 motivation.
- RetNet → GLA: data-dependent decay 로 “이 토큰은 더 오래 기억해” 라고 학습.
- GLA → DeltaNet: delta rule — 새 토큰이 들어오면 기존 state 의 관련 영역만 수정. 깨끗한 overwrite.
- DeltaNet → Gated DeltaNet: gate + delta 둘 다 — recall 과 forget 모두 학습.
그럼에도 불구하고 — pure linear 가 softmax 의 recall 을 완전히 따라잡았다고 보긴 어렵다. 이게 hybrid 패턴 (§08) 이 실용적으로 가장 의미 있는 이유.
언제 linear 가 충분한가
recall 이 critical 하지 않은 task — language modeling 의 perplexity, long-form generation, 일반적 reasoning — 에서는 GLA / DeltaNet 이 softmax 와 거의 같은 품질을 1/10 비용으로 낸다. retrieval-heavy (needle, multi-key, 정확한 copy) 에서는 hybrid 가 안전.
최근 (2024–25) 트렌드 — 완전 linear 모델 (RWKV, Mamba) 보다 hybrid (대부분 linear, 일부 layer 만 softmax) 가 production 채택률이 높아지는 중. Jamba, Zamba, Samba 같은 모델들.
§ 08hybrid 패턴· sliding window + linear
대부분의 layer 는 linear, 몇 개만 softmax — 비용은 linear 에 가깝고 recall 은 softmax 에 가깝다
2024 의 가장 실용적인 architecture 트렌드 — hybrid. transformer 의 일부 layer 를 linear 로, 일부를 softmax (또는 sliding window) 로. 비용의 대부분은 linear 가 차지하지만, softmax layer 가 “정밀 recall” 의 detail 을 가져간다.
FIG · 흔한 hybrid 패턴 비교32-layer 모델
A
all-softmax
표준 transformer · cost 100%
B
all-linear
RWKV/GLA · cost 15%, recall 78%
C
hybrid 7:1
7 linear + 1 softmax/sliding · cost 25%, recall 95%
D
interleaved
4 linear + 4 sliding 반복 · cost 30%, recall 97%
hybrid 의 sweet spot — 매 7~8 layer 마다 한 layer 의 softmax. Jamba, Samba, Zamba 가 이 비율 근처. 더 줄이면 recall, 더 늘리면 cost.
- Jamba (AI21): Mamba + softmax + MoE. 8 layer 중 1 이 softmax. 256K context 까지 학습.
- Samba (Microsoft): Mamba + sliding window. softmax 대신 짧은 window attention 으로 — local recall 만 살림. 성능 좋음.
- Zamba: Mamba + shared softmax block (한 block 을 여러 layer 가 공유).
- 최근 NVIDIA Nemotron-H: Mamba2 + softmax 의 hybrid (확인 필요).
왜 sliding window 가 좋은 hybrid 짝인가
linear attention 이 약한 영역은 정밀 local lookup. “방금 그 토큰” 을 정확히 찾는 일. sliding window attention 이 정확히 이 영역을 — 메모리/시간 모두 적게 — 보강. linear (long-range, fuzzy) + sliding (short-range, exact) 의 조합이 인지적으로 자연스러움.
§ 09다음 방향· DeltaNet · Gated · Test-time
2024 후반 ~ 2025 의 새로운 변형들
- Gated DeltaNet / DeltaProduct — Songlin 그룹과 NVIDIA 의 collaboration. delta rule + gate 의 결합. 같은 cost 에서 가장 강한 recall (확인 필요 — 빠르게 진화 중).
- Test-time Training (TTT) — Sun et al. 2024. linear attention 을 “학습 가능한 RNN” 으로 보고, hidden state 를 inner gradient descent 로 update. attention 이 곧 implicit fine-tune.
- state expansion — d×d state 를 여러 head 에 분산해 effective state 를 키움. RWKV-7, Mamba2 의 한 갈래.
- SSM ↔ linear-attn duality 의 통합 view — Mamba2 paper 가 SSM 과 linear-attn 이 본질적으로 같은 framework 임을 명시. 두 영역의 trick 이 서로 옮겨짐.
- compiler 통합 — flash-linear-attention 이 PyTorch 의
torch.compile 과 jointly. 새 모델을 빠르게 prototyping 할 수 있게.
“다음 5년의 attention 은 — 한 가지 형태가 아니라 — quadratic / linear / sliding 이 layer 단위로 섞이고, 각각이 자기 자리를 학습하는 형태로 갈 가능성이 크다.”학습 노트
§ 10기억할 메모와 코드· FLA repo
다시 열었을 때 빠르게 잡혀야 할 것
linear attn = RNN
softmax 빼면 fixed-size state d×d 의 RNN. KV cache → state.
chunk-wise parallel form
sequence 를 chunk 로 자름. intra=quadratic(작음), inter=matmul. tensor core 활용.
decay 의 진화
none → scalar (Ret) → channel (RWKV) → data-dep (GLA, Mamba2). 각 단계 recall 향상.
delta rule
단순 누적 대신 (β k ⊗ (v − Sk)). associative recall 강함. DeltaNet 핵심.
recall 약점
needle/copy 에서 softmax 보다 떨어짐. fixed state 압축의 본질적 한계.
hybrid sweet spot
7~8 layer 중 1 개를 softmax 또는 sliding. 비용 ~25%, recall ~95%.
matmul-form 필수
element-wise 구현은 tensor core 못 씀 → 같은 algorithm 의 10× 차이.
FLA library
fla-org/flash-linear-attention · Triton chunked-matmul. 가족 전체 한 framework.
손에 새기기 — 실습 시퀀스
- FLA 설치 + GLA 한번 forward —
pip install flash-linear-attention. fla.ops.gla.fused_chunk_gla 로 같은 (Q,K,V) 에서 forward. 결과 모양 / 속도 확인.
- FlashAttention2 와 head-to-head — 같은 hidden=128, N ∈ {2K, 8K, 32K, 128K} sweep. ms 측정. crossover 가 어디서 일어나는지 plot.
- chunk size sweep — GLA 의 BLOCK_SIZE_C ∈ {16, 32, 64, 128, 256}. chunk 가 너무 크면 intra 가 quadratic 으로 폭발, 너무 작으면 launch overhead.
- recall 검증 — synthetic MQAR task 에 GLA / DeltaNet 학습. softmax baseline 과 비교. context 별 정확도 곡선.
- hybrid 모델 학습 — 6 GLA + 2 sliding-attn layer 의 hybrid. 같은 token budget 에서 perplexity 비교.
- state 크기 vs recall 의 trade-off — head dim 을 키우면 state d² 도 커짐. recall 향상되지만 state 메모리 폭발. sweet spot 찾기.
§ 11다른 강의로· connections
이 강의의 frame 이 다음에 어디에 다시 등장하는지
§ 12열린 질문· open questions
transcript 가 비어 있어 직접 검증해야 할 것들
- 강의에서 다룬 가족의 정확한 범위 — RWKV / Mamba / RetNet / GLA / DeltaNet 중 어느 것에 시간을 더 썼는지 미확인. Songlin 본인 work 인 GLA / DeltaNet 위주일 가능성.
- chunk-wise 의 backward 처리 — backward 도 같은 chunked-matmul 형태로 짜야 효율이 나오는데 — 강의에서 이 부분의 디테일이 어디까지 다뤄졌는지 미확인.
- 구체적 benchmark 수치 — 본 노트의 수치는 일반적 트렌드 재구성. 강의에서 보여줬을 정확한 H100 / A100 수치는 미확인.
- 훈련 안정성 — gated linear attn 은 학습 초기에 gate 의 saturation / decay 의 vanishing 같은 문제가 있을 수 있다. 강의에서 학습 hyperparameter 권장값을 보여줬는지 미확인.
- SSM (Mamba) 와의 duality 깊이 — Mamba2 paper 의 framework 를 강의에서 어떻게 깔았는지. SSM 출발 vs linear-attn 출발 둘 중 어디에 더 무게를 뒀는지.
- production 채택 사례 — Jamba, Falcon-Mamba 외에 강의에서 언급된 다른 production 모델이 있는지.
- hardware 별 차이 — H100 / A100 / consumer GPU 에서 chunk-wise form 의 sweet spot 이 다를 수 있음. 강의 안 자세한 분석 미확인.
검증 메모
본 노트의 모든 수치 (TFLOPs, recall %, throughput 비교) 는 일반적 산수와 FLA / 가족 논문들의 트렌드를 재구성한 예시값이다. 자기 hardware / 자기 모델에서 직접 측정 후에만 절대값을 인용할 것.