1 minute read

Meta info.
  • Authors: Zicheng Lin, Tian Liang, Jiahao Xu, Xing Wang, Ruilin Luo, Chufan Shi, Siheng Li, Yujiu Yang, Zhaopeng Tu
  • Paper: https://arxiv.org/pdf/2411.19943
  • Affiliation: Tencent AI, Tsinghua Univ.
  • Published: November 29, 2024

TL; DR

์˜ค๋ฅ˜ ์ถ”๋ก ์ด ๋ฐœ์ƒํ•˜๋Š” ๊ณผ์ •์— ์ค‘์š” ์—ญํ• (์›์ธ)์„ ํ•˜๋Š” ํ† ํฐ (critical token)์„ ์‹๋ณ„ํ•˜์—ฌ ์ด ํ† ํฐ์„ ๋ชจ๋ธ ์ถ”๋ก  ๊ฐœ์„ ์— ์ ์šฉ(cDPO)ํ•˜๋Š” ๋ฐฉ๋ฒ•๋ก  ์ œ์•ˆ

image.png

image.png

image.png

image.png

image.png

image.png

image.png

Background

critical token์˜ ์‹ฌ๊ฐ์„ฑ

  • critical token: LLM์ด ์ž˜๋ชป๋œ ์ถ”๋ก  ๊ฒฝ๋กœ๋ฅผ ๋”ฐ๋ฅด๋„๋ก ์œ ๋„ํ•˜๋Š” ํ† ํฐ
    • e.g.ย Figure 2ย ์˜ โ€œowedโ€ โ†’ โ€œpaidโ€ ๋ณ€๊ฒฝ
    • experiment: GSM8K๋ฅผ Llama-3-8b๋กœ ์ถ”๋ก  ์ˆ˜ํ–‰, critical token ์‹๋ณ„ 1) w/ critical token (token score by resampling): ํ‹€๋ฆฐ reasoning์— ์žˆ๋Š” ๊ฐ ํ† ํฐ์— ๋Œ€ํ•ด 64ํšŒ ๋ฆฌ์ƒ˜ํ”Œ๋ง. ์ ์ˆ˜๊ฐ€ 0์ ์ธ ์ฒซ ํ† ํฐ์€ ์ง€์†์ ์œผ๋กœ ์ž˜๋ชป๋œ ๊ฒฐ๊ณผ > critical token์œผ๋กœ ์ง€์ • 2) w/o critical token (alternative token์œผ๋กœ decoding ๊ฐ•์ œ): ํ•ด๋‹น token์„ ๋Œ€์ฒด ํ† ํฐ์œผ๋กœ ๊ฐ•์ œ ๋ณ€๊ฒฝํ•˜๋„๋ก ์ƒ์„ฑ ์ œ์•ฝ ์ ์šฉ ํ›„ ๋™์ผํ•˜๊ฒŒ resampling โ†’ pass@k ์ธก์ •ย Figure 1

Problem States

critical token์„ ์‹๋ณ„ํ•ด์„œ ๋ฐ”๊พธ๋Š” ๊ฒŒ ์„ฑ๋Šฅ์— ํฌ๊ฒŒ ๊ธฐ์—ฌํ•˜๊ธด ํ•˜๋Š”๋ฐ, ๊ทธ ์‹๋ณ„๋น„์šฉ์ด ๋„ˆ๋ฌด ํฌ๊ณ  ํ™•์žฅ์„ฑ์ด ๋‚ฎ์Œ.

  • Research Question: critical token์„ ์ž๋™์œผ๋กœ ์ฐพ์„ ์ˆ˜๋Š” ์—†๋‚˜?

Suggestions

