聊聊PyTorch的ZeroRedundancyOptimizer优化器

最近在调研 PyTorch 的一些 features,偶然发现了一些有意思的小细节,小小记录下心得。ZeroRedundancyOptimizer 优化器是 PyTorch1.10 版本发布的功能,如果了解最近大模型训练方向的成果,对这个更加不陌生。ZeroRedundancyOptimizer 是微软提出的一种大规模模型在分布式训练的一种优化策略 [ZeRO],通过将优化器状态进行切分或存储分配,达到节省存储空间的目的。

image.png

如上图所示, ZeroRedundancyOptimizer 即是进行第一阶段的优化,在保留模型参数和相应参数梯度的情况下,均分模型优化器状态。

image.png

具体细节可以参考 fairscale 和 DeepSpeed 的实现,总的来说ZeroRedundancyOptimizer 只是实现了 ZeRO-DP-1 的功能。这里主要想要介绍一下 PyTorch 在 ZeroRedundancyOptimizer 功能实现上的设计,不得不说 PyTorch 的二次开发性和扩展功能接口真的是丰富。下面从 Join 和 overlap_with_ddp 开始介绍。

Join

Join 是一个上下文管理器,可以围绕每个 rank 上的循环对不平衡的输入进行训练。即对先结束训练的 rank 在上下文管理器里和未结束训练的 rank 进行集合通信 (AllReduce等)。由于在设计上带有一定同步检查机制,可以有效避免无意义的通信 hang 等现象,下面我们以官方代码进行解释。

import os
import torch
import torch.distributed as dist
import torch.multiprocessing as mp
from torch.distributed.algorithms.join import Join, Joinable, JoinHook

BACKEND = "nccl"
WORLD_SIZE = 2
NUM_INPUTS = 5

class CounterJoinHook(JoinHook):
    r"""
    Join hook for :class:`Counter`.

    Arguments:
        counter (Counter): the :class:`Counter` object using this hook.
        sync_max_count (bool): whether to sync the max count once all ranks
            join.
    """
    def __init__(
        self,
        counter,
        sync_max_count
    ):
        self.counter = counter
        self.sync_max_count = sync_max_count

    def main_hook(self):
        r"""
        Shadows the counter's all-reduce by all-reducing a dim-1 zero tensor.
        """
        t = torch.zeros(1, device=self.counter.device)
        dist.all_reduce(t)

    def post_hook(self, is_last_joiner: bool):
        r"""
        Synchronizes the max count across all :class:`Counter` s if
        ``sync_max_count=True``.
        """
        if not self.sync_max_count:
            return
        rank = dist.get_rank(self.counter.process_group)
        common_rank = self.counter.find_common_rank(rank, is_last_joiner)
        if rank == common_rank:
            self.counter.max_count = self.counter.count.detach().clone()
        dist.broadcast(self.counter.max_count, src=common_rank)

class Counter(Joinable):
    r"""
    Example :class:`Joinable` that counts the number of training iterations
    that it participates in.
    """
    def __init__(self, device, process_group):
        super(Counter, self).__init__()
        self.device = device
        self.process_group = process_group
        self.count = torch.tensor([0], device=device).float()
        self.max_count = torch.tensor([0], device=device).float()

    def __call__(self):
        r"""
        Counts the number of inputs processed on this iteration by all ranks
        by all-reducing a dim-1 one tensor; increments its own internal count.
        """
        Join.notify_join_context(self)
        t = torch.ones(1, device=self.device).float()
        dist.all_reduce(t)
        self.count += t

    def join_hook(self, **kwargs) -> JoinHook:
        r"""
        Return a join hook that shadows the all-reduce in :meth:`__call__`.

        This join hook supports the following keyword arguments:
            sync_max_count (bool, optional): whether to synchronize the maximum
                count across all ranks once all ranks join; default is ``False``.
        """
        sync_max_count = kwargs.get("sync_max_count", False)
        return CounterJoinHook(self, sync_max_count)

    @property
    def join_device(self) -> torch.device:
        return self.device

    @property
    def join_process_group(self):
        return self.process_group

    def find_common_rank(self, rank, to_consider):
        r"""
        Returns the max rank of the ones to consider over the process group.
        """
        common_rank = torch.tensor([rank if to_consider else -1], device=self.device)
        dist.all_reduce(common_rank, op=dist.ReduceOp.MAX, group=self.process_group)
        common_rank = common_rank.item()
        return common_rank

