2 minute read

Meta info.

TL; DR

ํ›ˆ๋ จํ•  ๋•Œ ๋ณธ context length๋ฅผ ๋„˜์–ด์„œ๋„ Diffusion-based LLM์˜ "local perception" ๋•๋ถ„์— ์•ˆ์ •์  ์„ฑ๋Šฅ์„ ๋‹ฌ์„ฑํ•˜๋Š” LongLLaDA ์ œ์•ˆ. NTK ๊ธฐ๋ฐ˜ RoPE extrapolation์œผ๋กœ Diffusion-based LLM์˜ input length๋ฅผ ์ตœ๋Œ€ 24k ํ† ํฐ๊นŒ์ง€ ํ™•์žฅ, ํ›ˆ๋ จ ๋ถˆ์š”.

image.png

image.png

image.png

image.png

image.png

image.png

Background

  • long dependency ์ฒ˜๋ฆฌ ๋ฐ reversal curse ๊ทน๋ณต ๋“ฑ, auto-regressive manner (Transformers)์˜ ๋Œ€์•ˆ์œผ๋กœย Diffusion-based LLMย ์ฃผ๋ชฉ
    • ์•„์ง๊นŒ์ง€๋Š” ๋‹จ~์ค‘๋ฌธ์— ๋Œ€ํ•œ ์„ฑ๋Šฅ ํ™•์ธ, LC capacity์— ๋Œ€ํ•œ ๋ณด๊ณ  ๊ฒ€์ฆ ๋ถ€์กฑ
  • NTK-based RoPE scalingย ๋“ฑ์ด AR LLM์— ๋Œ€ํ•ด ์žฌํ•™์Šต ์—†์ด Inference-time์—์„œ context length ํ™•์žฅ ๊ฐ€๋Šฅ์„ฑ ํ™•์ธ
    • RoPE scaling: rotary position embedding
      • QK ๋‚ด์ ์— ์œ„์น˜์ •๋ณด๋กœ sinusoidal ํ•จ์ˆ˜๋กœ ์ธ์ฝ”๋”ฉ์„ ๋”ํ•ด์ฃผ๋Š” ๊ฒƒ (๋Œ€๋ถ€๋ถ„์˜ AR manner์—์„œ ์‚ฌ์šฉ)
      • ๋ฌธ์ œ๋Š” ํ›ˆ๋ จ๋œ max-length์— ๋Œ€ํ•ด์„œ๋งŒ ์ž‘๋™๋˜๋Š” ๊ฒƒ์ด ํ•œ๊ณ„: high dimensional sin/cos์˜ period๊ฐ€ ๋„ˆ๋ฌด ๊ธธ์–ด์„œ extrapolation์—์„œ OOD ๋ฐœ์ƒ
    • NTK-based RoPE Scaling:ย Neural Tangent Kernel ๊ด€์ ์œผ๋กœ RoPE period๋ฅผ scalingํ•˜๋ฉด extrapolation ๊ฐ€๋Šฅ
      • https://www.reddit.com/r/LocalLLaMA/comments/14lz7j5/ntkaware_scaled_rope_allows_llama_models_to_have/
      • tl;dr: RoPE์˜ period๋ฅผ ์ค„์ด๋ฉด ( ์ž‘๊ฒŒ scaleํ•ด์ฃผ๊ธฐ ) ํ•™์Šต ๋•Œ ๋ณธ context length๋ณด๋‹ค ๋” ๊ธธ์–ด๋„ ํ•™์Šต๋œ ๊ธธ์ด์ฒ˜๋Ÿผ ์†์ผ(?) ์ˆ˜ ์žˆ๋‹ค.
        • rotary base(period)๊ฐ€ ํฌ๋ฉด > ์ฆ‰ sin/cos ํ•จ์ˆ˜ ๋ณ€ํ™”๊ฐ€ ๋А๋ฆผ > ๋จผ ์œ„์น˜ ์ฐจ์ด ํ‘œํ˜„
        • rotary base๊ฐ€ ์ž‘์œผ๋ฉด > ์ง„๋™์ด ๋น ๋ฅด๊ณ  > ์งง์€ ๊ฑฐ๋ฆฌ ์œ„์น˜์— ๋ฏผ๊ฐ

Problem States

D-LLM์˜ LC ์„ฑ๋Šฅ์€ ์–ด๋– ํ•œ๊ฐ€?

  • NTK-based RoPE scaling ์ ์šฉ์ด ๊ฐ€๋Šฅํ•œ๊ฐ€?
  • training ์—†์ด context length ํ™•์žฅ์ด ๋ ๊นŒ?

Suggestions

