Parallel Decoding 随笔 [WIP]

2025/02/15

1. Introduction

Parallel Decoding 是一种对 LLM Inference 加速的一种方法,这个随笔聊一聊它是怎么 work 的,并且聊一下它的几个流派。

2. What is Parallel Decoding

这一 part 聊聊什么是 Parallel Decoding 以及为什么它为什么能够 LLM Inferece 加速。 ref: https://arxiv.org/pdf/2302.07863

传统的 llm model 会用 autoregressive 的方式一个个生成 token。也就是说如果需要生成 5 个 token 就需要调用大模型 5 次,每生成一个 token 都依赖之前的输出。

而 Parallel Decoding 则利用了小模型和大模型协作的方式,由小模型先生成多个token (例如 4个),然后由大模型对它们进行 validation 并且也生成一个 token。如果所有小模型生成的 token 都被大模型接受了,那么生成 5 个 token 的时间就变成了 4 次小模型推理加上 1 次大模型的推理,小模型的推理时间低于大模型,由此总体时间会更快。

上面只是说的理想情况,事实上 Parallel Decoding 是否真的能带来 latency improvement 由很多因素,例如如果小模型的质量太差,大模型一个都不接受,那反而增加了 latency,这个在一下 part 可以进一步分析。

在后面的章节里,我们把这个大的模型叫 target model,这个小的模型叫 drafter,并且我们可以把 Parallel Decoding 分成这样三部分:

  1. Drafting: The draft model produces a set of potential tokens for the next sequence position, considering the previous context. 此时的输出是 draft tokens,以及 draft tokens 所对应的 logits。
  2. Validation: The target model performs a single forward pass to calculate the probability of each draft token given the current context. 输入 draft tokens,输出的是 target logits。
  3. Sampling: The process of selecting the final tokens based on the validation results. 输入是 draft tokens, draft logits 和 target logits,最终输出想要返回给用户的 tokens。

3. Evaluating the Speedup Achieved by Parallel Decoding

假设:

\[ T_{\text{auto}} = N \times T_{\text{target}} \]

而并行解码的过程可以分为三个阶段:

  1. Drafting: 小模型生成 draft token 的时间为 \(T_{\text{drafting}}\);
  2. Validation: 大模型 validate 这些 draft token 的时间为 \(T_{\text{validation}}\);
  3. Sampling: 最终从 draft token 中采样确定输出的时间为 \(T_{\text{sampling}}\);

那么单次 Parallel Decoding 的 latency 为

\[ T_{\text{parallel_iteration}} = T_{\text{drafting}} + T_{\text{validation}} + T_{\text{sampling}} \]

如果一次并行生成 \(k\) 个 token,则总的并行解码延迟为

\[ T_{\text{parallel}} = \lceil N/k \rceil \times \Big( T_{\text{drafting}} + T_{\text{validation}} + T_{\text{sampling}} \Big) \]

为量化延迟提升(或加速比),我们可以定义速度提升因子(Speedup)为:

\[ \text{Speedup} = \frac{T_{\text{auto}}}{T_{\text{parallel}}} = \frac{N \times T_{\text{target}}}{\lceil N/k \rceil \times \Big( T_{\text{drafting}} + T_{\text{validation}} + T_{\text{sampling}} \Big)} \]

只有当 \(\text{Speedup} > 1\) 时,Parallel Decoding 会给我们带来提升。

4. Naive vs Mudusa vs Eagle2 [In Progress]

ref: https://aclanthology.org/2024.findings-acl.456.pdf

根据 drafting 和 verification 的不同,Parallel Decoding 有非常多的变种,除了原始的 speculative decoding,后面比较受到关注的是 Medusa 和 Eagle2,理论上来说 Eagle2 对于对于 performance 的提升是最好的,这章就主要聊聊这三种方式的不同实现和各自的优缺点。

4.1 Drafting

ref: https://zhuanlan.zhihu.com/p/15955544919

Naive Speculative Sampling

Medusa

Eagle2

4.2 Validation & Sampling [In Progress]

4.3 优缺点

Naive Speculative Sampling

Medusa

Eagle2

暂时来看,如果想要快速用 parallel 来加速一个 model 的 inference,可以考虑 Naive Speculative Decoding, 如果想要追求最好的效果,应该选择 Eagle2。

5. Open Questions

5.1 why parallel decoding can improve the latency

5.2 how will it work with continuous batching

6. Closing Notes

7. Reference

  1. https://arxiv.org/abs/2211.17192
  2. https://arxiv.org/pdf/2302.07863
  3. https://arxiv.org/pdf/2408.08146v1
  4. https://aclanthology.org/2024.findings-acl.456.pdf
  5. https://arxiv.org/pdf/2406.16858v1
  6. https://arxiv.org/pdf/2401.10774
  7. https://zhuanlan.zhihu.com/p/704755926