def worker(rank):
    assert torch.cuda.device_count() >= WORLD_SIZE
    os.environ['MASTER_ADDR'] = 'localhost'
    os.environ['MASTER_PORT'] = '29500'
    dist.init_process_group(BACKEND, rank=rank, world_size=WORLD_SIZE)

    counter = Counter(torch.device(f"cuda:{rank}"), dist.group.WORLD)
    inputs = [torch.tensor([1]).float() for _ in range(NUM_INPUTS + rank)]

    with Join([counter], sync_max_count=True):
        for _ in inputs:
            counter()

    print(f"{int(counter.count.item())} inputs processed before rank {rank} joined!")
    print(f"{int(counter.max_count.item())} inputs processed across all ranks!")

def main():
    mp.spawn(worker, nprocs=WORLD_SIZE, join=True)

if __name__ == "__main__":
    main()

如上示例,当 rank0 结束 for 循环,会触发环境管理器的 main_hook 函数,在 main_hook 函数中会进行通信检查(这点在 DDP源码中有涉及,用来触发不同的运算逻辑),看看是否大家进行了相同的通信操作。即通过 AllReduceSum 进行加和处理,检查是否是判定值。我们可以看到在最后一个 rank 计算完成之后(join之后),最后的 rank 会进行 broadcast 广播 counts 到各个 rank。

Join 的设计,巧妙的给 Module 和 Optimizer 挂载一个外接 hook,我们可以在 forward 和 step 函数中增加与 hook 的函数通信或者其他函数功能,简化一些现有的复杂功能。

overlap_with_ddp

overlap_with_ddp 是 ZeroRedundancyOptimizer 的一个类入参,这里介绍它主要是因为 PyTorch 实现了一个很有意思的参数更新流程。即通过 register_comm_hook 同步使用 buckets 进行反向的梯度同步(AllReduceSum)和 参数更新过程,通过 buckets 粒度的 overlapping 实现时间维度的复用,如下图所示。这里可能稍微有点复杂,有兴趣可以先了解下 reducer 的实现和 DistributedDataParallel 调度逻辑。其实这个想法 我之前在 SE-MoE: A Scalable and Efficient Mixture-of-Experts Distributed Training and Inference System 有提到,不过 PyTorch 的实现也不错。当然这里 PyTorch 复用了一些检查转换操作在前两个 step,这里会自动给 buckets 和 ZeroRedundancyOptimizer 划分后的参数进行对应管理。

image.png

参考

  1. https://pytorch.org/
  2. https://www.deepspeed.ai/tutorials/zero/
  3. https://arxiv.org/pdf/1910.02054.pdf"%20 t%20"_blank
  4. https://github.com/features/copilot

The End

作者:无恶不作
文章来源:GiantPandaCV

推荐阅读

更多嵌入式AI干货请关注嵌入式AI专栏。欢迎添加极术小姐姐微信(id:aijishu20)加入技术交流群,请备注研究方向。

推荐阅读
关注数
18838
内容数
1371
嵌入式端AI,包括AI算法在推理框架Tengine,MNN,NCNN,PaddlePaddle及相关芯片上的实现。欢迎加入微信交流群,微信号:aijishu20(备注:嵌入式)
目录
极术微信服务号
关注极术微信号
实时接收点赞提醒和评论通知
安谋科技学堂公众号
关注安谋科技学堂
实时获取安谋科技及 Arm 教学资源
安谋科技招聘公众号
关注安谋科技招聘
实时获取安谋科技中国职位信息