PyTorch Async Checkpoint Save
- PyTorch博客资料:https://pytorch.org/blog/redu...
- PyTorch实现和使用Demo:https://github.com/pytorch/py...
功能介绍
在PyTorch 2.4之后,我们可以尝试使用PyTorch开发的异步Checkpoint保存功能,这个功能是和IBM联合开发的,在7B的大模型训练中,Checkpoint保存的时间从平均 148.8 秒缩短至 6.3 秒,快了 23.62 倍。这可以转化为以下两种好处:
- 在继续鲁棒的保存Checkpoint的同时,在每给定的 24 小时内实现更多净训练进度;
- 可以更频繁地进行Checkpoint保存以缩短训练恢复窗口或者时间。
从结果图来看,无论是单机的FSDP还是多机的HSDP,Async Checkpoint Save都展现出了很大的速度优势,对于参数量更大的模型预计收益会更大。目前在TorchTian(https://github.com/pytorch/to...)中已经集成了这个新的功能,相信其他的主流训练框架都会很快跟进此feature。
博客内容
背景
模型Checkpoint是大模型训练的重要组成部分,但Checkpoint是一个昂贵的过程,因为每个Checkpoint过程都需要阻止训练进度以保存最新的模型权重。但是,不进行Checkpoint或降低Checkpoint频率会导致训练进度损失很大。例如,死锁、straggler(落后者)和 GPU 错误等故障导致需要重新启动训练过程。为了从故障中重启,所有(训练)工作者必须停止其训练过程并从上次保存的Checkpoint重新启动。
因此,对故障的鲁棒性与训练进度之间很难做到权衡,但现在有了异步Checkpoint,PyTorch 分布式训练能够显著缓解这种张力,并以最小的影响整体训练时间的方式实现频繁Checkpoint。
大约一年前(https://pytorch.org/blog/perf...),我们展示了分布式Checkpoint如何大幅加速Checkpoint时间,从最初的 torch.save()
功能开始。正如 IBM 研究团队指出的那样,torch.save
可能需要长达 30 分钟才能检查一个 11B 模型(PyTorch 1.13)。
随着分布式Checkpoint的进步,对于高达 30B 的模型大小,Checkpoint可以在 4 分钟内完成。使用异步Checkpoint,Checkpoint导致的训练时间损失现在降至 30 秒以下,通常仅需 6 秒。
需要明确的是,异步Checkpoint不会压缩实际的序列化Checkpoint时间,如之前的更新所展示的那样。相反,它将最终的Checkpoint过程移出关键路径(到 CPU 线程),以允许 GPU 训练在单独的线程下完成Checkpoint的同时继续进行。
如上图所示,异步Checkpoint比一年前的改进进一步提高了 10 倍到 23 倍。
Async Checkpoint Save如何工作
异步Checkpoint将Checkpoint过程模块化分为两个部分,而不是一个单一的整体过程。第一阶段将每个 GPU/rank 的数据从 GPU 复制到 CPU。这是用户可见的停机时间,对于 7B-13B 的模型大小可能需要 6 到 14 秒。第二阶段异步地将数据从 CPU 内存复制到磁盘以持久保存Checkpoint。
一旦数据在第一阶段复制到 CPU,GPU 就可以立即恢复训练。因此,使用异步Checkpoint,Checkpoint的停机时间仅仅是将最新的模型状态复制到 CPU 所需的时间。在训练恢复的同时,非阻塞 CPU 线程使用内存中新到达的数据完成完整的Checkpoint/序列化过程到磁盘(即持久保存)。
注意,PyTorch 的分布式Checkpoint依赖于集合通信调用来获取必要的每个等级元数据以优化保存,以及最终的同步,该同步将Checkpoint标记为已完成并使操作成为原子操作。如果Checkpoint线程使用与训练相同的进程组,这可能会干扰分布式训练(因为分布式训练也依赖于类似的调用来跨多个 GPU 同步训练)。具体来说,调用之间的竞争条件可能会导致训练和异步Checkpoint保存线程同时等待集合通信调用,从而导致真正的集合通信卡死。我们通过为异步Checkpoint初始化一个单独的进程组来避免这种情况。这将Checkpoint集合通信分离到其自己的逻辑进程组中,从而确保它不会干扰主训练线程中的集合通信调用。
如何使用PyTorch Async Checkpoint Save
这里是最小的使用PyTorch Async Checkpoint Save的demo:
需要注意的是第12行,为异步的Checkpoint集合通信操作建立了一个新的group,然后调用dcp.save
的时候我们需要传入这个group。
https://github.com/pytorch/to... 里面也已经使用上了这个功能,可以用于预训练自己的 Llama2 或 Lllama3 模型。在配置文件里面就可以选择使用Async Checkpoint Save。如下图所示:
代码流程粗略浏览
代码实现在 https://github.com/pytorch/py... 这个文件中。核心部分为以下2个函数,这里简单展示下流程:
# 创建 state_dict 的浅拷贝,对于每个 Stateful 对象调用其 state_dict() 方法。
def _stateful_to_state_dict(state_dict: STATE_DICT_TYPE) -> STATE_DICT_TYPE:
"""Creates a shallow copy of `state_dict` where `state_dict` is called for each Stateful object."""
stateful_state_dict = {}
for key, elem in state_dict.items():
stateful_state_dict[key] = (
elem.state_dict() if isinstance(elem, Stateful) else elem
)
return stateful_state_dict
@_dcp_method_logger(log_exceptions=True)
def async_save(
state_dict: STATE_DICT_TYPE,
*,
checkpoint_id: Union[str, os.PathLike, None] = None,
storage_writer: Optional[StorageWriter] = None,
planner: Optional[SavePlanner] = None,
process_group: Optional[dist.ProcessGroup] = None,
) -> Future:
torch._C._log_api_usage_once("torch.distributed.checkpoint.async_save")
# 检查分布式环境设置
if dist.is_available() and dist.is_initialized():
pg = process_group or _get_default_group()
assert (
torch.device("cpu") in pg._device_types # type: ignore[attr-defined]
), "A CPU backend must be enabled for async save; try initializing process group with 'cpu:gloo,cuda:nccl'"
# 设置存储写入器
storage_writer = cast(
StorageWriter, _storage_setup(storage_writer, checkpoint_id, reader=False)
)
# 处理状态字典(调用 _stateful_to_state_dict)
state_dict = _stateful_to_state_dict(state_dict)
# 如果存储写入器支持异步暂存,则使用它;否则将状态字典卸载到 CPU
if isinstance(storage_writer, AsyncStager):
staged_state_dict = storage_writer.stage(state_dict)
else: # provides bwc for storage_writers not implementing AsyncStager
staged_state_dict = _offload_state_dict_to_cpu(state_dict, type_check=False)
# 创建线程池执行器,提交保存任务。这里是一个线程
executor = ThreadPoolExecutor(max_workers=1)
f: Future = executor.submit(
save,
staged_state_dict,
checkpoint_id=checkpoint_id,
storage_writer=storage_writer,
planner=planner,
process_group=process_group,
)
# 设置任务完成后的回调函数(关闭执行器)
f.add_done_callback(lambda f: executor.shutdown(wait=False))
# 如果需要,同步暂存操作
if (
isinstance(storage_writer, AsyncStager)
and storage_writer.should_synchronize_after_execute
):
storage_writer.synchronize_staging()
# 返回 Future 对象
return f
将来的改进
PyTorch Blog中提到Checkpoint在过去的一年中取得了巨大进步。从近半个小时的Checkpoint变为使用分布式Checkpoint不到 5 分钟,现在又变为使用异步Checkpoint不到 30 秒。最后一个前沿是零开销Checkpoint,即使是小于 30 秒的时间也可以通过在反向传递期间流式传输更新的权重来消除,这样Checkpoint数据在异步Checkpoint开始时就已经在 CPU 上了。这将有效地将大型模型训练转移到Checkpoint没有中断或停机时间的程度,从而既提高了鲁棒性(因为可以更频繁地进行Checkpoint),又因为没有Checkpoint的停机时间而加快了训练进度。
作者:BBuf
来源:GiantPandaCV
推荐阅读
- SGLang:LLM推理引擎发展新方向
- CUDA-MODE课程笔记 第7课: Quantization Cuda vs Triton
- CUDA-MODE 第一课课后实战(上)
- 一文弄懂 LLM 结构化数据生成原理
欢迎大家点赞留言,更多Arm技术文章动态请关注极术社区嵌入式AI专栏欢迎添加极术小姐姐微信(id:aijishu20)加入技术交流群,请备注研究方向。