拥抱 PyTorch,来自 Gauss 的自我革命

导语

自 2015 年 TensorFlow 开源以来,伴随着深度学习的迅猛发展,通用深度学习框架经历了 10 年的高速发展,大浪淘沙,余者寥寥。曾几何时,也有过性能与易用性之争,也有过学术界和工业界之分,但随着本轮大模型应用的推波助澜,PyTorch 无疑已经成为事实上的大模型“标准框架”。时至今日,PyTorch AOTCompile 特性的发布更是直接撕下了 TensorFlow 最后一块“易于部署”的遮羞布。社区活跃度、性能和易用性,数张无形的大手,推搡着我们去拥抱更加现代化的新质生产力 —— PyTorch。

一、Torch 做推荐,到底行不行?

区别于一统天下的大模型场景,PyTorch 在推荐场景上的应用谈不上广泛,但好在 Gauss 并不是第一个在微信内尝试使用 PyTorch 搭建整条推荐流水线的团队。早在 2022 年,微信内部就进行了初步的尝试[1],并在部分场景实际操作落地,但随着大模型的爆发,研发团队的职责发生了变化,当年的这股 PyTorch “劲风”并没有化为改革的“春风”铺展开来,无疑是一种遗憾。

时间来到 2024 年,Transformers 结构也从大模型蔓延到推荐领域,年初 Meta 一篇介绍“生成式推荐”[2]的工作在推荐系统领域投下了一枚重磅炸弹,业界各大公司纷纷跟进尝试复现,Gauss 团队也不例外。但在复现的过程中我们发现,TensorFlow 在大模型上的落后是全面的。首当其冲的问题是性能,因为大模型社区已经实质上抛弃了 TensorFlow ,这就导致所有针对 Transformers 的开源优化基本都只能找到 PyTorch 版本,比如 Attention 的“标准”实现——FlashAttention 至今都没有提供 TF 的支持。再者是开发复杂度,这点几乎是不言自明的,无论对工程还是算法同学,相比 PyTorch,TensorFlow 都是“难用”的代名词,举个例子,自 PyTorch 2.3 以来,用户自定义的 Triton kernel(什么是 Triton?) 已被全面支持,熟手定制融合 Attention 核的耗时从数周显著缩减至数天,这对快速验证尚不成熟的算法来说简直是不可或缺的特性。

基于上述背景,Gauss 总结了实时推荐系统和生成式推荐对建模工具链的 4 大要求,并依此重新审视 PyTorch 框架,决策其应用于推荐场景的可能性。

关键技术 1 —— 动态词向量

大规模稀疏的词向量,是时下个性化推荐系统的基础配置,TensorFlow 生态中存在 TFRA[3] 扩展,以便支持百亿量级动态变化的词向量。在 PyTorch 生态中,原不具备对应的功能,感谢 TorchRec 团队的贡献,其在 PyTorch 官方扩展 TorchRec 中引入了动态词向量的支持[4],极大的降低了这部分的入门门槛。

关键技术 2 —— 高性能、强一致的推理实现

相较于 PyTorch,很长一段时间内 TensorFlow 的静态图特性使其成为“易于部署上线”的代名词,但随着 PyTorch 2.2 AOTInductor 特性的发布,基于预编译技术可以方便的将用户模型编译成数个动态链接库文件发布上线,不仅规避了 Python 运行时,性能也得到了较为显著的提升,同时训练和推理的一致性也有充分的保障(和训练使用相同的编译技术和核实现),一举三得。

关键技术 3 —— FlashAttention

Attention 相关模块可以说是本轮大模型和生成式推荐中最为重要的核心技术,同时也是计算资源消耗最多的部分,在部分生成式推荐训练过程中,我们发现约 70% 的时间都消耗在 Attention 相关的计算中。因此,支持 FlashAttention 及其变体,快速开发融合算子,从加分项变为必选项。正如前文所言,FlashAttention 的原版实现仅支持 PyTorch,此外,得益于 Triton-lang 和 PyTorch 的有机结合,以及 2.5 版本新增的 FlexAttention 特性,在 PyTorch 生态下定制融合 Attention 算子,相较于 TensorFlow 而言,显得轻而易举。

关键技术 4 —— 容灾与弹性训练

实时推荐系统通常要求流式的进行模型训练,即便是分钟级别的训练中断也会对推荐效果产生不良影响[9]。此外,流式的样本管线通常存在波峰与波谷效应,如果训练资源一直保持波峰期的级别,将会造成极大的浪费。支持在不影响模型效果的前提下,在波峰波谷期动态的扩容与缩容,是评判一个实时训练框架是否完备的重要标准之一。额外的,低优(低价) GPU 资源在公司内的推广使用,也为弹性训练打下了政策基础。框架方面,TensorFlow 本身并不具备动态扩缩容的能力,但配合第三方框架(例如 Horovod)可以实现简单的弹性训练。PyTorch 内置了 弹性训练 的能力,但具体应用到推荐领域的大规模稀疏模型上,也面临恢复续训和扩缩容速度较慢,不能秒级容错的局限。综上,在容灾与弹性方面,TensorFlow 与 PyTorch 的能力均是捉襟见肘,需要额外补齐。

除了上述关键技术之外,分布式训练、异构设备(尤其是 910B)的支持,对推荐模型的训练和推理也十分重要,PyTorch 在这些方面与 TensorFlow 相比,各有千秋,总体上来说,均不会形成掣肘,这里就不过多展开了。

