最近在调研 PyTorch 的一些 features,偶然发现了一些有意思的小细节,小小记录下心得。ZeroRedundancyOptimizer 优化器是 PyTorch1.10 版本发布的功能,如果了解最近大模型训练方向的成果,对这个更加不陌生。ZeroRedundancyOptimizer 是微软提出的一种大规模模型在分布式训练的一种优化策略 [ZeRO],通过将优化器状态进行切分或存储分配,达到节省存储空间的目的。
如上图所示, ZeroRedundancyOptimizer 即是进行第一阶段的优化,在保留模型参数和相应参数梯度的情况下,均分模型优化器状态。
具体细节可以参考 fairscale 和 DeepSpeed 的实现,总的来说ZeroRedundancyOptimizer 只是实现了 ZeRO-DP-1 的功能。这里主要想要介绍一下 PyTorch 在 ZeroRedundancyOptimizer 功能实现上的设计,不得不说 PyTorch 的二次开发性和扩展功能接口真的是丰富。下面从 Join 和 overlap_with_ddp 开始介绍。
Join
Join 是一个上下文管理器,可以围绕每个 rank 上的循环对不平衡的输入进行训练。即对先结束训练的 rank 在上下文管理器里和未结束训练的 rank 进行集合通信 (AllReduce等)。由于在设计上带有一定同步检查机制,可以有效避免无意义的通信 hang 等现象,下面我们以官方代码进行解释。
import os
import torch
import torch.distributed as dist
import torch.multiprocessing as mp
from torch.distributed.algorithms.join import Join, Joinable, JoinHook
BACKEND = "nccl"
WORLD_SIZE = 2
NUM_INPUTS = 5
class CounterJoinHook(JoinHook):
r"""
Join hook for :class:`Counter`.
Arguments:
counter (Counter): the :class:`Counter` object using this hook.
sync_max_count (bool): whether to sync the max count once all ranks
join.
"""
def __init__(
self,
counter,
sync_max_count
):
self.counter = counter
self.sync_max_count = sync_max_count
def main_hook(self):
r"""
Shadows the counter's all-reduce by all-reducing a dim-1 zero tensor.
"""
t = torch.zeros(1, device=self.counter.device)
dist.all_reduce(t)
def post_hook(self, is_last_joiner: bool):
r"""
Synchronizes the max count across all :class:`Counter` s if
``sync_max_count=True``.
"""
if not self.sync_max_count:
return
rank = dist.get_rank(self.counter.process_group)
common_rank = self.counter.find_common_rank(rank, is_last_joiner)
if rank == common_rank:
self.counter.max_count = self.counter.count.detach().clone()
dist.broadcast(self.counter.max_count, src=common_rank)
class Counter(Joinable):
r"""
Example :class:`Joinable` that counts the number of training iterations
that it participates in.
"""
def __init__(self, device, process_group):
super(Counter, self).__init__()
self.device = device
self.process_group = process_group
self.count = torch.tensor([0], device=device).float()
self.max_count = torch.tensor([0], device=device).float()
def __call__(self):
r"""
Counts the number of inputs processed on this iteration by all ranks
by all-reducing a dim-1 one tensor; increments its own internal count.
"""
Join.notify_join_context(self)
t = torch.ones(1, device=self.device).float()
dist.all_reduce(t)
self.count += t
def join_hook(self, **kwargs) -> JoinHook:
r"""
Return a join hook that shadows the all-reduce in :meth:`__call__`.
This join hook supports the following keyword arguments:
sync_max_count (bool, optional): whether to synchronize the maximum
count across all ranks once all ranks join; default is ``False``.
"""
sync_max_count = kwargs.get("sync_max_count", False)
return CounterJoinHook(self, sync_max_count)
@property
def join_device(self) -> torch.device:
return self.device
@property
def join_process_group(self):
return self.process_group
def find_common_rank(self, rank, to_consider):
r"""
Returns the max rank of the ones to consider over the process group.
"""
common_rank = torch.tensor([rank if to_consider else -1], device=self.device)
dist.all_reduce(common_rank, op=dist.ReduceOp.MAX, group=self.process_group)
common_rank = common_rank.item()
return common_rank
def worker(rank):
assert torch.cuda.device_count() >= WORLD_SIZE
os.environ['MASTER_ADDR'] = 'localhost'
os.environ['MASTER_PORT'] = '29500'
dist.init_process_group(BACKEND, rank=rank, world_size=WORLD_SIZE)
counter = Counter(torch.device(f"cuda:{rank}"), dist.group.WORLD)
inputs = [torch.tensor([1]).float() for _ in range(NUM_INPUTS + rank)]
with Join([counter], sync_max_count=True):
for _ in inputs:
counter()
print(f"{int(counter.count.item())} inputs processed before rank {rank} joined!")
print(f"{int(counter.max_count.item())} inputs processed across all ranks!")
def main():
mp.spawn(worker, nprocs=WORLD_SIZE, join=True)
if __name__ == "__main__":
main()
如上示例,当 rank0 结束 for 循环,会触发环境管理器的 main_hook 函数,在 main_hook 函数中会进行通信检查(这点在 DDP源码中有涉及,用来触发不同的运算逻辑),看看是否大家进行了相同的通信操作。即通过 AllReduceSum 进行加和处理,检查是否是判定值。我们可以看到在最后一个 rank 计算完成之后(join之后),最后的 rank 会进行 broadcast 广播 counts 到各个 rank。
Join 的设计,巧妙的给 Module 和 Optimizer 挂载一个外接 hook,我们可以在 forward 和 step 函数中增加与 hook 的函数通信或者其他函数功能,简化一些现有的复杂功能。
overlap_with_ddp
overlap_with_ddp 是 ZeroRedundancyOptimizer 的一个类入参,这里介绍它主要是因为 PyTorch 实现了一个很有意思的参数更新流程。即通过 register_comm_hook 同步使用 buckets 进行反向的梯度同步(AllReduceSum)和 参数更新过程,通过 buckets 粒度的 overlapping 实现时间维度的复用,如下图所示。这里可能稍微有点复杂,有兴趣可以先了解下 reducer 的实现和 DistributedDataParallel 调度逻辑。其实这个想法 我之前在 SE-MoE: A Scalable and Efficient Mixture-of-Experts Distributed Training and Inference System 有提到,不过 PyTorch 的实现也不错。当然这里 PyTorch 复用了一些检查转换操作在前两个 step,这里会自动给 buckets 和 ZeroRedundancyOptimizer 划分后的参数进行对应管理。
参考
- https://pytorch.org/
- https://www.deepspeed.ai/tutorials/zero/
- https://arxiv.org/pdf/1910.02054.pdf"%20 t%20"_blank
- https://github.com/features/copilot
The End
作者:无恶不作
文章来源:GiantPandaCV
推荐阅读
- IA-YOLO数据增强+感知损失,做到大雾天气无痛即可完成YOLO检测器的场景升级
- 地平线提出VAD v2端到端自动驾驶模型 | 远超SP-T3/VAD/DriveMLM等方法
- 从 Intel 与 ARM 的成功历史看 RISC-V
- 马斯克实现承诺,开源Grok-1!3140亿参数迄今最大,远高于ChatGPT 3.5!
更多嵌入式AI干货请关注嵌入式AI专栏。欢迎添加极术小姐姐微信(id:aijishu20)加入技术交流群,请备注研究方向。