Parallel Decoding 随笔

[Lifan]

2025/02/15

1. Introduction

Parallel Decoding 是一种对 LLM Inference 加速的一种方法,这个随笔聊一聊它是怎么 work 的,并且聊一下它的几个流派。

2. What is Parallel Decoding

这一 part 聊聊什么是 Parallel Decoding 以及为什么它为什么能够 LLM Inferece 加速。 ref: https://arxiv.org/pdf/2302.07863

传统的 llm model 会用 autoregressive 的方式一个个生成 token。也就是说如果需要生成 5 个 token 就需要调用大模型 5 次,每生成一个 token 都依赖之前的输出。

而 Parallel Decoding 则利用了小模型和大模型协作的方式,由小模型先生成多个token (例如 4个),然后由大模型对它们进行 validation 并且也生成一个 token。如果所有小模型生成的 token 都被大模型接受了,那么生成 5 个 token 的时间就变成了 4 次小模型推理加上 1 次大模型的推理,小模型的推理时间低于大模型,由此总体时间会更快。

上面只是说的理想情况,如下图所示,一般情况下不会所有 token 都被 accept,一些 token 会被 reject,然后 target model regenerate 这个被 reject 的 token,然后进行下一轮 decode。

ref: https://arxiv.org/pdf/2302.07863

在后面的章节里,我们把这个大的模型叫 target model,这个小的模型叫 drafter,并且我们可以把 Parallel Decoding 分成这样三部分:

  1. Drafting: The draft model produces a set of potential tokens for the next sequence position, considering the previous context. 此时的输出是 draft tokens,以及 draft tokens 所对应的 logits。
  2. Validation: The target model performs a single forward pass to calculate the probability of each draft token given the current context. 输入 draft tokens,输出的是 target logits。
  3. Sampling: The process of selecting the final tokens based on the validation results. 输入是 draft tokens, draft logits 和 target logits,最终输出想要返回给用户的 tokens。

3. Evaluating the Speedup Achieved by Parallel Decoding

假设:

\[ T_{\text{auto}} = N \times T_{\text{target}} \]

而并行解码的过程可以分为三个阶段:

  1. Drafting: 小模型生成 draft token 的时间为 \(T_{\text{drafting}}\);
  2. Validation: 大模型 validate 这些 draft token 的时间为 \(T_{\text{validation}}\);
  3. Sampling: 最终从 draft token 中采样确定输出的时间为 \(T_{\text{sampling}}\);

那么单次 Parallel Decoding 的 latency 为

\[ T_{\text{parallel_iteration}} = T_{\text{drafting}} + T_{\text{validation}} + T_{\text{sampling}} \]

如果一次并行生成 \(k\) 个 token,则总的并行解码延迟为

\[ T_{\text{parallel}} = \lceil N/k \rceil \times \Big( T_{\text{drafting}} + T_{\text{validation}} + T_{\text{sampling}} \Big) \]

为量化延迟提升(或加速比),我们可以定义速度提升因子(Speedup)为:

\[ \text{Speedup} = \frac{T_{\text{auto}}}{T_{\text{parallel}}} = \frac{N \times T_{\text{target}}}{\lceil N/k \rceil \times \Big( T_{\text{drafting}} + T_{\text{validation}} + T_{\text{sampling}} \Big)} \]

只有当 \(\text{Speedup} > 1\) 时,Parallel Decoding 会给我们带来提升。

4. Naive vs Mudusa vs Eagle2 [In Progress]

ref: https://aclanthology.org/2024.findings-acl.456.pdf

根据 drafting 和 verification 的不同,Parallel Decoding 有非常多的变种,除了原始的 speculative decoding,后面比较受到关注的是 Medusa 和 Eagle2,理论上来说 Eagle2 对于对于 performance 的提升是最好的,这章就主要聊聊这三种方式的不同实现和各自的优缺点。

4.1 Drafting