LongLLaDA, AR-LLM์—์„œ ์“ฐ๋˜ NTK-based RoPE scaling์„ Diffusion-LLM์— training-freeํ•˜๊ฒŒ ์จ๋ณด์ž

  • ๊ธฐ์กด rotary base: beta_0=500,000
  • scaling factor: ฮป=4, 14, โ€ฆ
    • scaling factor ฮป๋Š” extrapolation ๊ธธ์ด์— ๋”ฐ๋ผ ์„ค์ •ย Eq. (1)
    • Sampling step์„ ๋Š˜๋ฆด์ˆ˜๋ก LLaDA์˜ retrieval depth ์ฆ๊ฐ€ย Fig 3
  • RoPE์˜ย sin/cos periodย ๊ด€์ ์—์„œ AR(causal attention) ๋ณด๋‹ค OOD ์˜์—ญ์ด ๋” ์ ์Œ > extrapolation์— ๋” ๊ฐ•๊ฑด
    • DLLM์€ bi-directional attention๊ตฌ์กฐ๋ผ position embedding์ด ๋Œ€์นญ์ ย Fig 4ย (NIAH ์‹คํ—˜ย Fig 3์—์„œ ํ›„์ˆ )

Effects

  • NIAH(needle-in-a-haystack) - retrieval accuracy ํ™•์ธ: Diffusion ๊ตฌ์กฐ์˜ local perception ๊ฐ•์  ํ™•์ธย Fig 2-3
    • AR LLM์€ย [0, T-1] ๊ตฌ๊ฐ„๋งŒ ํ•™์Šต > ์ „์ฒด sin/cos ๊ณก์„  ์ค‘ย ์ ˆ๋ฐ˜๋งŒ(+) ๋ด„ > (-)๋Š” ๋ชจ๋‘ OOD๋ผ LC ์„ฑ๋Šฅ ๊ธ‰๋ฝย (NIAH 0%)
    • Diffusion LLM์€ย [โˆ’T+1, Tโˆ’1] ๊ตฌ๊ฐ„์„ ๋ด„ > sin/cos์˜ย ์–‘์ชฝ(+/-)์„ ๋Œ€์นญ์ ์œผ๋กœ ๊ด€์ธก > LC์˜ ์•ž์ชฝ์€ ๋ชฐ๋ผ๋„ ๋’ท์ชฝ ์ตœ์‹  context๋Š” ์ œ๋Œ€๋กœ ๋ด„ (local perception,ย NIAH 10~25%์ •๋„๋Š” ์œ ์ง€)
    • lambda ablationย Fig 6,7 (+ 8,9,10,11,12)
      • ฮป=4, 14 (๊ฐ 8K, 16K max length)์—์„œ ์•ˆ์ •์  ํ™•์žฅ
      • ฮป=13๋ถ€ํ„ฐ AR์ฒ˜๋Ÿผ lost in the middle ๋ฐœ์ƒ. ์ค‘๊ฐ„์—์„œ context retrieve ์‹คํŒจ
      • ฮป=55์—์„œ๋Š” extrapolation ์‹คํŒจ.
  • LongBench: LLaDA๋Š” LLaMA3์™€ ๋Œ€๋“ฑ, synthetic task์—์„œ๋Š” ์šฐ์œ„. scaling ์ ์šฉํ•˜๋ฉด ํ‰๊ท  2-4์  ์ •๋„๋Š” ์„ฑ๋Šฅ ํ–ฅ์ƒย Tab 1
  • RULER: QA๋Š” ๋” ์ž˜ํ•˜์ง€๋งŒ Aggregation(e.g., variable tracing)์€ ๋” ๋ชปํ•˜๋”๋ผ: AR ๊ตฌ์กฐ๊ฐ€ global info ์กฐํ•ฉ์— ๊ฐ•ํ•œ ๊ฒƒ์œผ๋กœ ๋ณด์ž„ย Tab 2

Personal note. Diffusion LLM ๊ณ„์† ํ™•์ธํ•ด๋ด„์ง ํ•˜๋‹ค ์ดํ›„์˜ ์ตœ์‹  ์—ฐ๊ตฌ๋กœ ๋ณด์ด๋Š”๋ฐ, bi-directional ํ•˜๋‹ค๋Š” ๋А๋‚Œ์ด BERT๊ฐ€ ๊ณ„์† ๋– ์˜ค๋ฅด๊ธด ํ•˜๋„ค์š”. reddit ์ž‘์„ฑ์ž ๋ฐฉ์‹๋„ ๊ฐ„๋‹จํ•˜๊ณ  ๋‚ฉ๋“๊ฐ€๋Šฅํ•˜๊ฒŒ ํ™•์žฅ ๊ฐ€๋Šฅ์„ฑ์„ ๋ณด์ธ ๊ฒƒ ๊ฐ™์•„์„œ ํฅ๋ฏธ๋กญ์Šต๋‹ˆ๋‹ค.