critical token์„ contrastive estimation์œผ๋กœ ์ž๋™์‹๋ณ„ํ•ด์„œ ๋‹ค๋ฅธ token์œผ๋กœ ๋Œ€์ฒดํ•˜์ž = cDPO

  • critical token ์ž๋™ ์‹๋ณ„: positive/negative (reasoning) trajectory ๋น„๊ต
    • contrastive estimationย Eq. (1)ย : positive model๊ณผ negative model ๊ฐ๊ฐ์˜ trajectory token level distribution ๋น„๊ตย Figure 3
  • cDPO: alignment learning์ค‘์— (1) critical token์„ ์‹๋ณ„ (2) token-level reward modeling 1) Figure 4 step1ย incorrect/correct trajectory ๊ตฌ์„ฑ: positive model์— ๋Œ€ํ•ด ๋ฌด์ž‘์œ„ ์„ ํƒ, negative model์— ๋Œ€ํ•ด ์ƒ์œ„ n% ์„ ํƒ (๋‹ค์–‘ํ•œ error ์ค‘ ํ™•์‹คํ•œ๊ฑฐ ์„ ํƒํ•˜๊ธฐ ์œ„ํ•จ์ธ๋“ฏ) โ†’ ๊ฐ trajectory๋ณ„ model SFT ํ•™์Šต 2) Figure 4 step2ย token level DPO ์ˆ˜ํ–‰ 1. ์„ค๊ณ„: incorrect trajectory์˜ ํ† ํฐ๋งŒ ์ ์ˆ˜ํ™” (DPO tuning์—์„œ critical token๋งŒ ๋ฌธ์ œ๊ณ  ๋‹ค๋ฅธ token์€ ๋ฌธ์ œ ์—†์„ ์ˆ˜ ์žˆ์œผ๋ฏ€๋กœ ์ •ํ™•ํ•œ penalty ์ฃผ๋„๋ก ์ฒ˜๋ฆฌํ•ด์•ผํ•˜๋Š”๋ฐ, ๊ทธ๋ ‡๋‹ค๊ณ  ์ด๋ฏธ ์ž˜ ์ƒ์„ฑ๋œ token ๋“ค์— (๊ณผํ•˜๊ฒŒ) ์ ์ˆ˜๋ฅผ ์ฃผ๋ฉด token distribution์— ์•…์˜ํ–ฅ ์šฐ๋ ค) 2. token-level preference dataset ์ƒ์„ฑ: ํ”„๋กฌํ”„ํŠธ๋งˆ๋‹ค ๊ฐ trajectory pair ๊ตฌ์„ฑ,ย Eq. (1)ย ๊ณ„์‚ฐํ•˜์—ฌ critical token ์ž๋™ labeling 3. reward modeling: token-level๋กœ DPO ํ™•์žฅย Eq. (4, 5)

Effects

  • experimental setup:
    • backbone: llama-3-8B/70B, deepseek-math-7b-base
    • dataset: GSM8K, MATH
    • baseline: positive๋งŒ ํ•™์Šตํ•œ SFT, ํ˜น์€ ~PO series(DPO, TDPO, RPO, โ€ฆ)
  • results
    • Table 1ย : ํ™•์ธํ•œ ๋ชจ๋“  ๋ฒค์น˜๋งˆํฌ์—์„œ p<.005๋กœ SOTA
    • Figure 5ย : contrastive estimation์œผ๋กœ ์‹๋ณ„ํ•œ critical token์„ ๋ฐ”๊พธ๋ฉด ์ผ๊ด€๋˜๊ฒŒ ์„ฑ๋Šฅ ํ–ฅ์ƒ

Personal note. ์ถ”๋ก  ์‹คํŒจ๋ฅผ ์œ ๋„ํ•˜๋Š” ์ฃผ์š” ์‹œ์ž‘์ด ๋˜๋Š” ํ† ํฐ์„ ์‹๋ณ„ํ•ด์„œ ์ˆ˜์ •ํ•˜๋ฉด ์„ฑ๋Šฅ์ด ๊ฐœ์„ ๋œ๋‹ค๋Š” ์ด์ „ ์—ฐ๊ตฌ๋Š” ์žˆ์—ˆ๋Š”๋ฐ, ์ด๊ฑธ ์ž๋™ ์‹๋ณ„ํ•ด์„œ token level๋กœ DPO ํ•˜๊ฒ ๋‹ค๋Š” ์ฒซ ์‹œ๋„๋ผ๊ณ  ํ•˜๋„ค์š”. ์‹คํ—˜์œผ๋กœ ํ™•์ธํ•œ ๊ฑด ์ˆ˜๋ฆฌ์ถ”๋ก ์— ํ•œ์ •ํ•˜๊ณ  ์žˆ๋Š”๋ฐ, ์ปจ์…‰์ด ๋‹จ์ˆœํ•ด์„œ ๋‹ค์–‘ํ•œ downstream task์— ์‰ฝ๊ฒŒ ์ ์šฉ ๊ฐ€๋Šฅํ•  ๊ฒƒ ๊ฐ™์Šต๋‹ˆ๋‹ค.