ref: https://zhuanlan.zhihu.com/p/15955544919

Naive Speculative Sampling

Medusa

Eagle2

4.2 Validation & Sampling

这个 part 大致说一说为什么 Parallel Decoding 说可以不牺牲 quality。

先看一看 LLM 每生成一个 token 都在做什么:

  1. 每一步生成中,LLM 会输出一个针对所有可能下一个 token 的概率分布。
  2. 然后从这个分布中采样出一个 token。

那在 Parallel Decoding 中我们是怎么做的呢:

  1. 草稿阶段(Drafting Stage): draft model 会为每个 token 提供一个概率分布 \(p(x)\)。
  2. 验证阶段(Verification Stage): target model 会重新计算出它自己对这些 token 的概率分布 \(q(x)\)。
  3. 采样阶段(Sampling Stage): 根据这两个分布决定是否接受草稿中的 token:
    • Case1: 如果 \(q(x) \geq p(x)\),则直接接受该 token。
    • Case2: 如果 \(q(x) < p(x)\),则以 \(\frac{q(x)}{p(x)}\) 的概率接受该 token,否则拒绝并从 \(q(x)\) 中重新采样。

ref: https://www.youtube.com/watch?v=S-8yr_RibJ4

这样做可以让从 Parallel Decoding 采样出来的 token 概率分布和原始的模型一致。

4.3 优缺点

Naive Speculative Sampling

Medusa

Eagle2

暂时来看,如果想要快速用 Parallel Decoding 来加速一个 model 的 inference,可以考虑 Naive Speculative Decoding, 如果想要追求最好的效果,应该选择 Eagle2。

4.4 Tree Attention for Eagle

既然 Eagle 是当前的 SOTA(最先进方法),我们有必要来聊一下它的 Tree Attention(树状注意力机制)。

这样做可以提升每个 iteration 被接受的 token 数量,从而进一步加快了推理速度。

ref: https://arxiv.org/pdf/2406.16858v1

5. Challenges

  1. Compute Bound Transition: 当 batch size 变大时,decode 的阶段有可能会从原本的 memory bound 转向 compute bound,而这时 Parellel Decoding 就可能失效
  2. Higher TTFT: 虽然 Parallel Decoding 可以帮助减小 TTIT, 可是也会增加 TTFT 的 latency,从而使得在一些情况没办法真的投入使用。
  3. Draft Model Training: 需要具备训练 draft model 的能力
  4. Extra Overhead: 虽然说最重要的 latency 还是来源于 validation part,drafting,sampling 等步骤有时也会带来挺多的latency,特别是当 batch 增加一种,也需要对它们进行优化
  5. Integration Complexity: 它需要一直和其他技术适配,例如 PD 分离,Batching,Paged KV 等等,所以在真实应用的时候可能会需要解决这些问题。
  6. More: 上面就是我想到的一些挑战,那在具体部署的过程中一定还有很多其他的问题。

6. Closing Notes

Parallel Decoding 还真是一个很神奇的技术,需要涉及系统,也涉及算法。既要考虑怎么和系统里其他的优化 integrate 在一起,也需要考虑怎么训练一个好的 draft 模型,并且在不同硬件中的表现也非常不同。学习和实践它对于我来说是一个宝贵的机会,帮助我往下窥探了很多 ML System 的内容。当我真正写完这篇东西的时候 Eagle3 也发布了。如果后面有新的发现和感想希望可以写第二篇随笔 follow up 一下 Parallel Decoding 的发展吧。

7. Reference

  1. https://arxiv.org/abs/2211.17192
  2. https://arxiv.org/pdf/2302.07863
  3. https://arxiv.org/pdf/2408.08146v1
  4. https://aclanthology.org/2024.findings-acl.456.pdf
  5. https://arxiv.org/pdf/2406.16858v1
  6. https://arxiv.org/pdf/2401.10774
  7. https://zhuanlan.zhihu.com/p/704755926