vLLM源码之PagedAttention

原文:https://www.zhihu.com/people/...

本文主要介绍vLLM推理引擎的PagedAttention算子实现,关于PagedAttention内容,后续会持续更新。

引用

本文内容初版翻译自: https://tech.scatterlab.co.kr/vllm-implementation-details/ 的PagedAttention部分,韩语写的。https://tech.scatterlab.co.kr/vllm-implementation-details/

后续更新版本会对源码分析再加入一些适当的图解和基础知识的补充,尽量做到能把PagedAttention分析全面细致。

感谢 @lipi提供该技术博客的原网址。

Attention

image.png

首先先了解下作为 Transformer 模型核心功能的 Attention(本文中仅介绍 GPT2 的多头 Attention)。如下图所示,右图为Multi-Head Attention,左图是是DotProductAttention,我们平时所接触的FlashAttention、PagedAttention、FlashDecoding都是这个层面的计算。具体计算公式为:

image.png

在 vLLM 的实现中,主要根据上述结构对 Attention 进行了变更,其余部分(除并行相关的实现外)与上述结构相同。

# Reshape the query, key, and value tensors.
query = query.view(-1, self.num_heads, self.head_size)
key = key.view(-1, self.num_heads, self.head_size)
value = value.view(-1, self.num_heads, self.head_size)

# Pre-allocate the output tensor.
output = torch.empty_like(query)

# Compute the attention op for prompts.
num_prompt_tokens = input_metadata.num_prompt_tokens
if num_prompt_tokens > 0:
    self.set_attn_bias(input_metadata)
    self.multi_query_kv_attention(
        output[:num_prompt_tokens],
        query[:num_prompt_tokens],
        key[:num_prompt_tokens],
        value[:num_prompt_tokens],
        input_metadata,
    )

注意力层的输入接收以下三个张量:

  • Query:形状为 [num_tokens, num_heads * head_size] 的张量
  • Key:形状为 [num_tokens, num_heads * head_size] 的张量
  • Value:形状为 [num_tokens, num_heads * head_size] 的张量

换句话说,按顺序排列所有标记,并且根据多头注意力,头数存在维度。将每个张量重新分割为 [num_tokens, num_heads, head_size] 的大小。

之后,为与提示部分相对应的标记(即除了最后 N(=序列数) 个之外的其余标记),对 Q、K、V 应用多头注意力。

也就是说,对于整个 QKV,不对没有高速缓存的 SequenceGroup 执行常规注意力计算。注意力使用 xformers 的内核。(xops.memory_efficient_attention_forward)

# 将键值存储到缓存中
cache_ops.reshape_and_cache(
    key[:num_valid_tokens],
    value[:num_valid_tokens],
    key_cache,
    value_cache,
    input_metadata.slot_mapping,
)

# Single Query Attention
self.single_query_cached_kv_attention(
    output[num_prompt_tokens:num_valid_tokens],
    query[num_prompt_tokens:num_valid_tokens], key_cache,
    value_cache, input_metadata)

在对 Key,Value 值进行注意力计算之后,将这些值存储在缓存中,然后对照缓存和生成标记(语境中的最后标记)应用单一查询注意力。

接下来内容主要从源码角度分析PagedAttention的实现,CUDA 实现需要对 CUDA 架构有预先知识。简而言之,在 GPU 中,按 Warp 单位执行 SIMT(单指令多线程),并且要执行 CUDA 核,必须按网格 / 块大小设置总线程数进行处理。(此时,块是一个与 PagedAttention 的块不同的概念。因此,我们将该块称为 CUDA 块。)

KV 缓存存储

从 CacheEngine 调用 Swap-In/Out,这是在 KV 缓存存储中调用的 cuda 内核。本节目的是了解 存储KV 缓存的函数,存储完KV缓存,该KV缓存后续用于PagedAttention计算。

dim3 grid(num_tokens);
dim3 block(std::min(num_heads * head_size, 512));

vllm::reshape_and_cache_kernel<scalar_t><<<grid, block, 0, stream>>>(
    key.data_ptr<scalar_t>(),
    value.data_ptr<scalar_t>(),
    key_cache.data_ptr<scalar_t>(),
    value_cache.data_ptr<scalar_t>(),
    slot_mapping.data_ptr<int>(),
    key_stride,
    value_stride,
    num_heads,
    head_size,
    block_size,
    x
);

在计算注意力之前,通过 reshape_and_cache 函数将 Key、Value 值存储在缓存中。此函数将 [num_tokens, num_heads, head_size] 大小的 Key、Value 值复制到缓存中。

  • Key 缓存:[num_blocks, num_heads, head_size/x, block_size, x] 大小
  • Value 缓存:[num_blocks, num_heads, head_size, block_size] 大小

Key 缓存的 x 值等于 16 / element_size,这似乎是用于实现快速计算和快速内存访问的技术。(以 bfloat16 为基准,x=8)

