Critical Tokens Matter: Token-Level Contrastive Estimation Enhances LLMโs Reasoning Capability
Meta info.
- Authors: Zicheng Lin, Tian Liang, Jiahao Xu, Xing Wang, Ruilin Luo, Chufan Shi, Siheng Li, Yujiu Yang, Zhaopeng Tu
- Paper: https://arxiv.org/pdf/2411.19943
- Affiliation: Tencent AI, Tsinghua Univ.
- Published: November 29, 2024
TL; DR
์ค๋ฅ ์ถ๋ก ์ด ๋ฐ์ํ๋ ๊ณผ์ ์ ์ค์ ์ญํ (์์ธ)์ ํ๋ ํ ํฐ (critical token)์ ์๋ณํ์ฌ ์ด ํ ํฐ์ ๋ชจ๋ธ ์ถ๋ก ๊ฐ์ ์ ์ ์ฉ(cDPO)ํ๋ ๋ฐฉ๋ฒ๋ก ์ ์







Background
critical token์ ์ฌ๊ฐ์ฑ
- critical token: LLM์ด ์๋ชป๋ ์ถ๋ก ๊ฒฝ๋ก๋ฅผ ๋ฐ๋ฅด๋๋ก ์ ๋ํ๋ ํ ํฐ
- e.g.ย
Figure 2ย ์ โowedโ โ โpaidโ ๋ณ๊ฒฝ - experiment: GSM8K๋ฅผ Llama-3-8b๋ก ์ถ๋ก ์ํ, critical token ์๋ณ
1) w/ critical token (token score by resampling): ํ๋ฆฐ reasoning์ ์๋ ๊ฐ ํ ํฐ์ ๋ํด 64ํ ๋ฆฌ์ํ๋ง. ์ ์๊ฐ 0์ ์ธ ์ฒซ ํ ํฐ์ ์ง์์ ์ผ๋ก ์๋ชป๋ ๊ฒฐ๊ณผ > critical token์ผ๋ก ์ง์
2) w/o critical token (alternative token์ผ๋ก decoding ๊ฐ์ ): ํด๋น token์ ๋์ฒด ํ ํฐ์ผ๋ก ๊ฐ์ ๋ณ๊ฒฝํ๋๋ก ์์ฑ ์ ์ฝ ์ ์ฉ ํ ๋์ผํ๊ฒ resampling โ pass@k ์ธก์ ย
Figure 1
- e.g.ย
Problem States
critical token์ ์๋ณํด์ ๋ฐ๊พธ๋ ๊ฒ ์ฑ๋ฅ์ ํฌ๊ฒ ๊ธฐ์ฌํ๊ธด ํ๋๋ฐ, ๊ทธ ์๋ณ๋น์ฉ์ด ๋๋ฌด ํฌ๊ณ ํ์ฅ์ฑ์ด ๋ฎ์.
- Research Question: critical token์ ์๋์ผ๋ก ์ฐพ์ ์๋ ์๋?
Suggestions
critical token์ contrastive estimation์ผ๋ก ์๋์๋ณํด์ ๋ค๋ฅธ token์ผ๋ก ๋์ฒดํ์ = cDPO
- critical token ์๋ ์๋ณ: positive/negative (reasoning) trajectory ๋น๊ต
- contrastive estimationย
Eq. (1)ย : positive model๊ณผ negative model ๊ฐ๊ฐ์ trajectory token level distribution ๋น๊ตยFigure 3
- contrastive estimationย
- cDPO: alignment learning์ค์ (1) critical token์ ์๋ณ (2) token-level reward modeling
1)
Figure 4 step1ย incorrect/correct trajectory ๊ตฌ์ฑ: positive model์ ๋ํด ๋ฌด์์ ์ ํ, negative model์ ๋ํด ์์ n% ์ ํ (๋ค์ํ error ์ค ํ์คํ๊ฑฐ ์ ํํ๊ธฐ ์ํจ์ธ๋ฏ) โ ๊ฐ trajectory๋ณ model SFT ํ์ต 2)Figure 4 step2ย token level DPO ์ํ 1. ์ค๊ณ: incorrect trajectory์ ํ ํฐ๋ง ์ ์ํ (DPO tuning์์ critical token๋ง ๋ฌธ์ ๊ณ ๋ค๋ฅธ token์ ๋ฌธ์ ์์ ์ ์์ผ๋ฏ๋ก ์ ํํ penalty ์ฃผ๋๋ก ์ฒ๋ฆฌํด์ผํ๋๋ฐ, ๊ทธ๋ ๋ค๊ณ ์ด๋ฏธ ์ ์์ฑ๋ token ๋ค์ (๊ณผํ๊ฒ) ์ ์๋ฅผ ์ฃผ๋ฉด token distribution์ ์ ์ํฅ ์ฐ๋ ค) 2. token-level preference dataset ์์ฑ: ํ๋กฌํํธ๋ง๋ค ๊ฐ trajectory pair ๊ตฌ์ฑ,ยEq. (1)ย ๊ณ์ฐํ์ฌ critical token ์๋ labeling 3. reward modeling: token-level๋ก DPO ํ์ฅยEq. (4, 5)
Effects
- experimental setup:
- backbone: llama-3-8B/70B, deepseek-math-7b-base
- dataset: GSM8K, MATH
- baseline: positive๋ง ํ์ตํ SFT, ํน์ ~PO series(DPO, TDPO, RPO, โฆ)
- results
Table 1ย : ํ์ธํ ๋ชจ๋ ๋ฒค์น๋งํฌ์์ p<.005๋ก SOTAFigure 5ย : contrastive estimation์ผ๋ก ์๋ณํ critical token์ ๋ฐ๊พธ๋ฉด ์ผ๊ด๋๊ฒ ์ฑ๋ฅ ํฅ์
Personal note. ์ถ๋ก ์คํจ๋ฅผ ์ ๋ํ๋ ์ฃผ์ ์์์ด ๋๋ ํ ํฐ์ ์๋ณํด์ ์์ ํ๋ฉด ์ฑ๋ฅ์ด ๊ฐ์ ๋๋ค๋ ์ด์ ์ฐ๊ตฌ๋ ์์๋๋ฐ, ์ด๊ฑธ ์๋ ์๋ณํด์ token level๋ก DPO ํ๊ฒ ๋ค๋ ์ฒซ ์๋๋ผ๊ณ ํ๋ค์. ์คํ์ผ๋ก ํ์ธํ ๊ฑด ์๋ฆฌ์ถ๋ก ์ ํ์ ํ๊ณ ์๋๋ฐ, ์ปจ์ ์ด ๋จ์ํด์ ๋ค์ํ downstream task์ ์ฝ๊ฒ ์ ์ฉ ๊ฐ๋ฅํ ๊ฒ ๊ฐ์ต๋๋ค.