AI学习者 · 8月1日

图解大模型计算加速系列:分离式推理架构2,模糊分离与合并边界的chunked-prefills

分离式推理架构1中,我们以DistServe为例,解释了“为何要使用分离式推理架构”:分离式推理架构可以解耦prefill(compute-bound)和decode(memory-bound)过程,使得不管是在硬件分配还是在并行策略上,这两者都能朝着独立的方向优化,同时改进TTFT和TPOT,而无需再像合并式推理架构那样,总是在这两者之间做trade off。

但是,读完这篇文章,你可能会有这样的疑惑:如果我能采取一种方法,使得处于prefill阶段的请求和处于decode阶段的请求能组成一个batch同时计算,而在组建这样的batch的过程中,我又充分考虑了最大化GPU计算单元利用率、最小化IO读写次数(简而言之,怎么能榨干一块gpu我就怎么来)。那么这时,我是不是在不解耦的情况下,同样也能同时保全TTFT和TPOT呢?

那么在这篇文章中,我们就来看看遵从这个思路设计的推理架构:Sarathi-Serve,以及它背后的核心技术chunked-prefills(切块式prefill)和stall-free schedules(无停滞式调度策略)。虽然本文是讲Sarathi-Serve,但是为了更好理清其设计思路(它也是在借鉴了其余架构的基础上改良而来),本文也会涉及对其余架构的核心技术讲解:

【全文目录如下】

一、传统batching方式
1.1 整体流程
1.2 缺陷

二、Orca:Selective batching
2.1 Iteration-Level Schedule
2.2 Selective Batching
(1) Decoder Block的各种计算
(2) Selective Bathing的计算流程

三、Sarathi-Serve:chunked-prefills
3.1 为什么混合batch能提升整体性能
3.2 为什么有了selective batching还需要chunked-prefills3.3 chunked-prefills运作流程
3.4 stall-free schedules
3.5 chunked-prefills调度流程源码解读3.6 为什么有了chunked-prefills还可能需要分离式架构

【写作与绘图不易,如果本文有帮助,欢迎点赞收藏在看~可以让更多人看见❤️】

一、传统batching方式

1.1 整体流程

我们来看早期一个传统的batching方式的例子(例如FasterTransformer的实现,图片来自Orca论文):

image.png

在这个例子中,我们的batch_size = 2,分别装着长度相等的x1和x2序列(长度不相等时,可以采用诸如左侧padding等方法)

  • 我们把(左padding过后)长度相等的序列送入模型做prefill,产出第一个token。整个prefill的过程,被称为1次iteration中文可以理解成一次迭代,或者1个推理阶段)。
  • 接下来我们对这两个序列做decode。可以发现1次迭代后,x2已经推理完毕,x1依然还在做推理
  • 由于在传统batching方法中,整个batching中的序列是一起行动的,所以尽管x2已经做完推理了,它还是没有办法被“释放”。“释放”的含义是:x2所占据的资源(例如KV cache等)不能被释放。
  • 接下来,x1又做了两次迭代。这下x1也完成推理了。然后整个batch中的数据才可以被真正“释放”。
  • 当这一个batch推理完毕后。其余请求才能继续组成新batch,做下一轮推理。

正是由于在传统batching中,需要所有的request一起行动,因此和传统batching配套的调度方式,又被称为request-level schedules

1.2 传统batching方式的缺陷

由1.1的整体流程,我们可以直观看出传统batching方式的缺点:

  1. 以牺牲TTFT的方式保全TBT(Time Between Tokens,可以理解成和TPOT是等价的)。由于整个batch一起行动,所以在这个batch做推理的过程中,不能接受新的请求,导致prefill的过程停滞了(stall)。所以尽管它一气呵成完成了现有数据的decode过程,它却增加了新请求们在队列中等待被处理的时间。
  2. 以牺牲吞吐(throughput)的方式降低延迟(latency)。由于不能接受新请求,吞吐量(每秒能处理的tokens数量)下降了,但是由于不间断地做decode,对decode来说延迟降低了。
  3. 增加了流水线并行中的气泡

我们对第3点做一些更详细的说明。

在大模型推理中,当模型尺寸过大时,我们需要把它切割到多张卡上,常用的并行方式有pp和tp(这里我们不谈dp,因为确认好tp和pp后,dp维度只是做模型副本拷贝而已)。一般来说,在做推理时,我们希望用一个较大的batch,这样一来我们可以最大化利用gpu的计算单元,二来也减少从显存读取数据到cache的次数(比如同样是从显存中读取模型权重,如果你分成很多小batch,你就要读取多次。当你合成大batch时,你只用读取1次,大家共享就可以了)。

  • 当我们使用tp时,我们是对模型做层内切割,这样一块卡上维护的模型权重占的显存就少了,我们就有空间组织更大的batch了。但是由于tp在前向过程中涉及到2次allreduce,所以它对不同gpu间的通讯性能要求更高。因此一般是在单机内,或者在有更好带宽的集群的情况下,我们会倾向于使用tp。
  • 当我们使用pp时,我们是对模型做层间切割,一块卡上维护的还是完整的层,虽然此时可能batch无法像tp那样打得比较大,但是pp间只涉及层间activation的通讯,对带宽要求更小。所以很多商用的架构都会使用pp作为推理的并行方式。