网格的大小设置为输入张量中包含的全部令牌数,并将 CUDA 块设置为 min(num_heads * head_size, 512) 大小。

template<typename scalar_t>
__global__ void reshape_and_cache_kernel(
  const scalar_t* __restrict__ key,     // [num_tokens, num_heads, head_size]
  const scalar_t* __restrict__ value,   // [num_tokens, num_heads, head_size]
  scalar_t* __restrict__ key_cache,     // [num_blocks, num_heads, head_size/x, block_size, x]
  scalar_t* __restrict__ value_cache,   // [num_blocks, num_heads, head_size, block_size]
  const int* __restrict__ slot_mapping, // [num_tokens]
  const int key_stride,
  const int value_stride,
  const int num_heads,
  const int head_size,
  const int block_size,
  const int x
) {
  // 每个块(Cuda 块)处理一个令牌。
  const int token_idx = blockIdx.x;
  // 找到每个令牌对应的缓存所在的插槽位置的索引。
  const int slot_idx = slot_mapping[token_idx];
  const int block_idx = slot_idx / block_size;
  const int block_offset = slot_idx % block_size;

缓存的CUDA内核负责每个CUDA块中的一个令牌。

const int n = num_heads * head_size;
// 因为内部维度中包含多个值,所以允许所有线程同时最大复制 512 个值,并且不会出现重叠。
for (int i = threadIdx.x; i < n; i += blockDim.x) {
  // Key 和 Value 大小为 [num_tokens, num_heads, head_size],所以
  // 通过跨距 = num_heads * head_size 值查找令牌位置。
  // 找到令牌位置后,找到各个线程负责的空间。
  // 也就是 num_heads * head_size 大小的数组被 512 个线程 (=blockDim.x) 复制。
  const int src_key_idx = token_idx * key_stride + i;
  const int src_value_idx = token_idx * value_stride + i;

  // Cache 위치의 Stride 및 인덱스를 계산합니다.
  const int head_idx = i / head_size;
  const int head_offset = i % head_size;
  const int x_idx = head_offset / x;
  const int x_offset = head_offset % x;

  // Key 캐시 크기: [num_blocks, num_heads / x, head_size, block_size, x]
  const int tgt_key_idx = block_idx * num_heads * (head_size / x) * block_size * x
                          + head_idx * (head_size / x) * block_size * x
                          + x_idx * block_size * x
                          + block_offset * x
                          + x_offset;
  // Value 캐시 크기: [num_blocks, num_heads, head_size, block_size]
  const int tgt_value_idx = block_idx * num_heads * head_size * block_size
                            + head_idx * head_size * block_size
                            + head_offset * block_size
                            + block_offset;
  key_cache[tgt_key_idx] = __ldg(&key[src_key_idx]);
  value_cache[tgt_value_idx] = __ldg(&value[src_value_idx]);
}

每个线程在cuda块内会将缓存中值复制到全局内存中。为此,计算源索引和目标索引,并利用这些值进行复制。

Single Query Attention

为了加速注意力计算,只使用 KV Cache 和最后一个令牌的 Query 来计算注意力。(可以将其视为 Huggingface 的 use_cache=True 的 CUDA 版本实现。)我不确定确切的名称,但为了方便,我们将之称为单一查询注意力。原始实现位于 FasterTransformer 中,vLLM 是针对 PagedAttention 新移植的版本。

image.png

内核使用以下设置运行:

网格维度:大小为 [num_heads, num_sequences] 的 2D

CUDA 块维度:大小为 [128] 的 1D

也就是说,每个 CUDA 块被并行地操作一个序列和一个头部。由于每个头和序列可以并行执行,因此可以将块视为并行化级别划分。接着,每个块在总共 128 个线程中并行计算,最后进行 Reduce 以合并运算。因此,每个 CUDA 块都会计算以下维度 QKV。

image.png

其中L是序列长度,h是一个头(=head_size)的隐藏维度。

首先,我们来看一下注意力公式:

image.png

使用 Python 代码表示如下:

def calculate_attention(query, key, value, mask):
    # query, key, value: (n_batch, seq_len, d_k)
    # mask: (n_batch, seq_len, seq_len)
    d_k = key.shape[-1]
    attention_score = torch.matmul(query, key.transpose(-2, -1)) # Q x K^T, (n_batch, seq_len, seq_len)
    attention_score = attention_score / math.sqrt(d_k)
    if mask is not None:
        attention_score = attention_score.masked_fill(mask==0, -1e9)
    attention_prob = F.softmax(attention_score, dim=-1) # (n_batch, seq_len, seq_len)
    out = torch.matmul(attention_prob, value) # (n_batch, seq_len, d_k)
    return out

现在让我们分析一下在库达语中编写的实现。

并行计算预处理

/// WARP_SIZE = 32, BLOCK_SIZE = 16(默认值)
// THREAD_GROUP_SIZE = 2, NUM_TOKENS_PER_THREAD_GROUP = 1
// 使块内的 16 个标记由 32 个线程处理。因此,有 2 个线程附着在每个标记上进行计算。
//(每个线程组有 2 个线程,每个线程组(2 个线程)对 1 个标记进行计算。)
constexpr int THREAD_GROUP_SIZE = MAX(WARP_SIZE / BLOCK_SIZE, 1);
constexpr int NUM_TOKENS_PER_THREAD_GROUP = (BLOCK_SIZE + WARP_SIZE - 1) / WARP_SIZE;
// 将总线程数除以 WARP_SIZE,计算总 Warp 个数(默认为 4)
constexpr int NUM_WARPS = NUM_THREADS / WARP_SIZE;

每个 cuda 块针对一个 head 和一个序列计算注意力。由于每个 cuda 块有 128 个线程,这些线程被划分到特定组中以便并行工作。首先有一个称为线程组的组,它将一个 Warp 划分为块大小数量,以此确定组大小。因此,存储在块中的标记由 Warp 内的线程均匀分配并处理。(即,每个线程组处理块内的一个标记。)Warp 是 Cuda 的 SIMT(单指令多线程)目标执行单元,以 A100 为标准,Warp 内包含的线程数为 32 个。

// thread_idx 是 0 到 NUM_THREADS-1 之间的值
const int thread_idx = threadIdx.x;
// 使用自己的线程索引计算是第几个 Warp 以及第几个 Lane
// 需要使用 warp_idx 来确保不同 Warp 之间同步,需要使用 lane 来确保 Warp 内部同步。
const int warp_idx = thread_idx / WARP_SIZE;
const int lane = thread_idx % WARP_SIZE;

// 既然网格自身大小为 (num_heads, seq_len),
// 因此可以从这些信息中获取下列值。
const int head_idx = blockIdx.x;
const int num_heads = gridDim.x;
const int seq_idx = blockIdx.y;

并指定每个线程要处理的块或令牌等索引。

// 找出每个线程拥有并计算的大小。如果线程组大小为4,则每个线程处理16 / 4 = 4字节,
// 由于元件大小的不同,数组大小也会不同,因此再次除以sizeof(scalar_t)以找出实际数组大小。(如果是fp16,则大小为4 / 2 = 2个数组)
constexpr int VEC_SIZE = MAX(16 / (THREAD_GROUP_SIZE * sizeof(scalar_t)), 1);
// 下面的 Vec 类型实际上不用于存储向量,而是用于利用在编译阶段预先定义的模板
// 在数组类型时使用。
using K_vec = typename Vec<scalar_t, VEC_SIZE>::Type;
using Q_vec = typename Vec<scalar_t, VEC_SIZE>::Type;

计算 key 的大小,并计算 value 的大小。每个线程组旨在同时处理 16 个字节,对应于 python 解析部分的 x 值 (16 / element_size)。由于每个线程组处理 16 个字节,因此乘以组内包含的线程数和 element size 然后除以 16,以计算每个线程处理的字节大小。

// 每个线程处理的元素数
constexpr int NUM_ELEMS_PER_THREAD = HEAD_SIZE / THREAD_GROUP_SIZE;
// 需要多少个 Vec 来存储 Vec 类型数组
// 对 Vec 的个人看法如下:
constexpr int NUM_VECS_PER_THREAD = NUM_ELEMS_PER_THREAD / VEC_SIZE;

// 特定线程属于哪个线程组
const int thread_group_idx = thread_idx / THREAD_GROUP_SIZE;
// 特定线程在哪个线程组中
const int thread_group_offset = thread_idx % THREAD_GROUP_SIZE;

接下来,计算每个线程和组需要处理的元素数和向量的数量。每个线程组必须针对 NUM_TOKENS_PER_THREAD_GROUP 个标记执行计算。每个标记存在 HEAD_SIZE 个元素,因此每个线程从组内均匀划分并获取它们。

Vec 数据结构的含义是什么?

如果 Element 类型为 float16,每个线程需要 32 个,那么 float16 vec[32] 不是可以了吗?为什么要中间使用 Vec 将其设为 Vec<float16, 2> vec[16]?从 Vec 类型来看,可以发现其为每个 Element 类型和 VEC_SIZE 单独声明了一个类型。

// FP16 vector types for Q, K, V.
template<>
struct Vec<uint16_t, 1> {
    using Type = uint16_t;
};
template<>
struct Vec<uint16_t, 2> {
    using Type = uint32_t;
};
template<>
struct Vec<uint16_t, 4> {
    using Type = uint2;
};
template<>
struct Vec<uint16_t, 8> {
    using Type = uint4;
};

也就是说,线程组在同一时间必须处理 16 字节,并且必须多次重复处理(必须处理 HEAD_SIZE)以便最快速地运行代码,以下示例可以帮助理解。

如果假设处理 32 个 float16 大小的数组。如果简单地声明为 float16 vec[32],则只能运行 32 次迭代。

FOR(i, 32) vec[i] = vec[i] + some[i];

但是如果核可以同时处理 2 个 float16(SIMD),或者可以应用流水线,那么速度会更快。

// SIMD
FOR(i, 16) simd_vec[i] = simd_vec[i] + simd_some[i];

// Pipelining
FOR(i, 16) {
    vec[2 * i] = vec[2 * i] + some[2 * i];
    vec[2 * i + 1] = vec[2 * i + 1] + some[2 * i + 1];
}

因此可以认为是为了加速像上面一样的计算而设置了中间的 Vec。实际上查看实现时,可以看到 Vec<uint16_t, 2> 的 Type 为 uint32_t。也就是说,一次处理两个 float16 的捆绑。并且当为这中向量类型时,可以看到将两个 float16 类型的加法在汇编单元中仅用一条指令执行(SIMD)。

inline __device__ uint32_t add(uint32_t a, uint32_t b) {
    uint32_t c;
    asm volatile("add.f16x2 %0, %1, %2;\n" : "=r"(c) : "r"(a), "r"(b)); // <---
    return c;
}

本文可以以图表的形式显示如下:

image.png

查询加载和内存分配

需要跳到由 CUDA 块负责处理的、与头部和序列对应的元素(= head_size 个)所在地址的位置。
const scalar_t* q_ptr = q + seq_idx * q_stride + head_idx * HEAD_SIZE;
// 因为一个线程组单位计算令牌,所以需要调用特定令牌的 HEAD_SIZE 个元素。
// 这个时,每个线程组中的线程调用并处理 HEAD_SIZE / THREAD_GROUP_SIZE 个元素。
// 尽管中间有 Vec,但 NUM_VECS_PER_THREAD 不是这种大小,
// 但要调用的元素数量总共为 HEAD_SIZE / THREAD_GROUP_SIZE 个。
Q_vec q_vecs[NUM_VECS_PER_THREAD];
// 通过展开来展开循环。在编译阶段,会针对特定值生成很多函数的重载函数。
// 因此,如下面所示,编译时可以设置的所有循环都通过展开用 unroll 展开,并删除了循环 (jump) 指令。
#pragma unroll
for (int i = 0; i < NUM_VECS_PER_THREAD; i++) {
  // 所有线程组中的 CUDA 块读取相同 Q 值。线程组内线程稀疏地读取。
  // 例如,如果线程组大小为 4,则第 0 个线程读取第 0、4、8 个元素(准确来说是 Vec),
  // 第 1 个线程读取第 1、5、9 个元素。
  const int vec_idx = thread_group_offset + i * THREAD_GROUP_SIZE;
  // 可以用 Vec<scalar_t, VEC_SIZE> 直接读取标量 t q[size]。
  // 就像可以用 float[2] 指针读取 struct { float x; float y }。
  q_vecs[i] = *reinterpret_cast<const Q_vec*>(q_ptr + vec_idx * VEC_SIZE);
}

首先,加载[num_seqs, num_heads, head_size] 大小的查询到寄存器中。每个线程加载查询值的各个部分。此时,一个 CUDA 块负责一个序列和头,因此,实际加载的大小为 head_size 个元素。

extern __shared__ char shared_mem[];
float* logits = reinterpret_cast<float*>(shared_mem);
__shared__ float red_smem[2 * NUM_WARPS];

接下来为输出和中间计算分配共享内存。将 shared_mem 变量的动态大小设置为较大的值,即 logits 的大小以及 CUDA 块输出所需的大小。使用 red_smem 使 CUDA 块内的各个线程之间进行归并。warp 内部可以不使用内存访问执行此归并操作,因而可以通过此_shlf_XXX 函数在 warp 之间使用归并操作。(大小为每 warp 浮点数 * 2 = 8 字节)

用于 QK 计算的变量计算

// x == THREAD_GROUP_SIZE * VEC_SIZE
// 每个线程组一次读取 x 个 Key 元素。
constexpr int x = 16 / sizeof(scalar_t);
float qk_max = -FLT_MAX;

计算我们要一次处理的元素数量 x。qk_max 用于 Softmax 计算。通常,在执行 Softmax 计算时,仅取 exp() 会导致值变得非常大,因此会调整最大值使其为 0 以防止发生上溢,计算时会用到它。

const int* block_table = block_tables + seq_idx * max_num_blocks_per_seq;
const int context_len = context_lens[seq_idx];
const int num_blocks = (context_len + BLOCK_SIZE - 1) / BLOCK_SIZE;

一个序列包含多个令牌和多个块。获取负责序列的块表、语境长度和块的数量。

QK计算

// 每个 Warp 处理一个区块。
for (int block_idx = warp_idx; block_idx < num_blocks; block_idx += NUM_WARPS) {
  const int physical_block_number = block_table[block_idx];
  // ...

一个 CUDA 块必须针对一个序列进行计算。此时,每个标记都存在于块内的 block_size 中。每个 Warp 会处理一个块。因此,围绕整个块数进行迭代时,会声明一个外部循环,以便一个块由一个 Warp 处理。

// 执行 warp 中每个线程组负责的令牌的第二个迭代。
for (int i = 0; i < NUM_TOKENS_PER_THREAD_GROUP; i++) {
  // thread_group_idx 在 warp 中不是唯一的,而是在 cuda 块中唯一的。
  // 换句话说,如果在 warp 中唯一,则其他 warp 的线程组可能存在重复的 thread_group_idx,
  // 但如果在 cuda 块中唯一,则不会重复,并且意味着可以使用 thread_group_idx 来获得 warp_idx。
  const int physical_block_offset = (thread_group_idx + i * WARP_SIZE) % BLOCK_SIZE;
  // 表示要在序列中计算的特定令牌在该序列中位于第几个令牌。
  const int token_idx = block_idx * BLOCK_SIZE + physical_block_offset;
  K_vec k_vecs[NUM_VECS_PER_THREAD];

一个线程组负责处理总 NUM_TOKENS_PER_THREAD_GROUP 个数的标记。该值小于 block_size,因此,像上方的代码一样,第二个嵌套循环得以执行,每次迭代都会处理一个标记。physical_block_offset 值是特定线程组计算它本身在块内的标记位置的索引。

for (int j = 0; j < NUM_VECS_PER_THREAD; j++) {
  // key 缓存大小为 [num_blocks, num_heads, head_size / x, block_size, x]
  // 可对特定线程组调用的公共索引为 [num_blocks, num_heads, block_size]
  // 部分的偏移量预先生成并存储在 k_ptr 中。
  const scalar_t* k_ptr = k_cache + physical_block_number * num_heads * HEAD_SIZE * BLOCK_SIZE
                                  + head_idx * HEAD_SIZE * BLOCK_SIZE
                                  + physical_block_offset * x;
  // 特定线程组内的特定线程应调用 head_size / THREAD_GROUP_SIZE 个元素。
  // 将剩余部分 [head_size / x, x] 缩减为 2D 以进行思考,
  // head_size / x 部分的索引将存储为 offset1,x 部分的索引将存储为 offset2,
  // 并调用。在这种情况下,每个线程可调用一个 x 个元素。
  const int vec_idx = thread_group_offset + j * THREAD_GROUP_SIZE;
  const int offset1 = (vec_idx * VEC_SIZE) / x;
  const int offset2 = (vec_idx * VEC_SIZE) % x;
  k_vecs[j] = *reinterpret_cast<const K_vec*>(k_ptr + offset1 * BLOCK_SIZE * x + offset2);
}

现在,每个线程都将数据从键缓存中加载出来。此时每个线程组一次只加载一个令牌,因此线程组内的线程会按顺序加载令牌的数据。也就是说,线程组加载 head_size 个元素,每个线程加载 head_size / THREAD_GROUP_SIZE 个元素。

此时键缓存中的形状为 [head_size, block_size](已转置)。

// --- attention_utils.cuh ---
// Q*K^T operation.
template<int THREAD_GROUP_SIZE, typename Vec, int N>
inline __device__ float qk_dot_(const Vec (&q)[N], const Vec (&k)[N]) {
  using A_vec = typename FloatVec<Vec>::Type;
  // Compute the parallel products for Q*K^T (treat vector lanes separately).
  A_vec qk_vec = mul<A_vec, Vec, Vec>(q[0], k[0]);
#pragma unroll
  // 每个线程分别对 Q 和 K 的元素求积并累加。
  for (int ii = 1; ii < N; ++ii) {
    qk_vec = fma(q[ii], k[ii], qk_vec);
  }

  // Finalize the reduction across lanes.
  float qk = sum(qk_vec);
#pragma unroll
  // 每个线程组内的线程彼此执行规约,以便线程组内的所有线程最终都拥有一个加和了的所有线程计算结果的标量值。
  for (int mask = THREAD_GROUP_SIZE / 2; mask >= 1; mask /= 2) {
    qk += __shfl_xor_sync(uint32_t(-1), qk, mask);
  }
  return qk;
}
// --- end of attention_utils.cuh ---

// 在线程内执行点积并在线程组内执行规约,最终线程组内的所有线程都拥有各个线程计算结果的和。
float qk = scale * Qk_dot<scalar_t, THREAD_GROUP_SIZE>::dot(q_vecs, k_vecs);

如果已经读取了 Key 值,则将 Q 与 K 相乘以生成一个标量值。QK_dot() 函数包括一个缩减操作,在线程组中计算点积并将其加总到一个值中。

image.png

// 坡度为 1/2^n 的坡度,后面的项是 ALiBi 公式的罚分,两值相乘并相加。
qk += (alibi_slope != 0) ? alibi_slope * (token_idx - context_len) : 0;

// 线程组的代表将对数添加到 logit 中。
if (thread_group_offset == 0) {
  // 将部分约简后的值存储在共享内存中。
  const bool mask = token_idx >= context_len;
  logits[token_idx] = mask ? 0.f : qk;
  // 更新最大值。
  qk_max = mask ? qk_max : fmaxf(qk_max, qk);
}

每个线程组的第一个线程作为代表将自己所属线程的令牌对应的QK值存储在共享存储器中,并拥有对应令牌中的最大QK值。

QK 值的最大值

当整个过程结束时,logits 变量中现已储存了 Q、K、T 值。现在需要找出这些值中的最大值。在此过程中使用的运算技术称为 Butterfly Reduction 技术。

Butterfly Reduction 是什么?

在并行处理中,Reduction 是一种在组内快速、高效地处理运算的方式。由于可以在寄存器内进行运算,特别是在 Warp 内部高速执行运算,因此非常快速。具体方法如下:

  1. 对于线程数 N(=2^k),将掩码 M 初始化为 N/2,每个线程的编号从 0 到 N-1。
  2. 每个线程将自己的编号与掩码 M 进行 XOR 运算,将结果 T 与线程 T 所拥有的数据交换,然后执行运算。
  3. 每个线程运算后保存结果,并将其 M 更改为 M/2。
  4. 重复上述过程,直至 M 为 0。
def reduce_sum(tid, value):
    mask = NUM_THREADS >> 1

    while mask > 0:
        value += exchange(tid ^ mask, value)
        mask >>= 1

        __synchronize()

    return values[0]

以下是在 8 个线程中运行的原理。

image.png

最终无论选择哪一个线程,都会反映所有线程的值。

image.png

// Butterfly Reduction
#pragma unroll
for (int mask = WARP_SIZE / 2; mask >= THREAD_GROUP_SIZE; mask /= 2) {
  qk_max = fmaxf(qk_max, __shfl_xor_sync(uint32_t(-1), qk_max, mask));
}
if (lane == 0) {
  red_smem[warp_idx] = qk_max;
}
__syncthreads();

首先在 Warp 中使用 Butterfly Reduction 方法进行归约。每个线程所持有的 QK 的最大值会相互同步,最终 Warp 中的所有线程都持有相同的值。然后,Warp 的代表(0 号 Lane)将最大值写入共享内存的 Warp 索引位置。此值随后将用于在不同的 Warp 之间同步最大值。

qk_max = lane < NUM_WARPS ? red_smem[lane] : -FLT_MAX;
#pragma unroll
for (int mask = NUM_WARPS / 2; mask >= 1; mask /= 2) {
  qk_max = fmaxf(qk_max, __shfl_xor_sync(uint32_t(-1), qk_max, mask));
}

接下来,Warp 中的一些线程会从共享内存中读取存储的最大值。通过这种方式,一个 Warp 可以持有由另一个 Warp 计算出的最大值。持有最大值的线程之间会执行归约操作(蝶形归约)。

qk_max = __shfl_sync(uint32_t(-1), qk_max, 0);

最后,Warp 中的代表线程负责将自己的值传递给 Warp 中的其他线程,这样最终 CUDA 块内的所有线程都可以看到相同的最大值。

image.png

Softmax

float exp_sum = 0.f;
// 从每个线程的编号中获取 Logit。
for (int i = thread_idx; i < context_len; i += NUM_THREADS) {
  // 从前面获得的 QK 最大值中减去 Exp。
  float val = __expf(logits[i] - qk_max);
  logits[i] = val;
  exp_sum += val;
}
// 就像获取 QK Max 那样,在 cuda 块内执行归约。(归约和)
exp_sum = block_sum<NUM_WARPS>(&red_smem[NUM_WARPS], exp_sum);

// 计算 Softmax。
const float inv_sum = __fdividef(1.f, exp_sum + 1e-6f);
for (int i = thread_idx; i < context_len; i += NUM_THREADS) {
  logits[i] *= inv_sum;
}
__syncthreads();

现在取 Softmax 以获得值。线程将 Logit 中的值取与其线程号 (thread_idx) 对应的标量值一个,取 exp(),然后使用拥有的值执行 Reduction,就像上面做的那样。这里,执行的不是取最大值,而是取加和。(block_sum 函数) 然后计算 Softmax 的分母部分并进行除法。

QKV 计算

作为注意力计算的最后一步,我们需要通过对 𝑠𝑜𝑓𝑡𝑚𝑎𝑥(𝑄𝐾𝑇/√𝑑𝑘) 值进行乘法来获得 V。现在,与在我们获得 QK 值时不同,不再使用线程组了。

constexpr int V_VEC_SIZE = MIN(16 / sizeof(scalar_t), BLOCK_SIZE);
using V_vec = typename Vec<scalar_t, V_VEC_SIZE>::Type;
using L_vec = typename Vec<scalar_t, V_VEC_SIZE>::Type;
using Float_L_vec = typename FloatVec<L_vec>::Type;

首先,像计算 QK 时那样使用了 Vec 概念,并将 Vec 的大小设置为单个线程可以读取 16 字节。logits 已转换为 float32 以进行更精确的 Softmax 计算。在这里,将它们再次转换为 scalar_t 值以进行反向计算,因此声明了 Float_L_vec 类型以便能够将 float32 的值读入 scalar_t 类型。

image.png

然后,计算各个线程应处理的大小。当将注意力分数乘以 V 时,V 需要按列单位的数据,如上图所示。(存储的形状为 [head_size * block_size],因此实际访问的是行单位)

一个 Warp 负责一个块。也就是说,32 个线程计算 block_size 个 token 的 QKV。按每个线程分割,然后稍后合并,并确定每个线程应提取多少数据进行计算。最终输出必须针对所有块(= 所有 token)进行合并,因此在最后执行。

// V 矩阵需要上图中对应于列的元素。
// 由于每个线程都需要一次读取 16 字节,因此可通过将块大小除以 V_Vec 大小来计算需要分几次来读取。
// 换句话说,表示在矩阵乘法中,需要多少个 Vec 才能对块内的一个 Row 执行操作。
// 由于每个 Row 包含 Block Size 个元素,因此以下公式得以完善。
constexpr int NUM_V_VECS_PER_ROW = BLOCK_SIZE / V_VEC_SIZE;
// 针对每个 Row,由 Warp 中的每个线程分配来计算乘积,当均匀分配时,表示可以对多少个 Row 执行计算。
// 例如,如果 Row 的大小(=Block Size)为 32,V_Vec 大小为 8,那么每个 Row 需要 4 个 V_Vec,
// 并且 Warp 中可以一次获取 32 个 V_Vec,因此一个 Warp 可以一次处理 32 / 4 = 8 个 Row。
constexpr int NUM_ROWS_PER_ITER = WARP_SIZE / NUM_V_VECS_PER_ROW;
// 一个 Warp 可以一次处理 NUM_ROWS_PER_ITER 个 Row。
// 这个过程需要重复,直到全部处理完 Head Size 个,因此需要重复 Head Size / NUM_ROWS_PER_ITER 次。
constexpr int NUM_ROWS_PER_THREAD = (HEAD_SIZE + NUM_ROWS_PER_ITER - 1) / NUM_ROWS_PER_ITER;

在上述代码中,ROW 的含义是存储于 V 缓存中的表的 [head_size, block_size] 行列中的 head_size 部分。(即,在上述图片中的 V 的列部分)

float accs[NUM_ROWS_PER_THREAD];
#pragma unroll
for (int i = 0; i < NUM_ROWS_PER_THREAD; i++) {
  accs[i] = 0.f;
}

为每个线程分配并初始化供其计算中使用的结果保存的空间。

for (int block_idx = warp_idx; block_idx < num_blocks; block_idx += NUM_WARPS) {
  // 就像计算 QK 值时一样计算偏移量。
  const int physical_block_number = block_table[block_idx];
  const int physical_block_offset = (lane % NUM_V_VECS_PER_ROW) * V_VEC_SIZE;
  const int token_idx = block_idx * BLOCK_SIZE + physical_block_offset;

  // logits 中储存了 float32 类型元素。将其强制转换为 scalar_t 类型的空间进行储存。实际计算将重新将其强制转换为 float32 以进行运算。
  // 一次读取 16 个字节。
  L_vec logits_vec;
  from_float(logits_vec, *reinterpret_cast<Float_L_vec*>(logits + token_idx));

   // 移动到 V 缓存中的合适位置。这时特定的 CUDA 块所看到的缓存是 [head_size, block_size] 形状的空间,因此,如果您将下面的图片转置,就可以看到它。
  const scalar_t* v_ptr = v_cache + physical_block_number * num_heads * HEAD_SIZE * BLOCK_SIZE
                                  + head_idx * HEAD_SIZE * BLOCK_SIZE;

image.png

一个 Warp 操作处理多个块。也就是说,它执行大小为 [1, block_size] 的注意分数和大小为 [block_size, head_size] 的 V 矩阵的乘法。计算以 float32 为单位进行转换后执行计算。

#pragma unroll
for (int i = 0; i < NUM_ROWS_PER_THREAD; i++) {
  const int row_idx = lane / NUM_V_VECS_PER_ROW + i * NUM_ROWS_PER_ITER;
  if (row_idx < HEAD_SIZE) {
    const int offset = row_idx * BLOCK_SIZE + physical_block_offset;
    // 读取 16 字节大小的元素,乘以 accs 变量并存储
    V_vec v_vec = *reinterpret_cast<const V_vec*>(v_ptr + offset);
   // 如果某个 Warp 已经对其他块执行,则从总体上看,相同行将执行点积,因此可以将前一个块计算的值加起来。
    accs[i] += dot(logits_vec, v_vec);
  }
}

image.png

各个线程对负责的部分执行点积。

#pragma unroll
for (int i = 0; i < NUM_ROWS_PER_THREAD; i++) {
  float acc = accs[i];
#pragma unroll
  // 同一行针对计算线程执行 Reduction 操作
  for (int mask = NUM_V_VECS_PER_ROW / 2; mask >= 1; mask /= 2) {
    acc += __shfl_xor_sync(uint32_t(-1), acc, mask);
  }
  // 最终存储特定块的点积结果
  accs[i] = acc;
}

image.png

一个线程负责 V 行列中的一部分行。由于实际的点积要求对整个行进行求和,所以负责同一行的线程之间执行 Butterfly Reduction 以获取实际的和。

// 设置同步所需的内存。在此过程中,它将重新利用用于计算和存储 logits 的 Shared 内存。
float* out_smem = reinterpret_cast<float*>(shared_mem);
#pragma unroll
// 选择同步对象的 Warp。随着进行,需要同步的 Warp 的数量将减少一半。
for (int i = NUM_WARPS; i > 1; i /= 2) {
  int mid = i / 2;
  // 上一半 Warp 中的线程仅写入 Shared 内存。位置是将其分配的所有行。
  if (warp_idx >= mid && warp_idx < i) {
    float* dst = &out_smem[(warp_idx - mid) * HEAD_SIZE];
#pragma unroll
    for (int i = 0; i < NUM_ROWS_PER_THREAD; i++) {
      const int row_idx = lane / NUM_V_VECS_PER_ROW + i * NUM_ROWS_PER_ITER;
      // 分配到特定行的一个线程作为代表将写入到 Shared 内存中。(lane % NUM_V_VECS_PER_ROW == 0)
      if (row_idx < HEAD_SIZE && lane % NUM_V_VECS_PER_ROW == 0) {
        dst[row_idx] = accs[i];
      }
    }
  }
  __syncthreads();

  // 下一半 Warp 中的线程提取由其他 Warp 写入到 Shared 内存中的值并将其添加到它们自己的结果中。位置是将其分配的所有行。
  if (warp_idx < mid) {
    const float* src = &out_smem[warp_idx * HEAD_SIZE];
#pragma unroll
    for (int i = 0; i < NUM_ROWS_PER_THREAD; i++) {
      const int row_idx = lane / NUM_V_VECS_PER_ROW + i * NUM_ROWS_PER_ITER;
      // 分配到特定行的一个线程作为代表从 Shared 内存中提取。(lane % NUM_V_VECS_PER_ROW == 0)
      if (row_idx < HEAD_SIZE && lane % NUM_V_VECS_PER_ROW == 0) {
        accs[i] += src[row_idx];
      }
    }
  }
  __syncthreads();
}

到目前为止,我们一直在以块为单位执行矩阵乘法。现在我们需要与由其他 Warp 处理的块同步。使用的方式与 Butterfly Reduction 类似。不同之处在于,在 Butterfly Reduction 之后,所有线程最终都将具有相同的值,而此同步方式仅使 0 号 Warp 的线程具有最终值。

image.png

此时的同步对象并不是各个不同的块,而是各个 Warp 之间的同步。因为一个 Warp 已经针对多个块进行了计算和求和。

重复这个过程之后,Warp 0 最终将会有大小为 [1, head_size] 的注意力值。

// warp ID 为 0 的线程执行复制操作。
if (warp_idx == 0) {
  scalar_t* out_ptr = out + seq_idx * num_heads * HEAD_SIZE + head_idx * HEAD_SIZE;
#pragma unroll
  for (int i = 0; i < NUM_ROWS_PER_THREAD; i++) {
    const int row_idx = lane / NUM_V_VECS_PER_ROW + i * NUM_ROWS_PER_ITER;
    if (row_idx < HEAD_SIZE && lane % NUM_V_VECS_PER_ROW == 0) {
      from_float(*(out_ptr + row_idx), accs[i]);
    }
  }
}

最后一步是将值复制到最终输出张量。

结束

到目前为止,我们已详细了解了 vLLM 是如何实现的。由于内容深入,说明文字较长且复杂,但我们已尽最大努力使其易于理解。另外,当前 vLLM 还在不断升级,其实现体有很大一部分可能不同,请您知晓。

作者:手抓饼熊​
来源:GiantPandaCV

推荐阅读

欢迎大家点赞留言,更多Arm技术文章动态请关注极术社区嵌入式AI专栏欢迎添加极术小姐姐微信(id:aijishu20)加入技术交流群,请备注研究方向。

推荐阅读
关注数
18835
内容数
1369
嵌入式端AI,包括AI算法在推理框架Tengine,MNN,NCNN,PaddlePaddle及相关芯片上的实现。欢迎加入微信交流群,微信号:aijishu20(备注:嵌入式)
目录
极术微信服务号
关注极术微信号
实时接收点赞提醒和评论通知
安谋科技学堂公众号
关注安谋科技学堂
实时获取安谋科技及 Arm 教学资源
安谋科技招聘公众号
关注安谋科技招聘
实时获取安谋科技中国职位信息