文章修改自:
https://research.colfax-intl....之前解读过两期 LMDeploy Turbomind 里的源码,针对 Hopper 架构。NV 推出了新的指令集。今天我们就来一起看看 TMA
前言
张量内存加速器(TMA)是 NVIDIA Hopper™ 架构中引入的一项新功能,用于在 GPU 的全局内存(GMEM)与其线程块(即 CTA)的共享内存(SMEM)之间进行异步内存复制。与以前的方法相比,TMA 具有许多优势,例如(1)通过使用名为 warp specialization 的异步方法调度来提高 GPU 利用率,以及(2)通过 TMA 复制描述符以单线程方式处理辅助复制数据(如地址和步长)的计算,这既提高了寄存器效率,又处理了 handler(例如,越界检查)。这篇博客文章专注于实现对如何编写使用 TMA 的内核的操作理解。
在整个过程中,我们依赖于 CuTe 库,其中 TMA 通过包装较低级别 GPU 指令的 API 公开。这些指令包括 PTX 指令 cp.async.bulk.tensor 和 cp.reduce.async.bulk.tensor,以及 cuTensorMap 操作数,我们也将在本文中讨论这些内容。
我们将本博文组织为三个主要部分:首先讲解 TMA load,然后是 TMA store,最后介绍更高级的操作,如 TMA store reduce 和 TMA load multicast。简而言之,TMA load 将数据从 GPU 的全局内存(GMEM)复制("加载")到协作线程阵列(CTA)的共享内存(SMEM)中,而 TMA store 则将数据从 CTA 的共享内存复制("存储")到 GPU 的全局内存中。由于 TMA load、TMA store 及其高级变体共享许多概念,我们将在 TMA load 部分介绍大部分必要概念,并在后续章节中只关注其余差异。
此外,由于 TMA 是一种异步操作(在异步代理中执行),我们需要使用特定的内存一致性强制工具,如异步内存屏障(即 mbarrier)和异步内存栅栏(即 fence.proxy.async),以确保内核的正确行为。同步本身是一个广泛的讨论话题,因此我们仅会在实际应用所需的程度上介绍这些概念。
TMA Load
TMA load 操作将数据从全局内存(GMEM)复制到共享内存(SMEM)。在本节中,我们将演示如何编写一个使用 TMA load 实现此目标的内核。使用 TMA load 的内核与使用其他内存复制方法的内核有很大不同,因此我们将首先展示如何为一个简单的示例任务编写这样的内核,然后再解释所涉及的概念。
与传统的内存复制方法相比,TMA load 操作采用了一种全新的异步执行模式,这要求我们以不同的方式组织代码和管理内存同步。在接下来的内容中,我们将通过具体的代码示例,逐步说明如何正确实现和使用 TMA load,并深入解析其工作原理和关键概念。
通过本节的学习,读者将了解 TMA load 的基本用法、内存布局要求以及如何处理异步操作中的同步问题,为后续更复杂的 TMA 操作奠定基础。
Example
为了演示 TMA load 的用法,我们考虑一个简单的任务:对二维行主序矩阵进行分块。给定一个形状为 [m,n] 的矩阵 A 和两个正整数 CTA_M 和 CTA_N。需要注意的是,CTA_M 和 CTA_N 在编译时已知,而 m 和 n 是通过矩阵 A 在运行时提供给我们的。为简化起见,我们假设 m % CTA_M == 0 且 n % CTA_N == 0,不过稍后我们会看到这一要求可以放宽。
我们启动一个大小为 {m/CTA_M, n/CTA_N, 1} 的 CTA 网格,其中第 (i,j) 个 CTA 的共享内存(SMEM)保存来自 A 的第 (i,j) 个形状为 [CTA_M, CTA_N] 的分块。我们可以用 NumPy 伪代码描述这种分配方式:
A = np.random.uniform(M, N)
for i in range(M):
for j in range(N):
cta_i_j = A.reshape(M // CTA_M, CTA_M, N // CTA_N, N)[i, :, j, :]
- 两步执行:为执行这项任务,我们使用 TMA load。在 CuTe 中,TMA load 操作通过两个步骤实现。第一步是在主机代码中构建 TMA 复制描述符,第二步是在内核代码中使用该描述符执行实际的 TMA load。值得注意的是,这种两步实现过程与我们通常使用 CuTe 的 TiledCopy 的方式不同——在 TiledCopy 中,所有复制步骤都写在内核代码中,就像在相关教程中展示的那样。
Host code
在主机端,我们创建三个对象:作为复制源的 GMEM 张量,每个 CTA 上用作复制目标的 SMEM 张量的布局,以及一个接受这两者作为参数的 tma_load 对象。需要注意的是,由于我们在主机端创建 SMEM 布局,所有 CTA 将共享相同的 SMEM 布局用于 TMA load 操作。一旦创建了这些对象,它们就可以传递给设备上的内核,然后在内核中调用 TMA load 操作。
主机端的完整代码块如下:
template <typename T, int CTA_M, int CTA_N>
void host_fn(T* data, int M, int N) {
using namespace cute;
// create the GMEM tensor
auto gmem_layout = make_layout(make_shape(M, N), LayoutRight{});
auto gmem_tensor = make_tensor(make_gmem_ptr(T), gmem_layout);
// create the SMEM layout
auto smem_layout = make_layout(make_shape(CTA_M, CTA_N), LayoutRight{});
// create the TMA object
auto tma_load = make_tma_copy (SM90_TMA_LOAD{}, gmem_tensor, smem_layout);
// invoke the kernel
tma_load_kernel<CTA_M, CTA_N>
<<<1, dim3{M / CTA_M, N / CTA_N, 1}>>>
(tma_load, gmem_tensor, smem_layout);
}
创建 gmem_layout、gmem_tensor 和 smem_tensor 的代码行仅使用了基本的 CuTe 概念,因此我们建议读者参考这些 CuTe 教程来复习相关知识。在这里,我们专注于解释 tma_load 对象。该对象是 cute::TiledCopy 的一个实例,它包含执行 CTA 范围复制操作所需的信息并实现相关方法。在代码片段中,tma_load 对象通过 cute::make_tma_copy 函数的显式默认值创建。这个函数的完整实现有一些微妙之处,我们将在本博文后面讨论 MULTICAST 时深入探讨,但对于大多数用例(如我们的示例任务)来说,使用显式默认值就足够了。我们建议使用显式默认值以避免不必要的复杂性(和潜在的错误)。
深入理解 make_tma_copy
关于 make_tma_copy:
- 它的最后两个参数是 gmem_tensor 和 smem_layout。在底层,make_tma_copy 使用这些信息创建一个 TmaDescriptor,这实际上是 CUtensorMap 的别名。这个描述符对象在 TMA 内核中使用。
- 它的第一个参数是 SM90_TMA_LOAD 的实例。这个对象将复制操作分派给所需的 cp.async.bulk.tensor PTX 调用,我们将在下面的第三部分更深入地讨论这一点。
Kernel code
相关的内核代码片段如下所示。这些代码行包含了许多重要的 TMA 概念,我们将在下文中进行解释。
template <typename T, int CTA_M, int CTA_N, class TmaLoad, class GmemTensor>
void tma_load_kernel(__grid_constant__ const TmaLoad tma_load, GmemTensor gmem_tensor) {
using namespace cute;
constexpr int tma_transaction_bytes = CTA_M * CTA_N * sizeof(T);
__shared__ T smem_data[CTA_M * CTA_N];
__shared__ uint64_t tma_load_mbar;
auto smem_layout = make_layout(make_shape(CTA_M, CTA_N), LayoutRight{});
auto smem_tensor = make_tensor(make_smem_ptr(smem_data), smem_layout);
if (threadIdx.x == 0) {
auto gmem_tensor_coord = tma_load.get_tma_tensor(shape(gmem_tensor));
auto gmem_tensor_coord_cta = local_tile(
gmem_tensor_coord,
Tile<Int<CTA_M>, Int<CTA_N>>{},
make_coord(blockIdx.x, blockIdx.y));
initialize_barrier(tma_load_mbar, /* arrival count */ 1);
set_barrier_transaction_bytes(tma_load_mbar, tma_transaction_bytes);
auto tma_load_per_cta = tma_load.get_slice(0);
copy(tma_load.with(tma_load_mbar),
tma_load_per_cta.partition_S(gmem_tensor_coord_cta),
tma_load_per_cta.partition_D(smem_tensor));
}
__syncthreads();
wait_barrier(tma_load_mbar, /* phase */ 0);
// after this line, the TMA load is finished
}
首先,在第 2 行,内核的 tma_load 参数必须使用 grid_constant const 进行注释。如果我们有两个张量需要从 GMEM 复制到 SMEM,每个张量都必须有自己的 TiledCopy 实例,并且每个实例都必须是 grid_constant const。这是将 cuTensorMap 从主机传递到设备的必要条件,相关文档在此处有所说明。
下一个重要的点是,对于 TMA 复制,只有一个线程负责发起 TMA 操作。在代码片段中,所有与 TMA 相关的变量和指令都包含在从第 12 行开始的 if 块中,该块仅由线程 0 执行。另一方面,第 30 行包含一条指令,要求 CTA 中的所有线程等待 TMA 操作完成。
现在,让我们深入了解 TMA load 的逻辑。这从第 13 行开始,我们创建一个 gmem_tensor_coord 对象,用于保存要复制的 GMEM 张量的坐标。如果我们尝试以下操作:
if(cute::thread(0)) { cute::print(gmem_tensor_coord); }
我们能看到输出是
ArithTuple(_0,_0) o (1024,1024):(_1@1,_1@0)
对于熟悉 CuTe 中平铺复制工作方式的读者来说,第 15-18 行是不言自明的。在这种方式中,GMEM 张量被平铺成更小的分区,每个 CTA 根据块坐标切片到平铺张量中,以获取其对 GMEM 的视图。但需要注意的是,分区应用于前面提到的表示 gmem_tensor 坐标的 ArithTuple,而不是 gmem_tensor 本身。具体来说,ArithTuple 被分区成形状为 [CTA_M,CTA_N] 的块,然后每个 CTA 取其对应的块。
如果我们使用以下方式通过 print_tensor 打印 gmem_tensor_coord_cta:
if(cute::block(7)) { cute::print_tensor(gmem_tensor_coord_cta); }
当 CTA_M == CTA_N == 16:
ArithTuple(0,112) o (_16,_16):(_1@1,_1@0):
(0,112) (1,112) (2,112) (3,112) (4,112) (5,112) (6,112) (7,112) (8,112) (9,112) (10,112) (11,112) (12,112) (13,112) (14,112) (15,112)
(0,113) (1,113) (2,113) (3,113) (4,113) (5,113) (6,113) (7,113) (8,113) (9,113) (10,113) (11,113) (12,113) (13,113) (14,113) (15,113)
// more lines
(0,127) (1,127) (2,127) (3,127) (4,127) (5,127) (6,127) (7,127) (8,127) (9,127) (10,127) (11,127) (12,127) (13,127) (14,127) (15,127)
这些数字是 gmem_tensor 中的坐标,其值将被复制到 CTA 7 的 smem_tensor 中。我们鼓励读者尝试运行这段代码,同时将 cute::block(7) 替换为其他索引,以理解不同的 CTA 从 gmem_tensor 中的哪些坐标复制数据。
接下来,在第 25-27 行发出的复制操作具有 TiledCopy 操作的常见签名,其中源张量被分区坐标所替代。
内存屏障
我们省略了第 20、22 和 30 行,这些行都涉及存在于 SMEM 中的 uint64_t 变量 tma_load_mbar。这是异步事务屏障,我们用它来同步 TMA load 操作与内核中消费已加载到 SMEM 的数据的其余部分。NVIDIA 关于 Hopper 架构的技术博客中对这种类型的屏障给出了高级描述。就我们的内核而言,重要的点如下:
- 我们在第 20 行的共享内存中初始化 mbarrier 对象。CuTe 方法 initialize_barrier 封装了 PTX 指令 mbarrier.init.shared.b64,该指令需要一个额外的到达计数参数。在我们的上下文中,由于只有单个线程会启动 TMA load,我们应该将到达计数设置为 1。此外,mbarrier 的起始阶段将始终设置为 0。
- 我们在第 22 行既执行了到达操作,又使用 CuTe 方法 set_barrier_transaction_bytes 为 mbarrier 对象设置了预期的事务计数,该方法封装了 PTX 指令 mbarrier.arrive.expect_tx.shared::cta.b64。事务计数被设置为等于 TMA load 传输的字节数,这是我们在第 4 行计算的。
- 在第 25-27 行,复制指令(它会分派到所需类型的 cp.async.bulk.tensor)的完成机制始终是 mbarrier::complete_tx::bytes,并使用提供的 mbarrier 对象。
- 在第 30 行,我们对 mbarrier 对象执行等待操作。请注意,所有线程都在 mbarrier 上等待,而与之相对的是只有线程 0 到达 mbarrier,且在 wait_barrier 之前调用 __syncthreads() 是必要的,以解决线程分歧。
- 这里,wait_barrier 封装了 PTX 指令 mbarrier.try_wait.parity.shared::cta.b64。try_wait 限定符(与 test_wait 相对)表示等待是一个阻塞指令。parity 限定符(其使用需要提供一个相位位)表示线程休眠直到 mbarrier 的相位位翻转。因为这是初始化后首次使用 mbarrier 来跟踪完成情况,我们提供 0 作为相位。如果我们要进行另一次 TMA load,则必须翻转相位以重用 mbarrier。
- 一般来说,CUTLASS Pipeline API 提供了一种更高级的方式来处理 mbarrier 对象的生命周期,特别是在进行一系列 TMA loads 时,就像在软件流水线方案中可能做的那样。
- 在 wait_barrier 之后,内存一致性模型为我们提供了以下保证:TMA load 对 SMEM 的写入对所有调用了 mbarrier 等待的线程可见(因此在我们的示例内核中,是 CTA 中的所有线程)。
TMA 步长要求的剩余图块
TMA Store
在我们上面的例子中,我们假设 m%CTA_M==0 和 n%CTA_N==0。然而,对于执行 TMA load 而言,我们完全可以放弃这一假设。在从 GMEM 加载剩余部分到 SMEM 时,我们不需要自己处理越界逻辑,TMA 复制单元会必然地对内存复制进行断言,以避免读取越界。这与上述 TMA load 中使用带有 ArithTuple 的特殊"隐式" CuTe 张量的用法一致——如果我们改用普通的 CuTe 张量,那么它们可能会被切片以产生新的 CuTe 张量,这些新张量可能包含指向 GMEM 的越界指针,不可避免地导致错误。
然而,对于 TMA,我们需要牢记 GMEM 张量本身的步长有一个重要要求,即 16 字节边界要求。正如人们所预期的那样,TMA 不支持复制 GMEM 中任意步长的区域。相反,我们需要假设被复制的块具有:(i) 一个连续的方向(步长为 1),以及 (ii) 其他步长是 16 字节的倍数。这在 CUTLASS 代码库中有所断言。
例如,对于我们的行主序 GMEM 浮点张量,形状为 (m, n),步长为 (n, 1),这就要求 n%4==0。如果不满足这一要求,那么可以在调用内核之前对输入张量进行填充,使其具有正确的范围。
Example task and code
为了说明目的,让我们考虑 TMA load 的反向示例,即从多个 CTA 的 SMEM 复制到分区 GMEM 张量中对应的块。这里的一个区别是,在复制到 GMEM 之前,我们将用简单的数字模式填充 CTA 中的 SMEM 块(否则,我们将复制未定义的值)。功能代码片段如下:
template <typename T, int CTA_M=32, int CTA_N=32>
void host_fn(T* data, int M, int N) {
using namespace cute;
// create the GMEM tensor
auto gmem_layout = make_layout(make_shape(M, N), LayoutRight{});
auto gmem_tensor = make_tensor(make_gmem_ptr(T), gmem_layout);
// create the SMEM layout
auto smem_layout = make_layout(make_shape(CTA_M, CTA_N), LayoutRight{});
// create the TMA object
auto tma_store = make_tma_copy(SM90_TMA_STORE{}, gmem_tensor, smem_layout);
// invoke the kernel
tma_store_kernel<CTA_M, CTA_N>
<<<CTA_M, dim3{M / CTA_M, N / CTA_N, 1}>>>
(tma_store, gmem_tensor, smem_layout);
}
template <typename T, int CTA_M, int CTA_N, class TmaStore, class GmemTensor>
void tma_store_kernel(__grid_constant__ const TmaStore tma_store, GmemTensor gmem_tensor) {
using namespace cute;
__shared__ T smem_data[CTA_M * CTA_N];
auto smem_layout = make_layout(make_shape(CTA_M, CTA_N), LayoutRight{});
auto smem_tensor = make_tensor(make_smem_ptr(T), smem_layout);
// fill the rows of smem_data
for (int j = 0; j < CTA_N; ++j) {
smem_data(threadIdx.x, j) = threadIdx.x;
}
__syncthreads();
tma_store_fence();
if (threadIdx.x == 0) {
auto gmem_tensor_coord = tma_store.get_tma_tensor(shape(gmem_tensor));
auto gmem_tensor_coord_cta = local_tile(
gmem_tensor_coord,
Tile<Int<CTA_M>, Int<CTA_N>>{},
make_coord(blockIdx.x, blockIdx.y));
auto tma_store_per_cta = tma_store.get_slice(0);
copy(tma_store,
tma_store_per_cta.partition_S(smem_tensor),
tma_store_per_cta.partition_D(gmem_tensor_coord_per_cta));
// tma_store_arrive();
}
// tma_store_wait<0>();
}
主机代码看起来几乎与 TMA load 的相同,除了对 tma_store_kernel 的调用。注意,我们已安排每个 CTA 拥有 CTA_M 个线程。在我们的示例中,每个 CTA 在 SMEM 中保存一个 [CTA_M,CTA_N] 的块,使得在第 29-32 行,线程 i 用值 i 填充第 i 行。
在内核代码中,第 39-49 行的 if 块与 tma_load_kernel 中的 if 块类似。特别是,只有线程 0 发出 TMA store 操作。所有张量平铺逻辑在概念上是相同的。然而,复制方向是相反的:对于 TMA store,tma_store_per_cta.partition_S 方法应用于 smem_tensor,而 tma_store_per_cta.partition_D 方法应用于 GMEM 张量的坐标。注意,坐标也表示为一个 ArithTuple,类似于 TMA load。
内存屏障
TMA load 和 store 代码之间最重要的区别是,在使用 TMA store 时不再看到任何 mbarrier 对象。这是因为 TMA store 使用另一种机制来强制内存一致性:内存栅栏(memory fence)。
内存栅栏的目的是在栅栏前后执行线程请求的内存访问之间建立保证顺序。在我们的示例中,我们需要确保在第 29-32 行对 SMEM 的所有写入对线程 0 执行的 TMA store 可见。为此,在第 35 行,我们使用了封装 PTX 指令 fence.proxy.async.shared::cta 的 CuTe 方法 tma_store_fence()。
这条指令包含两个重要的限定符,描述了栅栏的效果:范围(scope)和代理类型(proxykind)。范围表示参与栅栏强制执行顺序的线程集合。在我们的例子中,限定符 cta 将范围定义为 CTA 中的所有线程(这是内存一致性模型所能使用的最小可能范围)。代理类型表示除通用代理外,将参与栅栏强制执行顺序的代理类型。在我们的例子中,我们选择代理类型为 async.shared,因为 TMA store 在异步代理中执行(相对于每个 CTA)。如果我们用不涉及异步代理的其他内存栅栏原语(如 __threadfence_block())替换异步栅栏,我们将破坏内核正确行为所需的保证,在实践中导致竞争条件。
TMA 存储到达并等待
在第 49 和 51 行,我们有 tma_store_arrive(),它提交 TMA store 操作(技术上,作为 cp.async.bulk-group),以及 tma_store_wait(),它等待直到最多有 Count 个已提交的 TMA store 操作处于挂起状态(例如,如果所有操作都应该完成,则将 Count 设置为 0)。当内核中的其他工作需要等待 TMA store 完成时,这些操作很有用——例如,在写出后重用释放的 SMEM 时就需要这样做。然而,由于我们的内核在 TMA store 完成后简单地退出,因此我们在这里不需要 TMA store 的到达和等待模式,所以我们将这些行注释掉。
A Deeper Look at TMA Operations
到目前为止,我们已经学习了如何调用 TMA load 和 TMA store 操作。上表比较和对比了这些操作。要调用任一操作,我们需要在主机代码中通过 cute::make_tma_copy 方法创建一个类似于 TiledCopy 的对象,然后将该对象传递到内核函数中,在那里我们在 cute::copy 中使用它们来实际调用操作。在本节中,我们将深入探讨当我们在内核函数中调用这些 TiledCopy 对象时实际发生了什么。基于这一深入探讨,我们将讨论两个扩展:TMA store reduce 和 TMA load multicast。
TMA Load 和 Store 和 PTX 指令
PTX (Parallel Thread Execution) 是 NVIDIA GPU 的一种低级中间语言。对于我们的讨论,PTX 的相关部分包括一组可以通过 asm volatile 关键字包装的块插入到 CUDA 代码中的指令。特别是,当我们按照前面章节所述调用 cute::copy(tma_load, ...) 或 cute::copy(tma_store, ...) 时,会调用某些 PTX 指令来执行这些操作。通过研究 PTX,我们可以更好地理解 TMA load 和 TMA store。
让我们从 TMA load 开始。回想一下,当我们在主机代码中创建 tma_load 对象时,我们必须提供 GMEM 张量(包含要复制的源数据)和 SMEM 布局(描述数据在每个 CTA 内的样子)。使用这个张量和布局,CuTe 确定当在内核中调用 cute::copy(tma_load, ...) 时要执行的底层 PTX 指令。选择的 PTX 指令取决于 GMEM 张量的秩(注意,这里的秩指的是张量的维度数量,而不是线性代数中的矩阵秩/零度)。在我们的示例中,GMEM 张量的秩为二,因此将执行以下 PTX 指令:
asm volatile (
"cp.async.bulk.tensor.2d.shared::cluster.global.mbarrier::complete_tx::bytes"
" [%0], [%1, {%3, %4}], [%2];"
:
: "r"(smem_int_ptr), "l"(gmem_int_desc), "r"(smem_int_mbar),
"r"(crd0), "r"(crd1)
: "memory");
查看这条 PTX 指令,我们看到许多熟悉的概念。例如,gmem_int_desc 指的是 TMA 描述符中保存的坐标,而 mbarrier::complete_tx::bytes 和 smem_int_mbar 指的是内存屏障。还要注意,tensor.2d 表示我们正在复制一个秩为 2 的张量,即二维矩阵。
事实证明,不仅 TMA load,所有 TMA 操作都是某些 cp.async.bulk 指令的包装器。NVIDIA PTX 文档专门用一整节来讨论 cp.async.bulk 指令,特别是它们的语法和操作数。我们鼓励读者阅读该部分及其中的参考资料,以更全面地研究 TMA 操作,这些操作涵盖的范围远大于本博文所打算介绍的。在这里,我们将讨论通过这些 cp.async.bulk 指令公开的两种 TMA 扩展。
TMA Store 和 Reduce 组合
回顾一下,TMA store 将数据从多个 CTA 的 SMEM 复制到 GMEM 张量中的相应块。我们可以将 TMA store 解释为以下 Python 伪代码所示的赋值操作:
for cta_idx in range(number_of_ctas):
gmem_dst[cta_idx] = smem_src[cta_idx]
如果我们想要做以下操作呢?也就是保存+reduce 的组合操作。
for cta_idx in range(number_of_ctas):
gmem_dst[cta_idx] += smem_src[cta_idx]
# or this:
gmem_dst[cta_idx] = max(gmem_dst[cta_idx], smem_src[cta_idx])
# or this:
gmem_dst[cta_idx] = min(gmem_dst[cta_idx], smem_src[cta_idx])
所有这些操作——即规约求和(reduce sum)、规约求最大值(reduce max)和规约求最小值(reduce min)——在张量程序中都相当常见。特别是,规约求和是 Split-K GEMM 中不可避免的子例程,而规约求最大值和规约求最小值经常用于注意力机制中。尽管这些操作看起来很简单,但在 CUDA 内核中实现它们并不是很直观。在阅读下一段之前,我们邀请读者简要思考一下,为了实现这些目标,必须在 GMEM 和 SMEM 之间进行多少轮数据移动。
将 CTA 的 SMEM 中的值"累积"到 GMEM 张量中的块的规约操作的原始实现包括一次 GMEM 读取、一个处理块和一次 GMEM 写入。首先,从 GMEM 加载原始值到 CTA 的 SMEM 或寄存器中,然后执行规约操作,最后将结果写回。这个过程很慢。
对 TMA store TiledCopy 对象的构造函数稍作修改,可以将这个三步过程压缩为仅一条 PTX 指令,即使用 cp.reduce.async.bulk 而不是 cp.async.bulk。准确地说,我们可以在主机代码上进行以下一行更改:
// original: create a TMA store object
auto tma_store = make_tma_copy(SM90_TMA_STORE{}, gmem_tensor, smem_layout);
// to create a TMA reduce sum object
auto tma_reduce_sum = make_tma_copy(SM90_TMA_REDUCE_ADD{}, gmem_tensor, smem_layout);
然后改用 tma_reduce_sum,它在底层调用的是 cp.reduce.async.bulk 而不是 cp.async.bulk。
顺便说一句,PTX 指令 cp.reduce.async.bulk 自 CUDA 12.0 发布以来就已经可用,但直到 CUTLASS 3.5 版本才通过 CUTLASS 和 CuTe 公开。我们希望其他规约操作将在未来的版本中公开,但如果没有,修改 CuTe 代码以使 TMA reduce add 执行最大值和最小值规约,以及 cp.reduce.async.bulk 提供的其他位运算规约(与、或、异或、递增和递减)是相当简单的。
TMA Load Multicast
在上一节中,我们已经看到研究 PTX 指令使我们能够发现 TMA 规约操作,它可以在某些应用中替代 TMA store。在本节中,我们将研究 TMA load 的多播(multicast)扩展。
为了帮助我们理解,我们首先来看看 cp.async.bulk.tensor 的完整语法:
// global -> shared::cluster:
cp.async.bulk.tensor.dim.dst.src{.load_mode}.completion_mechanism
{.multicast}{.level::cache_hint}
[dstMem],
[tensorMap, tensorCoords],
[mbar]
{, im2colOffsets}
{, ctaMask}
{, cache-policy}
.dst = { .shared::cluster }
.src = { .global }
.dim = { .1d, .2d, .3d, .4d, .5d }
.completion_mechanism = { .mbarrier::complete_tx::bytes }
.load_mode = { .tile, .im2col }
.level::cache_hint = { .L2::cache_hint }
.multicast = { .multicast::cluster }
再次强调,无需完全理解 PTX 指令的语法,我们已经看到了许多熟悉的概念,如 .dim、源的 .global 和完成机制的 .mbarrier。本节重点关注多播(multicast)操作数。
多播指的是这样一种情况:我们有 GMEM 张量中的一个块,想要将其复制到多个 CTA 的多个 SMEM 位置。这在 GEMM 内核(即矩阵乘法)中是典型情况,其中一个输入矩阵的列块被多个行块需要,反之亦然。在这种情况下,虽然 TMA load 仍然完全可用——我们只需向需要它的多个 CTA 提供相同的 TMA 描述符——但 .multicast 操作数使我们能够保证 L2 缓存命中。
让我们考虑将上述 TMA load 示例扩展到包含多播的情况。首先,我们需要定义内核的集群(cluster)维度为非平凡的,因为一组 CTA 共同参与 TMA load 多播操作的要求是它们属于同一个(线程块)集群。为了保持简单,我们将只更改网格维度,如下所示:
// old grid dimensions and implicit trivial cluster dimensions
dim3 grid_dims = dim3{M / CTA_M, N / CTA_N, 1};
dim3 cluster_dums = dim3{1, 1, 1};
// new grid dimensions and cluster dimensions
dim3 grid_dims = dim3{M / CTA_M, N / CTA_N, 2};
dim3 cluster_dums = dim3{1, 1, 2};
请注意,当使用集群时,集群维度必须能够整除网格维度,否则内核将无法启动。在我们的新内核中,我们将安排同一个 GMEM 块被加载到同一集群中每对 CTA 的 SMEM 中,这种情况当且仅当两个 CTA 具有相同的 blockIdx.x 和 blockIdx.y 时发生。
首先,在主机代码中,我们对 TMA load TiledCopy 对象的定义做如下更改:
// original: create a TMA load object
auto tma_load = make_tma_copy(SM90_TMA_LOAD{}, gmem_tensor, smem_layout);
// new: create a TMA load multicast object for the given cluster size
auto tma_load = make_tma_copy(SM90_TMA_LOAD_MULTICAST{},
gmem_tensor, smem_layout, cute::_2{});
我们为最后一个参数(集群大小)写入 _2{},以将其作为编译时常量传递,使用为此目的提供的 CuTe 整数类型。在实践中,更符合惯例的做法是预先定义 ClusterShape 类型(在我们的例子中,为 Shape<_1_1,_2>),然后为该参数写入 size<2>ClusterShape{}。
然后,我们将内核代码更改如下:
template <typename T, int CTA_M, int CTA_N, class ClusterShape,
class TmaLoad, class GmemTensor>
void tma_load_kernel(__grid_constant__ const TmaLoad tma_load,
GmemTensor gmem_tensor) {
using namespace cute;
uint32_t block_rank_in_cluster = cute::block_rank_in_cluster();
constexpr uint32_t cluster_size = size<2>(ClusterShape{}));
constexpr uint16_t tma_mcast_mask = (uint16_t(1) << cluster_size) - 1;
constexpr int tma_transaction_bytes = CTA_M * CTA_N * sizeof(T);
__shared__ T smem_data[CTA_M * CTA_N];
__shared__ uint64_t tma_load_mbar;
auto smem_layout = make_layout(make_shape(CTA_M, CTA_N), LayoutRight{});
auto smem_tensor = make_tensor(make_smem_ptr(T), smem_layout);
auto gmem_tensor_coord = tma_load.get_tma_tensor(shape(gmem_tensor));
auto gmem_tensor_coord_cta = local_tile(
gmem_tensor_coord,
Tile<Int<CTA_M>, Int<CTA_N>>{},
make_coord(blockIdx.x, blockIdx.y));
if (threadIdx.x == 0) {
initialize_barrier(tma_load_mbar, /* arrival count */ 1);
}
__syncthreads();
cute::cluster_sync();
cutlass::arch::fence_barrier_init();
if (threadIdx.x == 0) {
set_barrier_transaction_bytes(tma_load_mbar, tma_transaction_bytes);
auto tma_load_per_cta = tma_load.get_slice(block_rank_in_cluster);
copy(tma_load.with(tma_load_mbar, tma_mcast_mask),
tma_load_per_cta.partition_S(gmem_tensor_coord_per_cta),
tma_load_per_cta.partition_D(smem_tensor));
}
__syncthreads();
wait_barrier(tma_load_mbar, /* phase */ 0);
// after this line, the TMA load is finished
cute::cluster_sync();
}
我们已经突出显示了相关更改。首先,我们现在需要跟踪 CTA 在其集群内的内部索引,这是通过 CuTe 方法 block_rank_in_cluster() 获取的。这将返回特殊寄存器 %cluster_ctarank 的值,在我们的例子中将取值为 0 和 1。为简洁起见,让我们将其称为 ctaid。然后,我们对代码进行了以下三项修改,需要解释:
- 额外的集群同步原语。
- 在多播操作中使用 uint16 位掩码。
- 使用 ctaid 确定用于分割 GMEM 和 SMEM 张量的 TiledCopy 对象的切片。
对于 (1),我们使用 CuTe 方法 cluster_sync(),它按顺序同时执行集群屏障到达和等待操作。我们在两个地方插入这个方法:在第 7-8 行,我们将 cluster_sync() 与栅栏一起使用,以确保 mbarrier 初始化在整个集群范围内可见;在第 41 行,我们使用另一个 cluster_sync() 来确保集群中的两个 CTA 中的一个在另一个仍在等待多播加载完成时不会过早退出。一般来说,会对加载到 SMEM 中的数据进行计算,而最后的 cluster_sync() 会出现在内核代码的最后。
对于 (2),我们向复制操作传递一个 uint16 位掩码,以指定哪些 CTA 将参与 TMA 多播加载。掩码中设置为 1 的位表示哪些 CTA 处于活动状态,一个集群中最多有 16 个 CTA(最大非可移植大小),位的位置对应于 ctaid。因此,在我们的例子中,通过将 tma_mcast_mask 设置为 0b11,我们指定集群中的两个 CTA 都将参与。
最后,对于 (3),ctaid 用于指定从给定 CTA 启动 TMA 多播加载操作时,在 GMEM 中切片时使用的偏移量。为了清楚地解释这一点,考虑以下示例:从 GMEM 将一个 16 x 16 的整数块(以升序行主序初始化为 0-255)加载到集群中两个 CTA 的 SMEM 中。假设我们错误地为两个 CTA 都给出 0 作为 tma_load.get_slice 的参数。那么在加载完成后,我们在两个 CTA 的 SMEM 中获得以下内容:
0 1 2 3 4 5 6 7 8 9 10 11 12 13 14 1516 17 18 19 20 21 22 23 24 25 26 27 28 29 30 3132 33 34 35 36 37 38 39 40 41 42 43 44 45 46 4748 49 50 51 52 53 54 55 56 57 58 59 60 61 62 6364 65 66 67 68 69 70 71 72 73 74 75 76 77 78 7980 81 82 83 84 85 86 87 88 89 90 91 92 93 94 9596 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0
相比之下,如果我们为两个 CTA 都将参数设为 1,那么在两个 CTA 的 SMEM 中我们将得到:
0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255
最后,如果从 ctaid 为 1 的 CTA 给出 0 而从 ctaid 为 0 的 CTA 给出 1,或者从 ctaid 为 0 的 CTA 给出 0 而从 ctaid 为 1 的 CTA 给出 1,就会正确地将整个块加载到两个 CTA 的 SMEM 中。这些打印输出说明,从集群中一个 CTA 发出多播操作会将 GMEM 的一半加载到两个 CTA 的 SMEM 中,而 TiledCopy 的切片决定相应的一半。这与 PTX 文档中对 cp.async.bulk.tensor 多播的描述一致:
源数据被多播到每个目标 CTA 的共享内存中,与 dstMem 在 CTA 相对偏移量相同的位置。
就 TiledCopy 对象而言,它通常具有将线程-值元组映射到块的逻辑坐标的 TiledLayout_TV 布局,CuTe 将 ctaid 视为切片目的的线程索引。例如,在我们的 16 x 16 示例中打印 TiledCopy 会产生以下结果:
TiledCopy
Tiler_MN: (_16,_16)
TiledLayout_TV: (_2,((_16,_16))):(_8,((_16,_1)))
Copy_Atom
ThrID: _1:_0
ValLayoutSrc: (_1,_256):(_0,_1)
ValLayoutDst: (_1,_256):(_0,_1)
ValLayoutRef: (_1,_256):(_0,_1)
ValueType: 32b
这有两个"线程",对应于集群中的两个 CTA,对于 ctaid 为 1 的 CTA,在 (16,16) 块中的逻辑坐标 (8,0) 处给出偏移位置。
Conclusion
在本博文中,我们通过几个简化的示例,使用 CUTLASS 库提供的方法,演示了如何在 CUDA 内核中使用 TMA load、store、store reduce 和 load multicast 在全局内存(GMEM)和共享内存(SMEM)之间执行内存复制。
我们首先提供了 TMA 的概述,并介绍了用户如何在 GPU 内核中调用这些操作。然后,我们深入研究了底层 PTX 指令,以获得对 TMA 的更深入理解。我们希望本博文对希望了解 TMA、复习相关知识或调试使用 TMA 的现有项目的读者有所帮助。
我们省略了一些重要主题,如 TMA 支持的交织模式以及 TMA 将 GMEM 以交错格式复制到 SMEM 的能力,这些能力可以对连续维度之外的步长进行置换。当将 TMA 与 Hopper 架构新增的 Warpgroup 矩阵乘法累加(WGMMA)指令结合使用时,这些特性非常重要,因为它们可以以与 WGMMA 兼容的内存格式加载张量数据。我们将在未来的文章中讨论基于 Hopper 的 GEMM 时解释这些要点。
END
作者:企鹅火烈鸟
来源:GiantPandaLLM
推荐阅读
- LLM 技术报告系列 | Google 团队正式放出 Gemma 3 技术报告
- CARL2010:一种利用领域特定语言可重构性的方法论
- Strong-Baseline架构,无特征增强问鼎反无人机挑战赛
- Tensor-001 矩阵乘法分块乘法概述
欢迎大家点赞留言,更多 Arm 技术文章动态请关注极术社区嵌入式AI专栏欢迎添加极术小姐姐微信(id:aijishu20)加入技术交流群,请备注研究方向。