那么如果使用pp做推理,有一个优化点肯定是避不开的:减小pp的bubble,也就是减少gpu的空闲时间。

我们来看传统batching方式下的pp bubble情况,如下图(图片来自Orca论文):

image.png

其中,batch_size = 2,它装了A和B两个序列,下标表示序列正在进行第几个迭代。我们假设A和B此时都处于decode阶段。partition1~3可以理解成是3张gpu,上面维护着模型的不同层。

由于decode阶段是token by token的,所以A和B必须在第1次迭代产出一个token后,才能做第2次迭代。这就造成了每块gpu上的bubble(空闲时间)。

看见传统batching方式的这3个缺陷,此时的你一定觉得很可惜,因为:

  • 已经做完推理的请求,为什么还要占据着资源呢?把位置让给新的请求,让新请求做prefill,旧请求继续做decode,那不是更好吗?
  • 在使用pp的前提下,我在那些气泡处,塞入新请求做prefill或者decode,不就既能把那些气泡填满,又不影响当前请求做推理吗?

所以,这一切都指向了两个迫切需要被改进的方向:

  • 更改request-level的限制,让新请求和旧请求能接连不断组成新的batch(Orca iteration-level schedule
  • 让prefill和decode能在一个batch中一起做(Orca selective batching

二、Orca:Selective Batching

2.1 Iteration-Level Schedule

image.png

再复习一下:传统推理架构的调度流程如上图(图片来自Orca论文)。调度器(Scheduler)每次从请求队列中组织一个新的batch(如图中的x1和x2),然后与执行引擎(Execution Engine)交互做推理,等engine把这个batch的数据都做完推理并且返回给用户后,调度器才会继续从请求队列中组织新的batch。由于batch中的所有请求必须一起行动,我们管这种调度策略叫Request-Level Schedule

而现在我们的目标是:及时检测出推理完毕的请求,将其从batch中移出,好腾出位置给新的请求。

那怎么实现这点呢?还记得我们在1.1中给出的那张推理流程示意图吗?在那张图里,我们管请求做完prefill产出第一个token的过程叫1次iteration,请求每做一次decode也被称为1次iteration。所以,对于一个batch内的数据,如果我是按iteration维度调度的,也就是一个batch中的所有请求每做完1次iteration,scheduler就和engine交互一次,去检查batch中是否有做完推理的请求,以此决定是否要更新batch,这样不就能达到我们的目的吗?我们管这样的调度策略叫Iteration-Level Schedule,整体流程可用下图表示(图片来自anyscale blog)

image.png
这里,我们先不要管如何使用特殊的方法让这个batch中的数据能同时做推理(我们马上在下文讲解),只着重关注调度流程。这个batch中原始有4个序列s1~s4,黄色表示prefill tokens,蓝色表示decode tokens。左图展示了这4个序列刚做完prefill的过程。在此之后序列进入decode阶段,每生成1个token,scheduler就和engine做交互,即时检查序列的完成情况。在右图中,s3最先做完推理。此时scheduler检测到了这点,就把s3从batch中移除,再从队列里塞入新请求s5组成新batch继续做推理。s6~s7的推理过程同理可推。

2.2 Selective Batching

了解了iteration-level schedule后,现在我们来看一个大家都非常好奇的问题:同一个batch中,那些形态、计算方式各异的请求,要如何同时做推理?

举例来说:

  • prefill过程是长序列并行计算的,decode过程是token by token的
  • prefill过程不需要读取KV cache,decode过程需要读取KV cache
  • 对于prefill,各个请求的prompt长度是不一致的
  • 对于decode,不同请求的decode token的index不一样,意味着它们计算attention的mask矩阵也不一样。

诸如此类,真是令人头大。

而解决这些问题的一个好思路是:尽量找到这些请求计算时的共同之处,使得计算能最大化合并。对于有差异的部分再单独处理。这样说你可能觉得比较抽象,不要紧,我们先以一个transformer decode block为例,回顾一下序列要经过哪些计算,然后我们再慢慢讲解合并batch计算的细节。

(1)Decoder block中的各种计算类型

(下图来自sarathi论文)

image.png

  • preproj:即序列经过矩阵产出的过程。观察table1中给出的input和weights权重,可以发现重要的两点:

    • preproj计算时需要从显存读取模型权重。
    • preproj计算时和input序列长度无关(只是在hidden_size维度上做线性转换)
  • attn:利用计算出的计算attention分数的过程,可以发现:

    • attention分数计算时不需要从显存读取模型权重,你只需要利用算好的QKV即可
    • atttention分数计算时依赖mask矩阵,而不同序列的mask矩阵是不同的
  • postproj:使用权重矩阵,对经过attention计算后的序列做映射,它的两个特性和preproj一致。
  • FFN1与FFN2:道理同preproj/postproj,不再赘述。

我们把上面的介绍稍作提炼,得到如下重要信息:

  • preproj/postproj/FFN1/FFN2:做这些计算时,需要从显存读取模型权重,且这些计算和input序列长度无关。
  • attn:做attention分数计算时,不需要从显存读取模型权重,且不同序列的mask矩阵不同。
(2)selective batching的计算细节
  • preproj/postproj/FFN1/FFN2的计算和序列长度无关,这意味着你可以把一个batch中所有的tokens都展平成一行进行计算(维护好各自的位置向量就好)。而这些计算都要读取模型权重,这意味着我们可以尽量增大batch size,使得一次读取能造福更多request,以此减少IO次数。
  • attn的计算受各个序列的差异性影响(例如mask矩阵、是否需要读取KV cache),所以需要将序列拆分开独立处理,也即batch维度是重要的(cuBLAS batch matrix multiplication)。而由于attn部分本身不涉及到权重读取,因此你把序列拆分开处理,也不会在这一方面上带来额外的IO开销。

整体流程如下(图片来自Orca论文):

image.png

在图中,序列x1和x2正在decode阶段(因此需要KV cache Manager帮它们取出KV cache),序列x3和x4正在prefill阶段,它们被组成了一个batch。在非attention的部分,batch中的7个tokens被拉平成一行进行计算(忽略了batch维度),等实际计算attention时,再split开。计算完毕后再拉平。

三、Sarathi-Serve:chunked-prefills

我们来小结一下目前为止的内容:

  • 我们以分离式架构为引子,讨论了解耦prefill和decode过程带来的好处:能独立优化TTFT和TPOT/TBT,同时提升吞吐和降低延迟。
  • 基于此,我们又产生了疑问:如果不采用解耦的方式,只是修改传统的batching里非prefill即decode的方法,在最大化榨干一块gpu的前提下,让prefill和decode能同时放在一个batch里做推理,是不是也能达到一样的效果?
  • 为了解答这个问题,我们先回顾了以FasterTransformer为代表的早期batching方法:在推理的每个时刻,batch中的序列总是一起做prefill,或一起做decode。
  • 接下来,我们介绍了Orca是如何能让各种请求(prefill+decode,长度不同的prefill,index不同的decode等)混合在一个batch里做同时做推理的。

关于混合batch对性能带来的提升,大家可以去看Orca论文中的实验部分(以FasterTransformer等更早期的推理架构为baseline),这里就不赘述了。我们来看一个更有趣的问题:为什么混合batch可以带来性能上的提升?

3.1 为什么混合batch可以带来性能上的提升

我们来看sarathi-serve做的一个实验(图片来自sarathi-serve论文)

image.png

左右两图分别刻画了在不同的batch size下,prefill和decode阶段的吞吐量(tokens per second,每秒能处理的tokens数量)。

  • 观察到,对于prefill阶段来说,提升batch size时,吞吐量的有增长但不太显著。甚至当batch size更高时(比如从4~8),还发生了吞吐量的下降这是因为prefill阶段是compute-bound的,也即相比于读数时间,它消耗在计算上的时间更大(由于数据是可以边读边算的,所以我们可以大致认为总时间)。prefill阶段读取数据(例如从显存读取模型权重)的时间成本是固定的,但是计算时间却会随着batch中tokens的数量而增长,因此当gpu的计算单元还没有被打满时,吞吐量还可以上去;被打满时就会下降了。
  • 对于decode阶段来说,提升batch size时,吞吐量增长的线性趋势非常明显。这是因为decode是memory-bound的,也就是它花在读数上的时间更大(回想一下,当你用一个token做decode时,你其实要做的新计算很少,大部分时间你都花在读取KV cache和模型权重上)。decode阶段的算力严重打不满,所以当你增大batch size时,你不仅能多利用算力,也能把多次读取合并成一次读取,吞吐量自然就上升显著了。但是你也不能无止尽地增加batch size,因为gpu的存储是有限的,decode还要读取前面那一长串的KV cache呢。

既然decode和prefill阶段都需要读一些固定的数据(比如模型权重),且decode阶段的算力没有打满,那我们把他们组装在一起,让他们互相搭便车,肯定能取得更好的效果,也即:

  • prefill搭上decode的便车,能用上decode阶段被浪费的算力。
  • decode搭上prefill的便车,合并数据的读取次数,做到1次读取,大家共享。

3.2 为什么有了selective batching,还需要chunked-prefills

在3.1中,我们介绍了prefill和decode组成混合batch对性能提升的好处:乍一眼看,既不耽误做prefill(TTFT),也不耽误做decode(TPOT/TBT)。那么目前为止,Orca应该做得挺好了哇,那这个Sarathi-Serve的chunked-prefills,是干什么的呢?

当你回顾Orca组装batching的过程时,你可能会发现这个过程比较随机:一个batch中做prefill和做decode的请求有多少条是不确定的,只是大体按照先来后到的原则做动态组装。这就造成了一些问题:

  • 如果一个batch中做prefill的请求非常多,或者做prefill的请求非常长,那么prefill tokens会占据大量计算资源,使得整个batch变成compute-bound。
  • 如果一个batch中做decode的请求非常多(比如当所有的请求都没做完推理时,或者请求队列中没有新序列可以调度时),这个batch就可能变成memory-bound的。
  • 随机的batch同样可能产生pp并行气泡

哦咦,熟悉的感觉,我们再来看看第三点,还是关于pp并行气泡的问题。

我们知道相比于FasterTransformer,Orca已经能在一定程度上改善pp气泡问题了,但是由于其batch组装的随机性,它仍然可能导致气泡问题,我们以下图为例(图片来自Sarathi论文):

image.png

ABCD表示4个队列,下标p表示prefill阶段,di表示decode的第i个阶段。在采用micro-batch的前提下(也是减少pp气泡的一种办法),micro-batch size = 2,AB组成一个小batch,CD组成一个小batch。注意到这两个batch虽然size一致,但tokens数量更不一致。

观察到图中一共有3种类型的bubble:

  • PB1: 因为micro-batches中prefill序列长度不一致而产生的bubble
  • PB2: 因为prefill和decode阶段计算时间的差异而产生的bubble
  • PB3: 不同micro-batch的decode差异性而产生的bubble,这是因为不同micro-batch在做decode时,要读取的KV cache的长度不一致,这也导致了在读取数据上所花费的时间不一致

基于Orca selective batching的这些缺陷,我们不禁想:如果我们在保持selective batching这种混合机制的情况下,根据gpu资源的上限(FLOPS/MemBandwidth),找到一个最大batch size,即定义好一个batch内最多能处理的tokens数量,然后在每个batch内,在按照一定比例去分配做prefill的tokens和做decode的tokens,不就既能解决pp并行中的气泡问题,又能让这个batch得到性能最大化吗?

而在这种解决办法下,一个请求用于做prefill的序列必定是要被拆开的,所以我们就管这种方法为:chunked-prefills

3.3 chunked-prefills运作流程

基于pp的chunked-prefills运作流程如下(图片来自Sarathi论文):

image.png

  • 首先,我们通过3.2中的思路,从我们所使用的gpu性能出发,确定每个batch中最多能处理的tokens数量(可以通过profiling做模拟实验得到)。
  • 然后,我们在各个batch中进一步确定prefill tokens和decode tokens的比例。确认的原则被称为“decode-maximal batching":即优先往batch中添加需要做decode的序列,直到添加不动为止(即我们预留给decode的KV cache空间已经不足了,无法存放新的KV cache了)。然后我们再根据这个batch中剩余的tokens预算,对需要做prefill的序列做chunk切割,把对应的prefill tokens添加进batch中
  • 最后,Sarathi-Serve依然采用的是iteration-level schedules,即推理的每一步后,scheduler都会重新组建batch。

【📒:我们会在本章最后一节解读Sarathi-Serve调度器策略的源码,给大家展示更多上述流程的细节,这里大家只需要大致了解chunked-prefills的运作流程即可】

chunked-prefills的额外开销

看完了运作流程,你肯定有这样的疑惑:原来一条序列做prefill时,我是一起计算的。现在我把它拆成了多个chunk,那么每个chunk去计算时,肯定要去读前一个chunk的KV cache(如下图),那不就增加了IO复杂度了吗?这会影响到prefill计算的性能吗?

image.png

这个读取KV cache的额外开销肯定是有的,但它对prefill的影响大吗?基于此,Sarathi-Serve的作者们做了两个实验。

第一个实验:证明prefill阶段是强compute-bound特性,以及计算attention的时间在总计算时长里占比不高。

image.png

我们知道KV cache仅用在attention的计算中,所以这里作者把时间消耗拆成了attention和非attention(linear + others)的部分。可以发现:

  • 对于prefill的部分,不管prefill tokens数量如何,attention部分的计算时间在总时长里占比并不高。
  • 对于prefill部分,随着seq_length的变长,tokens的处理时间也变长。但是在128~512的长度内,tokens的处理时间增长不显著。这是因为在这个范围内,gpu的算力还没有打满。在这之后进入强compute-bound区域,此时读取数据的时间对prefill来说影响更小。

第二个实验:直接比较chunked-prefills和正常prefill下的延迟

image.png

这里以正常prefill为baseline(设其overhead = 1,即没有额外开销),比较不同chunk size下的额外开销。不出意外,prefill chunk分得越细(例如512),开销越大,但是总体来说,开销增长都控制在1.25倍内。稍微影响到TTFT,但是考虑到它对TBT/TPOT的更多提升(可以参见论文别的实验,这里不再写出),这样的开销还是可以接受的。

3.4 stall-free schedules

在Sarathi-Serve的设计思想下,无论是prefill过程还是decode过程,都不会产生停滞(stall)。以Sarathi-Serve作者的观点来看:在其余的推理架构中(比如vllm,Orca,FasterTransformer),他们都或多或少存在停滞一方以保存另一方的策略,我们来看一个整体流程图(图片来自Sarathi-Serve论文):

image.png

假设最开始有A、B两个序列,他们都处在decode阶段。从上帝视角来看,A和B分别要经过2次、4次decode迭代才能完成推理。

  • 对于这4个框架,A和B首先进入第1次decode迭代(图中第一个红色方块)。到这一步为止这4个框架没有什么差异。
  • 当A和B完成第一次decode迭代后。新来了请求C和D。
  • 对vllm,我们在之前的源码解读系列说过,它是prefill优先的,所以它会先处理C和D,这就使得decode暂停了(stall)。这其实是在保吞吐弃延迟(使得TBT增加了)
  • 对Orca,它在硬件资源允许的情况下,是可以让CD做prefill,AB继续做decode的(黄色部分)。但是由于decode和prefill的完整序列绑定,也使得整个decode的计算时间变长了(特别是在CD是长序列的情况下)。所以这其实也算是一种decode暂停
  • 对于FT,它是保延迟弃吞吐的。这使得prefill暂停了。
  • 对于sarathi-serve,它和orca一样,也是允许decode和prefill一起做的,但是它通过合理控制每个batch中prefill tokens的数量,使得decode阶段几乎没有延迟(把sarathi的绿色块和FT的红色块相比,可以发现绿色块只长了一点)。这样即保了延迟,又保了吞吐。

3.5 Sarathi-Serve调度流程源码解析

由于Sarathi-Serve论文中的调度流程伪代码,和实际的源码实现存在一定的差异。所以我这里直接根据源码来分析使用chunked-prefills方法时的调度流程(给出了非常详细的注释,大家可以关注下~):

class SarathiScheduler(BaseScheduler):

    def __init__(
        self,
        model_config: ModelConfig,
        scheduler_config: SarathiSchedulerConfig,
        cache_config: CacheConfig,
    ) -> None:
        super().__init__(model_config, scheduler_config, cache_config)
        
        # =================================================================
        # 【固定chunk_size策略】
        # 人为定好的chunk_size。如果你不想动态变更chunk_size大小,你可以固定使用这个。
        # 我们可以通过profiling等方式,在调度开始前确定好能够
        # saturate gpu computation的最大chunk_size
        # (注:在代码中,chunksize不是指prefill的chunksize,是指每次
        #  调度中,整个batch的tokens数量,也包括要做decode的tokens数)
        # =================================================================
        self.chunk_size = self.scheduler_config.chunk_size
        
        # =================================================================
        # 【动态chunk_size策略】
        # 使用动态变化的chunk_size
        # (随着调度次数增加,历史累积的要做decode的序列可能会变多,以及
        # 可能会进来更多的新请求。假设某个序列的prompt特别长,那么它就会持续占据着计算
        # 资源,影响到别的请求。所以对于这样的prompt,我们可以在迭代中逐渐减小它的preill
        # tokens数量)
        # 
        # 为了执行这个chunk_size动态变更的策略,我们需要如下4个参数:
        # 【low_chunk_size】:人为设定的最小chunk_size
        # 【high_chunk_size】: 人为设定的最大chunk_size
        # 【chunk_schedule_stages】:用于刻画调度阶段数。例如该值若等于5,则说明随着
        # 调度次数的增加,我们希望有5种逐步递减的chunk_size可以选择
        # 【chunk_schedule_max_tokens】: 这个变量比较难说明,我们直接看它怎么用。
        # 事实上,在源码中真正有意义的变量是_tokens_per_stage
        # (=chunk_schedule_max_tokens/chunk_schedule_stages)
        # 你可以理解成:对于一个正在做prefill的长序列,我们它的prefill tokens数量
        # 随着迭代阶段(stage)的增加而递减。我们设其做prefill时,每处理_tokens_per_stage
        # 个tokens就算完成了1个stage,然后就要递减一次prefill tokens。简而言之,这些
        # 参数的作用是帮助我们确定某个正在做prefill的序列当前位于哪个stage上
        # =================================================================
        self.enable_dynamic_chunking_schedule = (
            self.scheduler_config.enable_dynamic_chunking_schedule
        )
        # next four params apply only when using dynamic schedule
        self.low_chunk_size = self.scheduler_config.low_chunk_size
        self.high_chunk_size = self.scheduler_config.high_chunk_size
        self.chunk_schedule_max_tokens = self.scheduler_config.chunk_schedule_max_tokens
        self.chunk_schedule_stages = self.scheduler_config.chunk_schedule_stages

        if self.enable_dynamic_chunking_schedule:
            assert self.chunk_schedule_stages > 0
            assert self.chunk_schedule_max_tokens > 0
            assert self.low_chunk_size % 32 == 0
            assert self.high_chunk_size % 32 == 0
            # 计算在动态变更chunk_size的情况下,我们可选的chunk_size列表(详情参见相关函数注释)
            self._chunk_sizes = self._compute_chunk_size_schedule()
            # 用于计算每个stage能处理的token数(详细解释见上)
            self._tokens_per_stage = int(
                np.ceil(self.chunk_schedule_max_tokens / self.chunk_schedule_stages)
            )

    def _compute_chunk_size_schedule(self):
        # =================================================================
        # create num_steps equally spaced chunk sizes 
        # between low_chunk_size and high_chunk_size
        # 
        # self.low_chunk_size = 64
        # self.high_chunk_size = 256
        # self.chunk_schedule_stages = 5
        # 则chunk_sizes = [64, 108, 152, 196, 256]
        # 按照从大到小排序后 = [256, 196, 152, 108, 64]
        # =================================================================
        chunk_sizes = np.linspace(
            self.low_chunk_size,
            self.high_chunk_size,
            self.chunk_schedule_stages,
            dtype=np.int32,
        )[::-1]

        # =================================================================
        # 这里是调整每个备选的分块大小,让其能够被32整除
        # 这样做是考虑到tile-quantization effect,让gpu做gemm时的并行性能最大化
        # =================================================================
        round_of_chunk_sizes = min(32, self.low_chunk_size)
        chunk_sizes = (
            np.round(chunk_sizes / round_of_chunk_sizes) * round_of_chunk_sizes
        )
        chunk_sizes = chunk_sizes.astype(np.int64).tolist()

        return chunk_sizes

    def get_block_space_manager_class(self):
        return SarathiBlockSpaceManager

    def _get_seq_next_num_prefill_tokens(
        self, seq: Sequence, num_batched_tokens: int
    ) -> int:
        """
        对于一条还没做完prefill的seq,根据当前batch中已经存放的tokens数量,决定要送
        这个seq的多少tokens去做prefill
        """
        assert not seq.is_finished()
        # =================================================================
        # 如果使用动态chunk_size的方法
        # =================================================================
        if self.enable_dynamic_chunking_schedule:
            # =================================================================
            # 先计算当前seq目前一共处理了多少prefill tokens,然后根据每个阶段其最多能处理
            # 的prefill tokens数量,确定它在第几阶段(stage)
            # =================================================================
            request_stage_idx = int(
                np.ceil(seq.get_num_prompt_tokens_processed() // self._tokens_per_stage)
            )
            # =================================================================
            # 取出这个阶段的chunk_size
            # =================================================================
            assert request_stage_idx < len(self._chunk_sizes)
            chunk_size = self._chunk_sizes[request_stage_idx]
        # =================================================================
        # 如果没有使用动态变更chunk_size的策略,就用固定尺寸的chunk_size
        # (例如代码中的默认值512)
        # =================================================================
        else:
            chunk_size = self.chunk_size

        # =================================================================
        # 对于这个正在做prefill的seq,确定它在下一次迭代中要送去做prefill的tokens数量。
        # 这个数量 = min(该序列还没有做prefill的tokens数,batch中可用的prefill tokens配额)
        # =================================================================
        next_num_tokens = min(
            seq.get_prompt_len() - seq.get_num_prompt_tokens_processed(),
            chunk_size - num_batched_tokens,
        )

        return next_num_tokens

    def _schedule(self) -> SchedulerOutputs:
        # Fix the current time.
        now = time.monotonic()

        running: List[Sequence] = [] # 应该是用来存放确定要被本轮调度的数据
        ignored_seq_ids: List[str] = []
        preempted_seq_ids: List[str] = []
        scheduled_seq_metadata_list: List[SequenceScheduleMetadata] = []

        num_batched_tokens: int = 0

        ######################################################################
        # Phase 1: Add existing running sequence groups to the batch.
        # There are two cases:
        # 1. The sequence group has incomplete prefill. The routine
        # remains identical to the one in sarathi scheduler for such sequences.
        # 2. The sequence group has completed prefill. In this case, we need to
        # check for memory availability for the next chunk of decode tokens, and preempt
        # some sequence groups if necessary. Note that, the preempted sequence groups
        # might belong to either of the two categories.
        ######################################################################

        # NOTE(woosuk): Preemption happens only when there is no available slot
        # to keep all the sequence groups in the RUNNING state.
        # In this case, the policy is responsible for deciding which sequence
        # groups to preempt.
        
        # =================================================================
        # 把self.running中的数据按照FCFS原则(先来后到)进行排序
        # =================================================================
        self.running = self.policy.sort_by_priority(now, self.running)

        # in first pass process all the requests with prefill completed
        # this allows us to accurately account for the number of decode tokens
        running_prefills: List[Sequence] = []

        # =================================================================
        # 先去看上一次iteration中被选中的序列
        # =================================================================
        while self.running:
            seq = self.running.pop(0)
            # =================================================================
            # 如果这个seq没有被暂停,那么就把它继续添加到本轮running队列中
            # 
            # (上一轮调度结束后,所有running状态的序列都会被设置为pause状态,
            # 这里可以参考base_sequence_manager的on_step_completed函数,
            # 这个函数是对每轮调度结束后序列的状态和推理结果做处理),
            # 
            # (当然也可能有别的条件会触发pause状态设置,这里没有看完全部源码,所以暂不知道)
            # =================================================================
            if not seq.is_paused():
                running.append(seq)
                continue

            # =================================================================
            # 如果这个seq还没有做完prefill,就把它添加到running_prefill的列表中
            # =================================================================
            if not seq.prompt_processing_finished:
                running_prefills.append(seq)
                continue

            # =================================================================
            # (走到这一步,剩下的都是上一次调度中处于decode阶段的seq了)
            # 如果现在没有足够的空间给处于decode阶段的seq做推理了
            # =================================================================
            while not self.block_manager.can_append_slot():
                # =================================================================
                # 如果self.running队列中有数据,就从running队列中抢占最晚到来的那个
                # sarathi中的抢占是直接做重计算,即把seq重新放回waiting队列中
                # =================================================================
                if self.running:
                    # Preempt the lowest-priority sequence groups.
                    victim_seq = self.running.pop(-1)
                    self._preempt(victim_seq)
                    preempted_seq_ids.append(victim_seq.seq_id)
                # =================================================================
                # 如果self.running队列中已经没有数据了,就抢占当前seq
                # =================================================================
                else:
                    # No other sequence groups can be preempted.
                    # Preempt the current sequence group.
                    self._preempt(seq)
                    preempted_seq_ids.append(seq.seq_id)
                    break
            # =================================================================
            # 如果现在有足够空间给处于decode阶段的seq做推理
            # =================================================================
            else:
                # 给decode阶段的seq分配KV cache空间,并将其添加到本轮的running队列中
                self._append_slot(seq)
                running.append(seq)
                # 当前batch的token数量 += 1
                num_batched_tokens += 1
                scheduled_seq_metadata_list.append(
                    SequenceScheduleMetadata.from_sequence(seq)
                )
        
        # =================================================================
        # 接下来处理上一次调度中没有做完prefill的seq
        # 他们的KV cache空间肯定是够的,因为对于一个seq,我们在一开始是根据
        # 它完整的prefill序列长度来分配KV cache,而不是根据prefill chunk大小分配
        # KV cache。所以无论是那一轮iteration,我们都给这个seq的prefill留足了
        # KV cache空间
        # now add the requests with prefill incomplete
        # the memory for all these prefills has already been allocated
        # so we should be able to run all of them
        # =================================================================
        for seq in running_prefills:
            assert not seq.prompt_processing_finished
            # =================================================================
            # 计算对于这个seq,这次调度可以放多少tokens去做prefill
            # =================================================================
            next_num_prefill_tokens = self._get_seq_next_num_prefill_tokens(
                seq, num_batched_tokens
            )

            # as long as the request could fit in the batch previously
            # it should be able to fit in the batch now
            # so in non-pipeline case this condition should always be false
            # however, in pipeline case, the grouping of requests can change
            # between different microbatches, so this is not guaranteed to be always true
            if next_num_prefill_tokens == 0:
                running.append(seq)
                continue

            num_batched_tokens += next_num_prefill_tokens
            scheduled_seq_metadata_list.append(
                SequenceScheduleMetadata.from_sequence(
                    seq, prompt_chunk_len=next_num_prefill_tokens
                )
            )
            running.append(seq)

        ######################################################################
        # Phase 2: Add waiting (new) sequence groups to the batch.
        # This routine is nearly-identical to the one in sarathi scheduler
        # 在phase1中,我们遍历了上一个iteration的batch,来决定有哪些seq可以继续做
        # 这一轮的推理。
        # 在phase2中,我们去waiting队列中继续搜寻,看看是否有新请求能加入这一轮推理
        # 也就是每次调度中,batch = 上一轮batch筛选后的结果 + waiting队列中筛选的结果
        ######################################################################
        # Optimization: We do not sort the waiting queue since the preempted
        # sequence groups are added to the front and the new sequence groups
        # are added to the back.
        while self.waiting:
            seq = self.waiting[0]

            # This is required to handle benchmarking where we set request arrival time ahead of time
            if seq.arrival_time > now:
                break

            if not self._check_request_prompt_length(seq):
                ignored_seq_ids.append(seq.seq_id)
                continue

            # =================================================================
            # If the sequence group cannot be allocated, stop.
            # 直接用了vllm的allocate方法,即不是根据seq的prefill chunk大小
            # 预分配物理块的,而是直接根据整个seq的prefill大小分配物理块的
            # =================================================================
            if not self.block_manager.can_allocate(seq):
                # this is different from vllm scheduler
                # even if we cannot allocate this sequence group
                # there might be other sequence groups that can be allocated
                break

            # The total number of sequences in the RUNNING state should not
            # exceed the maximum number of sequences.
            if len(running) >= self.scheduler_config.max_num_seqs:
                break

            # check if we can fit the prefill in the batch
            next_num_prefill_tokens = self._get_seq_next_num_prefill_tokens(
                seq, num_batched_tokens
            )

            if next_num_prefill_tokens == 0:
                break

            seq = self.waiting.pop(0)
            self._allocate(seq) # 直接为完整的seq prefill(而不是chunk prefill)分配KV cache空间
            num_batched_tokens += next_num_prefill_tokens
            scheduled_seq_metadata_list.append(
                SequenceScheduleMetadata.from_sequence(
                    seq, prompt_chunk_len=next_num_prefill_tokens
                )
            )
            running.append(seq)

        # make sure that prefills are at the start of the batch, so that we don't violate assumptions
        # made in the original vllm codebase
        self.running = running

        return SchedulerOutputs(
            id=self._iteration_id,
            ignored_seq_ids=ignored_seq_ids,
            preempted_seq_ids=preempted_seq_ids,
            scheduled_seq_metadata_list=scheduled_seq_metadata_list,
        )

我们可以配合着下面这张图来解读源码:

image.png

总体来说,Sarathi的源码其实是基于vllm源码框架修改而来的(最新版本的vllm源码中也做了chunked-prefills的优化,等我有时间把这块写进vllm源码解读里)。注释中已经给出了所有的细节,这里额外强调几点:

  • 当整个系统刚启动时,batch中只有做prefill的序列。这时走的是源码中从waiting队列里调度的逻辑。在sarathi中,我们是根据整个prefill的长度预先分配好KV cache空间(而不是根据prefill chunk长度来分配的)。这确保了在后面所有的iteration中,我们不用再操心这个batch中这条prefill序列的KV cache问题,它一定是留足了空间。
  • sarathi提供了“固定”“动态”两种chunk size策略:

    • 在固定chunk_size策略中,默认值为512。这是sarathi根据硬件和profiling实验计算出来的能最大化saturate gpu computation的单batch中的tokens数量。从源码中不难知道,在系统刚启动时,每个请求的头512个prefill tokens各组成一个batch(如上图所示),进行前向推理。
    • 随着推理迭代的进行,陆续有请求完成了prefill,进入decode过程,比如上图中产出了Ad1。那么根据源码,A所在的这个batch,此时要分配1配额的tokens给Ad1继续做decode。同时,它要去waiting队列中按FCFS(先到先服务)的原则找出请求C。由于batch总tokens配额是512,所以它切割了C的511个tokens装进这个新batch中,以此类推。
    • 随着迭代的继续进行,这个batch中总有一些序列是在prefill中,有一些序列是在decode中。每一次在做新的调度迭代时,对于正在做decode的策略,我们会先检查当前是否有足够的KV cache空间留给他们做新一轮迭代,如果没有的话就需要抢占decode序列(细节在源码注释中)。而对于这个batch的prefill序列,正如前文所说,当他进入这个batch的那一刻起,就已经给他分配了完整的KV cache空间,所以它无需再担心这点。
    • 可能在你的印象中,固定大小hunked-prefills意味着每个batch中prefill tokens的数量是不变的,但是通过sarathi的源码解读,你可以发现,尽量保持不变的是batch中的总tokens配额,而prefill tokens数量是随着decode tokens的增减而变动的(只不过decode tokens的数量一般也不多,所以prefill tokens数量和整体batch tokens配额也不会相差很多)
    • 在动态chunk_size策略中,我们希望对于一个请求,它的prefill tokens的数量能随着迭代次数的增加而减少,这主要是为了解决较长序列带来的影响。当一条prompt特别长时,它在每一次迭代中都会占据一定计算资源,导致历史累积的decode序列和新来的请求受到影响。所以干脆,对于进入这个batch中的请求,在一开始我们多给它一些prefill tokens配额,然后随着迭代次数的增加,递减这个配额,降低它对别人的影响。

【📒论文中其实做了非常多关于性能的实验,篇幅原因这里不再一一给出,大家可以自行阅读论文。】

3.6 chunked-prefills VS 分离式推理架构

通过以上的介绍,你已经知道,在使用chunked-prefills的策略下,通过合理划分prefill tokens和decode tokens比例,最大化利用好gpu,似乎也能同时保全TTFT和TPOT/TBT。那么在这样的前提下,分离式推理架构还有什么优势呢?

其实如果想更好回答这一点,最好的方式是做消融实验并分析。我没有做过相关的实验,所以只能从原理上给出我自己的一些猜想:即有了chunked-prefills,为什么我们还可能需要分离式推理架构?

我觉得最主要的一点,是chunked-prefills可能还没有完全实现在达到TPOT/TBT SLO的情况下,最大化prefill阶段对GPU FLOPS的利用率(MFU)。我们从3.3的分析中可以发现,chunked-prefills是会产生额外开销的(overhead),这个开销不仅体现在他需要额外读取KV cache,还体现在prefill chunk size的设定上。我们知道GPU的矩阵计算是存在tile-quantization effect的,也即矩阵是被切分成tiles后送到thread blocks上去做并行计算的。如果你的矩阵尺寸是tiles尺寸的整数倍数,那么就可以最大化并行计算,否则那些除不尽的部分就可能产生额外的开销(Sarathi做过相关实验,257的矩阵尺寸比256的矩阵尺寸产生的prefill time多了32%)。而在chunk-prefill中,我们只是用profiling估算出在特定设备上一个batch的最大tokens配额而已,这些tokens包括prefill和decode。这个size是对整体的,而不是单独对prefill或decode的。所以仍然存在prefill阶段无法最大化MFU的可能。

第二个,也是从无法最大化prefill MFU上衍生出来的问题:chunked-prefills对长序列的处理可能还差强人意。从3.5的源码解读中,我们发现在chunked-prefills中,长序列持久地占据着KV cache的存储空间以及gpu的计算资源。尽管我们可以采用动态减少chunk_size的办法,来减少长序列的影响。但是一来,这个chunk_size递减的策略要怎么设置更合理(而不是像3.5源码中那样可能是自己凭经验拍了一个),还有待研究。二来即使是实现了更好的chunk_size递减策略,但它却使得长序列的TTFT变大了,同样影响用户体验。

所以,基于以上这些对chunked-prefills策略缺陷的猜想,或许使用分离式架构,对prefill阶段独立开发一套策略,可能可以更加针对性地解决以上问题。当然,这也取决于各策略的具体实现、业务场景和真实的实验效果。

四、参考

1、https://arxiv.org/abs/2306.02707
2、https://arxiv.org/abs/2308.16369
3、https://arxiv.org/abs/2403.02310
4、https://github.com/microsoft/sarathi-serve
5、https://www.anyscale.com/blog/continuous-batching-llm-inference
5、vllm、FasterTransformer相关资料,不一一列举

作者: 猛猿
来源:GiantPandaCV

推荐阅读

欢迎大家点赞留言,更多Arm技术文章动态请关注极术社区嵌入式AI专栏欢迎添加极术小姐姐微信(id:aijishu20)加入技术交流群,请备注研究方向。

推荐阅读
关注数
18812
内容数
1354
嵌入式端AI,包括AI算法在推理框架Tengine,MNN,NCNN,PaddlePaddle及相关芯片上的实现。欢迎加入微信交流群,微信号:aijishu20(备注:嵌入式)
目录
极术微信服务号
关注极术微信号
实时接收点赞提醒和评论通知
安谋科技学堂公众号
关注安谋科技学堂
实时获取安谋科技及 Arm 教学资源
安谋科技招聘公众号
关注安谋科技招聘
实时获取安谋科技中国职位信息