综上,今日之 PyTorch 不仅完全胜任推荐系统和生成式推荐的核心要求,与 TensorFlow 相比,在其中若干方面甚至更胜一筹。Gauss 团队基于 PyTorch 生态快速搭建系统验证了上述想法,此外,还针对生态中的关键技术进行了一轮优化,在生成式推荐场景下,相比 TensorFlow 实现,训练速度提升 3x 倍,结合 multi-targets[2] 特性,推理吞吐提升数十倍。

二、Gauss-Torch 工具链

Gauss-Torch 是 Gauss 团队针对推荐场景和微信环境定制的一组 Torch 扩展包,目前主要应用于生成式推荐相关场景(小模型流式实时训练和大模型离线预训练),核心训练方式采用经典的分布式数据并行(DDP),当然也可以配合 FSDP[5] 或 DeepSpeed[6] 实现 Tensor 并行与流水线并行,但只针对层数较多、训练中每步涉及超十亿甚至百亿以上参数的模型有效。

image.png

图 1 Gauss-Torch 模型训练推理示意

上图展示了 Gauss-Torch 在流式模型实时训练中的核心流程。在 Embedding 存储方面,Gauss-Torch 沿用了 Gauss-TensorFlow 中 TierPS 与 WePS[8] 的双引擎设计,在训练时采用成本友好的 TierPS,在推理时采用高可用的 WePS(不上线的模型、非实时模型可以不部署)。与 Gauss-TensorFlow 不同的是,TierPS 中去除了 GPU 缓存的部分,这部分作为 Gauss-Torch 的核心组件单独进行设计。训练过程中,框架会定期(秒级)向 WePS 同步有更新的 Embedding 参数,同时借助 HDFS/WFS 同步剩余模型参数,保障推理参数的实时性。

在 Embedding 查询部分,Gauss-Torch 并没有直接采用 TorchRec 中的分布式缓存策略,因为其在多机多卡场景下的扩展性并不好(在通讯占比较高的场景下,2 机性能约为单机的 1.6x 左右),而是借鉴发表在 SOSP'23 的工作 Bagpipe[7] 实现了一套流水线感知的单机 Embedding 缓存,通过启发式的预取与淘汰策略,实现了高通讯压力下线性的扩展能力。

自定义 Attention 融合算子方面,Gauss-Torch 初期基于 Triton 实现了一版 HSTU[2] 专用融合算子,性能相比于使用 Torch 算子拼接而成的实现已有不错的提升,近期团队内使用更加底层的 cuda + cutlass 接口重构了这部分实现,单算子性能相较于 Triton 版本又有了 2x ~ 4x 的提升。当然,性能只是一个方面,更重要的是赋能算法同学尝试不同的自定义 Attention 实现,快速实现业务价值。Gauss 团队目前正基于 Torch 提供的编译优化技术,着手抽象更加灵活的 Attention 接口,届时,一天内开发高性能融合 Attention 核将不再是奢望。

推理方面,Gauss-Torch 采用了社区最近几个版本发布的 AOTInductor 特性,其通过 TorchDynamo 将用户模型转化为一组顺序执行的代码,再使用 TorchInductor 将其编译为可执行的二进制程序发布上线,几乎完全兼容所有 Torch 算子和用户自定义算子。与 Torch 其他在线推理方案相比(ONNX,TensorRT)保障性能的同时,具备优秀的兼容性与训练/推理一致性。

最后,在容灾和弹性训练方面,Gauss-Torch 基于 Ray, 以 Torch-Elastic 为基础,摒弃了必须从 checkpoint 恢复的机制,结合多副本的设计实现了秒级故障恢复、分钟级弹性扩缩容的能力。

三、结语

微信 Gauss 团队致力于提供业界领先的推荐工程技术服务,我们将持续跟进、推动关键领域的技术进步,与业界同仁一道,为提升用户价值而不懈努力。

参考文献

[1] PyTorch 也能做推荐?TorchRec 的初步尝试,https://www.sohu.com/a/560275...
[2] Zhai J, Liao L, Liu X, et al. Actions speak louder than words: Trillion-parameter sequential transducers for generative recommendations[J]. arXiv preprint arXiv:2402.17152, 2024.
[3] TensorFlow Recommenders Addons(TFRA), https://github.com/tensorflow...
[4] torchrec, dynamic_embedding https://github.com/pytorch/to...
[5] FSDP, https://pytorch.org/tutorials...
[6] DeepSpeed, https://github.com/microsoft/...
[7] Agarwal S, Yan C, Zhang Z, et al. Bagpipe: Accelerating deep recommendation model training[C]//Proceedings of the 29th Symposium on Operating Systems Principles. 2023: 348-363.
[8] Sima C, Fu Y, Sit M K, et al. Ekko: A {Large-Scale} deep learning recommender system with {Low-Latency} model update[C]//16th USENIX Symposium on Operating Systems Design and Implementation (OSDI 22). 2022: 821-839.
[9] He X, Pan J, Jin O, et al. Practical lessons from predicting clicks on ads at facebook[C]//Proceedings of the Eighth International Workshop on Data Mining for Online Advertising. 2014: 1-9.

END

作者:cedric
文章来源:腾讯技术工程

推荐阅读

更多腾讯 AI 相关技术干货,请关注专栏腾讯技术工程 欢迎添加极术小姐姐微信(id:aijishu20)加入技术交流群,请备注研究方向。
推荐阅读
关注数
8150
内容数
230
腾讯AI,物联网等相关技术干货,欢迎关注
目录
极术微信服务号
关注极术微信号
实时接收点赞提醒和评论通知
安谋科技学堂公众号
关注安谋科技学堂
实时获取安谋科技及 Arm 教学资源
安谋科技招聘公众号
关注安谋科技招聘
实时获取安谋科技中国职位信息