Jumping Ahead: Improving Reconstruction Fidelity with JumpReLU Sparse Autoencoders
Meta info.
- Authors: Senthooran Rajamanoharan, Tom Lieberum, Nicolas Sonnerat, Arthur Conmy, Vikrant Varma, János Kramár, Neel Nanda
- Paper: https://arxiv.org/pdf/2407.14435
- Affiliation: Google DeepMind
- References: DeepMind Blog, gemma-scope HF
TL; DR
기존 vanilla ReLU를 jumpReLU라는 비연속 activation으로 대체하여 새로운 SAE (sparse autoencodesr) SOTA, 비연속적인 activation 사용하지만 straight-through estimator로 효과적으로 학습






Suggestions
- JumpReLU: vanilla ReLU의 변형으로, threshold(
pic1의 x축 \theta에 해당)가 양수, threshold 이하에서는 0이었다가 threshold 도달하면 불연속적으로 점프, 그 이후의 identical- threshold 사용하면 실제 signal(
pic2의 빨간색)의 크기를 왜곡하거나 축소하지 않고도 작지만 positive하게 activate되는 노이즈 (pic2의 파란색)을 0으로 설정 가능 - jumpReLU의 불연속성 > straight-through estimator로 해결 (
pic3): pseudo-derivatives (dirac delta functions 근사) 활용, \theta에서 gradient signal 제공하는 역할 → 학습 가능하도록
- threshold 사용하면 실제 signal(
- L1 대신 reparameterization-invariant L1(RI-L1) 적용 (Gated)
- L1의 역할: SAE의 sparsity 강제
- L1의 한계: encoder parameter를 축소하고 그에 따라 decoder parameter를 확대하면, activated featrue의 수나 SAE의 output 재구성을 변경하지 않고 feature 크기를 임의로 축소, 즉 L1적용으로 목적으로 했던 sparsity의 효과를 줄이는 (
우회하는? 왜곡하는?의도대로 반영되지 못하는) 한계가 있음 - L1의 한계 극복: reparameterization-invariant L1(RI-L1,
pic4)- decoder weight 행렬의 column이 unit norm을 갖도록 해서 sparsity 보장= reparameterization에 불변하므로 안정적으로 sparsity 보장
Effects
pic5- yes (feature가 interpretable 하다) / no (feature가 interpretable하지 않다) / maybe (약간? 다소 interpretable하다) 3분류 척도
- 세 가지 SAE 유형(Gated, JumpReLU, TopK)에서 생성된 feature 모두 인간 기준으로 유사하게 해석 가능 → 해석 가능한 feature 생성에 JumpReLU 사용의 타당성 주장
pic6: SAE 유형별 \delta LM Loss 대비 자주 activated되는 feature 비율 확인- y축: 토큰의 10% 이상에서 activated feature 비율 // x축: \delta LM Loss (LM Loss 변화량). 즉, 특정 feature가 activate 또는 deactivate 될 때 손실이 얼마나 변화하는지 // layer: 9, 10, 31
- 모든 레이어와 activation threshold에서 JumpReLU SAE가 TopK SAE에 비해 high-frequency feature가 비슷하거나 더 적음
- TopK와 JumpReLU SAE가 high-frequency feature(토큰의 10% 이상에서 activate되는 feature) 비율이 Gated SAE보다 높은 한편
- Gated SAE는 특히 낮은 \delta LM Loss 와 같은 조건하에서 더 광범위한 기준 (1% 이상으로 기준을 낮추면) 의 JumpReLU SAE 대비 비슷하거나 더 많은 high-frequency feature를 가질 수 있음