《GPU Mode》
L028
2024 · AUG · 24
High priority
transcript · available
Liger Kernel — Efficient Triton Kernels for LLM Training
학습 시 가장 무거운 자리는 attention 이 아니다 — cross_entropy 의 logits memory spike, RMSNorm 과 RoPE 의 elementwise launch 가 합쳐 만드는 메모리 압박이다. Byron Hsu 가 LinkedIn 에서 만들어 오픈한 Liger Kernel 은 Hugging Face 학습 모델 위에 한 줄 patch 로 끼우는 Triton 커널 모음 — 이 강의의 학습 노트.
Liger Kernel
Triton
LLM training
RMSNorm fused
cross_entropy fused
RoPE
SwiGLU
monkey-patch
memory spike
B
Speaker
Byron Hsu
LinkedIn AI · Liger Kernel 저자 · 전 PyTorch core
§ 01강의가 풀려는 문제· memory spike of LLM training
“학습 시 가장 큰 병목이 뭐냐” — flash attention 이 아니다
Byron 의 진단으로 시작하는 강의. flash attention 이 attention 의 HBM 트래픽을 잡은 뒤, 학습 루프의 새 병목은 elementwise op 들의 launch overhead와 cross_entropy 의 logits memory spike 다. Liger Kernel 은 그 두 자리를 정확히 노린다.
강의가 답하려는 두 줄 —
- 학습 루프에서 시간이 어디로 새는가 — Hugging Face Llama 의 forward + backward 를 profiler 로 까보면 무엇이 dominant 인가.
- 그 자리를 어떻게 Triton 으로 다시 짜는가 — 그리고 사용자가 코드 한 줄도 바꾸지 않고 어떻게 채택하게 하는가 (monkey-patch).
강의의 인지적 frame 은 분명하다 — “flash attention 이 attention 의 fusion 을 풀었으니, 우리는 attention 바깥의 fusion 을 푼다”. cross_entropy, RMSNorm, RoPE, SwiGLU. 학습 루프의 hot path 의 60% 가 이 4 개 안에 있다는 게 Byron 의 측정 결과.
강의의 frame
Liger 의 약속은 “학습 throughput 을 떨어뜨리지 않으면서 메모리만 줄인다”. flash attention 처럼 “속도와 메모리 둘 다 따 낸다” 와 미묘하게 다른 자리. cross_entropy 의 logits spike 를 잡으면 — 큰 vocab 의 모델(LLaMA 의 32K, Qwen 의 152K)에서는 — 실효 batch size 가 2배 늘어난다. 그게 dominant 한 가치.
“학습은 inference 와 다릅니다. forward 끝나고 backward 가 한 번 더 가요. 그래서 cross_entropy 의 logits 는 backward 까지 살아있어야 합니다 — 거기서 메모리 spike 가 생깁니다.”Byron Hsu · 06:52
강의 끝에 손에 잡혀야 할 것은 — 네 커널의 fusion 자리(cross_entropy / RMSNorm / RoPE / SwiGLU), 각 커널의 메모리 절감 메커니즘, 그리고 monkey-patch 한 줄로 채택하는 패턴이다.
§ 02학습 커널의 병목 진단· profile · launch overhead
Hugging Face Llama 의 학습 step 을 profiler 로 까본다
Byron 이 강의 첫 데모로 띄운 것 — Hugging Face 의 vanilla Llama-3 8B 학습 step 의 PyTorch profiler trace. 두 가지가 시각적으로 잡힌다 — elementwise 커널의 dense 한 launch와 cross_entropy 직후의 메모리 곡선 spike.
FIG · 학습 step 의 시간 분포Llama-3 8B · seq 4096 · A100
Byron 이 강의에서 보여준 분포의 재구성. attention 바깥이 합치면 약 44% — 그 자리가 Liger 의 표적이다. 시간뿐 아니라 메모리 spike 도 같은 자리에서 잡힌다.
profiler 가 잡는 두 종류의 비효율 —
- launch overhead — RMSNorm 의 forward 가 separate kernel 로 잡힌다.
x * rsqrt(mean(x²)) 가 통째로 한 커널이지 못하고 mean → rsqrt → mul 의 시퀀스로 분해된 PyTorch op chain. 작은 커널이 여러 번 launch.
- HBM 왕복 — 각 elementwise op 가 input 을 HBM 에서 읽어 output 을 HBM 으로 쓴다. 같은 데이터에 대한 read-write-read-write 가 반복.
두 비효율 모두 한 커널 안에 다 묶으면 사라진다. 이게 Liger 의 일반 전략 — 같은 텐서에 적용되는 elementwise op 시퀀스를 모두 한 Triton 커널로.
FIG · RMSNorm 의 두 구현 — fused vs separateHBM round trip count
unfused (eager)
RW
RW
RW
RW
8 trips
Liger fused
RW
2 trips
같은 RMSNorm forward — 8 번의 HBM 왕복이 2 번으로. backward 까지 합치면 14 → 4. throughput 이 거의 그대로 따라온다.
§ 03fused cross_entropy· logits memory spike
Liger 의 가장 큰 hit — vocab × seq × batch 의 logits 텐서를 절반으로
Byron 이 “이게 Liger 의 가장 영향력 있는 단일 커널이다” 라고 못 박은 자리. 학습 시 cross_entropy 의 logits 텐서가 backward 까지 메모리에 살아있어야 한다는 점이 핵심 문제 — 큰 vocab 모델에서 메모리 곡선의 spike 가 여기서 생긴다.
FIG · cross_entropy 의 메모리 곡선Llama-3 8B · vocab 128K
GB
▲ ▲ logits
80┤ / \ spike
70┤ / \
60┤ ─── activations ─/ \─── backward
50┤ ─
40┤ ─
30┤─
└────────────────────────────────────►
forward backward
▲ logits spike 가 30 GB 추가 → fused 로 0 추가
큰 vocab 일수록 logits = (B × S × V) × bf16 가 dominant. Llama-3 70B + vocab 128K + seq 8K = logits 만으로 13 GB. backward 에서도 살아있어야 하므로 두 배가 잡힘.
fused cross_entropy 의 핵심 아이디어 —
- logits 를 explicit 텐서로 만들지 않는다. forward 에서 hidden state × LM head weight 의 결과를 register/SRAM 에서 즉시 softmax + NLL.
- backward 도 같은 커널 안에서 처리.
∂L/∂logits = softmax(logits) - one_hot(target)의 식이 closed-form 이라 forward 의 softmax 결과만 남기면 된다.
- chunking — vocab 차원이 너무 크면 한 SM 에 안 들어간다. seq 차원으로 chunk 해서 한 row 씩 처리.
- upcast to FP32 inside kernel — bf16 input 을 받아 inside-kernel 에서 fp32 로 누적. 정확도 유지.
측정 — Llama-3 8B 학습
vanilla cross_entropy: peak memory 80 GB · throughput 4500 tok/s.
Liger fused cross_entropy: peak memory 54 GB (-32%) · throughput 4520 tok/s. throughput 거의 그대로. 메모리 절감만으로 batch size 를 2배 키울 수 있다.
# Liger 의 fused cross_entropy 사용 — 통상 PyTorch 대체
from liger_kernel.transformers import LigerCrossEntropyLoss
loss_fn = LigerCrossEntropyLoss(reduction="mean")
loss = loss_fn(logits.view(-1, vocab),
labels.view(-1)) # logits 를 인자로 받지만 ...
# 더 강력한 변형 — LM head 의 weight 를 받아 logits 자체를 안 만든다
from liger_kernel.transformers import LigerFusedLinearCrossEntropyLoss
loss = LigerFusedLinearCrossEntropyLoss()(
hidden_states, # (B*S, hidden)
lm_head_weight, # (vocab, hidden)
labels) # logits = h @ W^T 가 inside-kernel 에서 chunk 별로
“cross_entropy 의 logits 를 메모리에 만들지 않는다 — 이 한 줄이 Liger 의 가장 큰 win 입니다. vocab 이 클수록 효과가 커지고, 모던 LLM 의 vocab 은 점점 커지고 있어요.”Byron Hsu · 28:14
§ 04fused RMSNorm· forward · backward 융합
elementwise op chain 을 한 Triton 커널로 묶는 가장 깨끗한 사례
RMSNorm 은 LLaMA / Mistral / Qwen 의 표준 normalization. 식 자체는 짧지만, eager 모드에서는 4–7 개의 elementwise 커널 시퀀스 로 분해된다. Liger 의 RMSNorm 커널이 이걸 한 커널 안에 묶는다.
# Liger RMSNorm 커널의 핵심 — Triton
@triton.jit
def rms_norm_forward(X, W, Y, RSTD, eps,
N: tl.constexpr, BLOCK: tl.constexpr):
pid = tl.program_id(0)
X_row = X + pid * N
Y_row = Y + pid * N
cols = tl.arange(0, BLOCK)
mask = cols < N
x = tl.load(X_row + cols, mask=mask, other=0.).to(tl.float32)
w = tl.load(W + cols, mask=mask, other=0.).to(tl.float32)
# 한 커널 안에서 mean(x²) → rsqrt → multiply → cast 까지
var = tl.sum(x * x, axis=0) / N
rstd = 1.0 / tl.sqrt(var + eps)
y = (x * rstd) * w
tl.store(Y_row + cols, y.to(tl.bfloat16), mask=mask)
tl.store(RSTD + pid, rstd) # backward 에서 재사용
이 커널의 디자인 결정 4 가지 —
- row 한 개 = program 한 개. RMSNorm 은 row 별로 독립이라 program_id 만 row index 로.
- upcast to fp32 inside kernel. bf16 input/output, fp32 accumulation. 정확도가 학습 convergence 에 직결되는 자리.
- RSTD 를 saved tensor 로 명시 보관. backward 에서 다시 계산하지 않도록.
- BLOCK_SIZE = next_power_of_2(N). 한 row 가 한 program 이라 BLOCK 이 hidden size 와 일치해야 한다 — 대부분 4096–8192.
backward 도 같은 패턴 — input · grad_output · saved RSTD 를 받아 grad_x, grad_w 를 한 커널에서 계산.
FIG · RMSNorm forward + backward 의 HBM 트래픽fused vs eager
eager forward
RW
RW
RW
RW
8
eager backward
RW
RW
RW
6
Liger fused
RW
RW
4
forward+backward 합쳐 14 → 4. 메모리는 약 1/3 절감, throughput 도 약 1.2× 빨라짐.
Triton 의 launch 설정 결정
RMSNorm 처럼 row 별 reduction 커널은 BLOCK_SIZE 가 hidden size 와 같아야 한다. num_warps 는 BLOCK_SIZE 에 따라 1024 → 4 warps, 4096 → 8 warps, 8192+ → 16 warps. num_stages 는 보통 4 — pipeline stage 가 더 크면 register spill.
§ 05fused RoPE· rotary embeddings
RoPE — 위치 인코딩이 아니라 “모든 layer 에 박히는 작은 elementwise multiply” 의 누적
RoPE(Rotary Position Embedding)는 LLaMA 이후 표준이 된 위치 인코딩. 식 자체는 sin/cos 와의 elementwise multiply 두 줄 이지만 — 매 layer 의 Q, K 양쪽에 적용되니까 모델 전체에서 호출 횟수가 layer × 2 = 64+ 회. 그 누적이 launch overhead 의 큰 부분.
# RoPE 의 수학 — half-rotation 패턴
# q_rot = q * cos + rotate_half(q) * sin
# 여기서 rotate_half 는 [x1,x2] -> [-x2,x1] 의 swap
@triton.jit
def rope_forward(Q, K, COS, SIN,
Q_OUT, K_OUT, ...):
pid_b = tl.program_id(0) # batch · head · seq
# Q · K 두 텐서를 같은 program 에서 처리
q = tl.load(Q + ...)
k = tl.load(K + ...)
cos = tl.load(COS + ...)
sin = tl.load(SIN + ...)
# half-rotation: 한 row 안에서 self-shuffle
q1, q2 = split(q)
q_rot = concat(q1*cos - q2*sin,
q2*cos + q1*sin)
tl.store(Q_OUT + ..., q_rot)
# K 도 같은 자리에서 — Q/K 두 launch 가 한 launch 로
RoPE 커널의 디자인 —
- Q 와 K 를 같은 program 에서 동시 처리. eager 에서는 Q rope + K rope 가 separate kernel.
- cos/sin table 을 once 만 load. 같은 (batch, seq) 위 cos/sin 이 모든 head 에 공유.
- in-place 옵션. backward 도 in-place 가능 (수학적으로 invertible).
- contiguous 가 아닌 stride 도 지원. Q[B,S,H,D] 의 H 가 stride 차원일 때 — Triton 의 stride 인자를 명시적으로 받음.
왜 RoPE 가 fusion 의 가치가 큰가
RMSNorm 은 model 당 ~64 호출(layer 마다 2). RoPE 는 model 당 ~64 호출(layer 마다 2 — Q, K). 호출 횟수 자체는 비슷하지만, RoPE 는 한 호출 안에서 mul + sin/cos load + concat 의 시퀀스라 separate kernel 이 더 많이 펼쳐진다. fusion 의 win 이 비례적으로 크다.
§ 06fused SwiGLU + MLP· layer norm + gate
MLP 안 SwiGLU 는 sigmoid + multiply + multiply — 작지만 항상 같이 다닌다
LLaMA 의 MLP 는 down_proj(silu(gate_proj(x)) * up_proj(x)). 이 식의 끝 — silu(gate) * up — 이 elementwise activation. eager 에서는 silu 한 번 + multiply 한 번 + 그 사이 임시 텐서 하나가 박힌다. Liger SwiGLU 커널이 그걸 한 커널로.
eager SwiGLU
silu kernel (read gate, write tmp) → multiply kernel (read tmp, read up, write out). 4 HBM trips, 1 임시 텐서. 임시 텐서가 backward 까지 살면 메모리도 그만큼.
Liger fused SwiGLU
한 커널 안 — load gate · load up · silu · multiply · store. 3 HBM trips, 임시 텐서 0. backward 에서 재계산이 더 싸다(silu 가 cheap).
eager backward
grad_out * up · grad_out * silu(gate) · grad_silu · grad_up. 매 단계마다 separate kernel + 임시 텐서.
Liger fused backward
한 커널 안 — gate, up, grad_out 만 받아 grad_gate, grad_up 을 산출. 임시 텐서 없음, 한 launch.
왜 작은 커널의 fusion 이 큰 win 인가
각 호출 자체는 ms 수준이지만, 모델당 호출 횟수가 layer × 2 (forward + backward) 다. Llama-3 8B(32 layer) 는 64 회. SwiGLU launch 만으로 step 당 ~10ms — Liger fused 는 거의 0. 그게 step time 의 ~2-3% 인데, 모델 학습 1주의 GPU 시간 → 3% 가 큰 숫자.
§ 07numerics 정확도· upcast 패턴 · convergence
“메모리 절약했는데 학습이 안 된다” 가 실제로 가능한 자리
강의에서 Byron 이 가장 시간을 많이 쓴 자리 중 하나 — Liger 의 모든 커널은 학습 convergence 가 baseline 과 일치한다는 검증. 정확도 차이가 일정 수준을 넘으면 학습이 망가진다.
numerics 의 표준 패턴 —
- storage 는 bf16, accumulation 은 fp32.
tl.load(...).to(tl.float32) 가 모든 reduction 커널의 첫 줄.
- output cast 는 마지막에 한 번만. bf16 → fp32 → … → bf16. 중간에 down-cast 하지 않는다.
- partial upcasting 의 함정. 강의 Q&A 에서 한 청중이 “두 input 중 하나만 upcast 하면 fp16 mul fp32 일 때 어떻게 되는지” 질문 — Byron 의 답: “Triton 컴파일러가 자동으로 다른 쪽도 fp32 로 promote 한다. 명시적으로 둘 다 cast 하는 게 안전하다”.
- Hugging Face 와의 numerical match. Liger 는 모든 PR 에서 output diff < 1e-3 을 통과해야 merge.
FIG · 학습 loss curve — Liger vs vanillaLlama-3 8B · 1B tokens
loss
▲
2.4┤▓
│▓▒
2.2┤ ▓▒
│ ▓▒
2.0┤ ▓▒ ▓ vanilla
│ ▓▒
1.8┤ ▓▒ ▒ Liger fused
│ ▓▒
1.6┤────────────────────────────────►
0 250M 500M 750M 1B tok
두 곡선이 visible 하게 겹친다 — diff < 1%
Byron 이 강의에서 보여준 학습 곡선의 재구성. 두 곡선이 1B token 까지 시각적으로 분리되지 않음. Liger 는 “더 빠른 같은 학습”이라는 약속이 검증됨.
“우리는 모든 PR 에 numerics gate 를 박았어요. output diff 1e-3 이상이면 merge 안 합니다. 학습 결과를 절대 깨지 않는다 — 이게 채택을 가능하게 한 결정이에요.”Byron Hsu · 1:08:30
§ 08monkey-patch 채택· apply_liger_kernel
코드 한 줄로 — Hugging Face Trainer 가 그대로 쓰는 길
강의에서 Byron 이 가장 자랑한 디자인 결정 — 사용자가 모델 코드를 한 줄도 바꾸지 않는다. 한 import + 한 함수 호출로, 모델 안 RMSNorm/RoPE/SwiGLU/cross_entropy 가 자동으로 Liger 커널로 교체된다.
from liger_kernel.transformers import apply_liger_kernel_to_llama
# Hugging Face 모델 로드 그대로
model = AutoModelForCausalLM.from_pretrained("meta-llama/Llama-3-8B")
# ← 한 줄. 모델 안 RMSNorm/RoPE/SwiGLU 가 monkey-patch 됨
apply_liger_kernel_to_llama(
rope=True,
rms_norm=True,
swiglu=True,
fused_linear_cross_entropy=True,
)
# 그 후로는 Hugging Face Trainer 그대로
trainer = Trainer(model=model, args=...)
trainer.train() # ← Liger 커널이 자동으로 끼어듬
monkey-patch 의 mechanic —
- 모듈 클래스를 동적으로 교체.
transformers.models.llama.modeling_llama.LlamaRMSNorm = LigerRMSNorm 같이.
- autograd Function 으로 wrapping. forward + backward 가 한 쌍의 Triton 커널을 부른다. PyTorch graph 의 노드로 깨끗하게 들어감.
- 모델별 patch 함수 분리.
apply_liger_kernel_to_llama, ...mistral, ...gemma, ...qwen2. 모델별 RMSNorm 위치가 미세히 달라서.
- opt-out 가능. 각 커널을
True/False 로 켜고 끄게 — 디버깅 시 한 자리씩 분리해서 비교.
FIG · 채택의 layer cake모델 코드 변경 0 줄
기존 학습 stack
HF Trainer그대로
Hugging Face 모델그대로
RMSNorm classLlamaRMSNorm
cross_entropyF.cross_entropy
Liger 적용 후
HF Trainer그대로
Hugging Face 모델그대로
RMSNorm classLigerRMSNorm
cross_entropyLigerFusedLinearCE
왜 채택이 이렇게 빨랐는가
Liger 가 release 된 첫 달 안에 Axolotl, OpenRLHF, Hugging Face TRL 등이 옵션으로 채택. 이유는 단순 — “모델 코드 안 바꿔도 메모리 30% 절감”. risk 가 작으니 채택 cost 도 작다. 강의 시점에 GitHub star 수천을 빠르게 모음.
§ 09실측 사례· memory · throughput
“실제 학습 워크로드에서 어떤 숫자가 나오는가”
Liger 의 효과는 모델/세팅에 따라 다르다. 강의에서 Byron 이 보여준 표를 큰 그림으로 정리.
FIG · 메모리 절감 · 모델 크기별seq 4096 · batch 16 · A100 80GB
+ fused linear CE
48 GB (-37%)
cross_entropy fusion 이 단일 가장 큰 hit. 전체 -37% 의 메모리 절감 → batch size 1.5–2× 가능.
FIG · throughput 비교tokens/sec · 같은 batch
Llama-3 8B Liger
5100 (+13%)
Llama-3 70B Liger
940 (+15%)
Qwen2 vocab 152K vanilla
580
Qwen2 vocab 152K Liger
820 (+41%)
vocab 이 클수록 효과가 더 큼. Qwen2 처럼 152K vocab 모델은 cross_entropy fusion 이 dominant.
중요한 관찰들 —
- 모델이 클수록 효과가 약간 더 큼. larger model = more layers = more RMSNorm/RoPE/SwiGLU launch.
- vocab 이 클수록 cross_entropy fusion 의 hit 가 dominant. Qwen2, Aya 같은 다국어 모델에서 큰 win.
- seq length 가 길수록 효과가 커짐. RoPE/RMSNorm 모두 seq 차원에 비례하니까.
- FSDP/DeepSpeed 위에서도 호환. monkey-patch 가 그냥 module class 교체이므로 distributed setup 에 무관.
“같은 GPU 에서 같은 모델을 더 큰 batch 로 학습할 수 있다 — 이게 Liger 의 약속의 본질입니다. 학습 일정의 길이가 줄어들고, 같은 시간에 더 많은 token 을 본다는 거죠.”Byron Hsu · 1:18:00
§ 10기억할 메모와 코드· key takeaways
다시 열었을 때 5분 안에 손에 잡혀야 할 것
학습의 새 병목
flash attention 이후 attention 이 아니라 — RMSNorm·RoPE·SwiGLU·cross_entropy 의 launch overhead + cross_entropy 의 logits memory spike.
cross_entropy 의 hit
큰 vocab 모델에서 logits = (B·S·V)·bf16 가 dominant. fused linear CE 가 이 텐서 자체를 안 만든다. 가장 큰 메모리 win.
RMSNorm 패턴
row-per-program · BLOCK = next_pow2(N) · upcast fp32 · save RSTD for backward. forward+backward 합쳐 14 → 4 HBM trips.
RoPE 패턴
Q 와 K 를 같은 program 에서 동시. half-rotation 은 row 안 self-shuffle. cos/sin 한 번 load 후 모든 head 공유.
SwiGLU 패턴
silu(gate) * up 을 한 커널. 임시 텐서 없음. backward 도 fused — recompute 가 cheap.
numerics 약속
storage bf16 + accum fp32. PR 마다 output diff < 1e-3 검증. 학습 loss curve 가 baseline 과 visible 하게 겹쳐야 함.
monkey-patch
apply_liger_kernel_to_llama(...) 한 줄. 모델 코드 변경 없음. HF Trainer · Axolotl · TRL 그대로 채택.
실측 win
Llama-3 8B 메모리 -37%, throughput +13%. Qwen2 (vocab 152K) throughput +41%. 모델/vocab 클수록 효과 큼.
손에 새기기 — 실습 시퀀스
- baseline 학습 step profiling — Hugging Face Llama 의 한 step 을 PyTorch profiler 로 잡는다. RMSNorm/RoPE/SwiGLU/cross_entropy 가 차지하는 % 를 확인.
- Liger 적용 후 같은 profile —
apply_liger_kernel_to_llama 한 줄 추가. 같은 step 의 trace 를 비교 — 커널 수가 얼마나 줄었는지.
- RMSNorm 직접 짜보기 — Liger 의 RMSNorm 코드를 보지 않고 first principle 로. row-per-program, BLOCK = pow2(N), upcast fp32. 그 후 Liger 와 diff.
- cross_entropy fusion 이해 — fused linear CE 가 logits 를 안 만드는 trick 을 코드로 본다. chunk 단위 hidden × W^T → softmax → NLL 의 흐름.
- numerics gate 재현 — bf16 input · fp32 accum · bf16 output. baseline 과 output diff 가 1e-3 이하인지 확인.
- monkey-patch 직접 짜보기 — 자기가 가진 모델(예: Mistral) 의 RMSNorm class 를 동적으로 교체하는 짧은 함수. 모델 코드 그대로.
- 학습 1B token convergence — Liger on/off 두 가지로 1B token 학습. loss curve 를 plot 해서 두 곡선이 visible 하게 겹치는지 확인.
- vocab 큰 모델에서 효과 측정 — Qwen2 같은 152K vocab 모델에서 fused linear CE 의 메모리 절감을 직접 측정.
§ 11다른 강의로 이어지는 길· connections
이 강의의 도구가 시리즈 안에 어떻게 다시 등장하는지
§ 12열린 질문· open questions
다음에 다시 들었을 때 직접 검증해야 할 것들
- torch.compile + Liger 의 호환성 — torch.compile 의 inductor backend 가 같은 자리에서 fusion 하려고 한다. 두 시스템이 어떻게 합쳐지는가? graph break 없이 monkey-patch 가 작동하는지.
- flash attention 과의 fusion — Liger 는 attention 바깥. attention 자체와 RMSNorm/RoPE 를 더 fusion 할 여지가 있는가? fa3 의 tile-based 디자인과 Liger 패턴의 합집합.
- FP8 시대의 Liger — bf16 + fp32 accum 이 현재 패턴. H100/H200 의 FP8 mma 와 어떻게 통합되는가? Transformer Engine 과의 조율.
- Inference 에도 적용 가능한가 — Liger 는 학습 위주. inference 에서도 같은 패턴이 win 인가? cross_entropy 가 없으니 효과가 다를 듯.
- 큰 vocab 의 한계 — vocab 이 1M 이상으로 가면 fused linear CE 도 chunk 단위가 너무 커진다. 어떤 점에서 split 이 필요한가.
- 비-Llama 모델 지원 — RWKV, Mamba 같은 다른 아키텍처에 같은 패턴이 적용되는가? 이미 patch 함수가 모델별 분리됐는데, 새 모델 추가의 cost 는 어디.
검증 메모
이 노트의 % 와 절대 수치(메모리 -37%, throughput +13% 등)는 강의 시점(2024 8월)의 측정. 이후 Liger 가 계속 개선되고 있으니 현재 GitHub 의 README benchmark 를 다시 확인할 것.