在3月25日的MegEngine Meetup中,旷视研究院周亦庄讲师分享了《利用 MegEngine 分布式通信算子实现复杂的并行训练》。
分享内容主要分为四个部分:
1、 介绍 MegEngine 的分布式通信算子;
2、 简单参数并行,用于熟悉模型并行的一些基本概念;
3、 层内模型并行;
4、 层间模型并行和流水线并行,同时介绍了如何实现一个简单的 GPipe。
本文为该分享实录的上篇,主要包含:背景介绍、MegEngine 的分布式通讯算子及简单参数并行。Enjoy~
背景
并行训练是开展深度学习研究和业务非常重要的一环,很多基础研究都需要大规模的计算集群甚至是超级计算机来完成。比如,像我们知道的 DeepMind 下围棋的 AlphaGo,还有OpenAI 的 1750 亿(175 billion)参数的超大语言模型 GPT-3,最近 OpenAI 还搞了一个 CLIP 和 DALL-E,他们都是用非常大的集群来进行分布式训练的。而因为旷视研究院有 Brain++ 这个分布式的计算平台,所以我们也有很多优秀的成果。大模型在各类视觉和语言任务上相比于小模型都有显著优势,所以最近的一种趋势是模型规模、数据规模越大越好,“大即正义”,因此更需要大规模的并行训练。
并行训练,一方面可以调动上百甚至上千块 GPU(图形处理器,又称”显卡”,简称”卡”,是深度学习最常见的计算设备)进行训练,第二部分也是根据业务或模型的特点,我们可以设计出最高效的并行模式。这是我今天讲的并行训练的一个现实意义。
先来讲一下深度学习当中有三种比较常用的并行模式,三种并行模式的关系用下面这张图就可以表达清楚。
第一种(层内模型并行),是利用矩阵乘法天然的并行特性,把每层(比如全连接层或卷积层)内部的矩阵乘法计算给拆开,表现为沿着输入/输出 通道(channel) 拆开进行分组计算,这就叫层内模型并行。
第二种(层间模型并行),是利用神经网络串行执行的特性,把网络按照执行顺序拆开,分别放到不同的设备上进行计算,比如说我们一个 ResNet18,它有 17 层卷积层加上最后一层全连接层,如果我们把前九层的和后九层的计算放到两块卡(即 GPU/显卡)上,它就是叫层间模型并行。层间与层内这两种模型并行方式是“正交”的,互不影响,可以同时存在。
以上说的两种并行,它的模型参数都是拆开来的,每个计算节点(计算节点是底层计算设备的一种抽象,它可以是一张卡,也可以是一台或者一组 8 卡机,即装载 8 块 GPU 的计算机)只负责管理整个网络的一部分参数以及这部分参数参与的相应计算。
最后一种就是我们最常用的数据并行,它又是另外一个维度,在数据并行维度上,模型参数都是共享的,但是接收的数据是不一样的。通过增加计算设备,我们可以近似线性地增加单次迭代的 batch size(批量,即训练图片的数量),从而节省训练模型的时间。
这三种并行维度是两两正交的,意思是在实际训练中我们既会用到两种模型并行也会用到数据并行。小模型可能数据并行就足够了,但大模型由于参数特别多、计算量非常大,计算难以用单个 GPU 完成,这时候就要将计算拆解到不同 GPU 上,此即模型并行。
MegEngine 的通信算子
接下来,进入到今天要讲的正题。先说通信算子。
人类的历史它其实就是一个信息交互的历史,也就是一个通信的历史——人与人之间说话就是通信,我今天做直播,它其实也是通信,我把信息广播给大家,这也是通信,电视和广播当然也是通信。
对于深度学习框架来说,通信是最重要的功能之一,否则数据并行和模型并行难以实现。简单来说就是我有很多个计算设备(GPU),我需要让信息在所有计算设备之间进行交互,那就需要集合通信——集合通信是一个求导完备的一套通信规则。
表中列了有 8 种集合通信算子和 2 种点对点通信算子,这就是 MegEngine 全部的通信算子。8 种集合通信算子,构成一套求导完备的通信的规则,它们互相各自为导数。MegEngine 提供了对通信算子的自动求导,所以和其它所有用于计算的算子(如卷积、ReLU、转置等)一样,我们可以自由地把通信算子加入前向计算图,框架将负责对其求导。
考虑到有些同学没有背景知识,我们一一介绍一下集合通信算子的功能。
Broadcast
Broadcast 即广播。
它表示的是数据的一个同步的过程,将一张 GPU 上的信息同步给其它所有 GPU。这在数据并行中非常有用,因为数据并行的话,每张卡上面的参数应该确保都是一样的,因此在初始化时我们会通过 Broadcast 进行参数同步,我们也会周期性同步一些缓存信息(buffer,比如 BatchNorm 的统计量)。
ReduceSum
第二个是 ReduceSum。ReduceSum 叫做求和或者归约,将所有 GPU 上的数据收集到一个 GPU 上并相加。
我们刚才讲的 Broadcast 和 ReduceSum 这两个通信算子是构成参数服务器 Parameter Server 的一个基石,它是中心式的,在这里面 GPU0 就起到一个中心的作用,我先把中心参数通过 Broadcast 同步给各张卡进行前传,反传后通过Reduce收集各张卡的梯度,进行参数更新。Broadcast 和 ReduceSum ,互为导数的, ReduceSum 的导数就是 Broadcast,Broadcast 的导数是 ReduceSum 。
AllReduce
我们再介绍 AllReduce,本来 Reduce 是归约到一张卡上,AllReduce 则是归约到每一张卡上。它即可以理解为 Reduce Broadcast 的组合,即我先 Reduce 到一张卡,然后再 Rroadcast 到所有卡;也可以理解为每张卡都同时调用了 Reduce,AllReduce 它的导数就是 AllReduce 本身。
尽管只用 Reduce 和 Broadcast 就可以实现 AllReduce,但是 AllReduce 的高效实现(即 Ring-AllReduce)才是构成现代 分布式深度学习框架的基石,它的通信时间基本不随 GPU 数量的增加而增加,因此可以高效地实现分布式训练的规模化。在数据并行中,我们用 AllReduce 将所有梯度求和,并用于模型参数更新。
Gather
Gather 简单来说就是把每张卡上不同的信息都给收集过来,并沿着第一维相连(Concatenate)。
AllGather
AllGather 就是全收集,和 AllReduce 类似,它可以理解为 Gather 后接 Broadcast。
AllGather 是我们层内模型并行当中的一个很重要的操作,因为你的参数在不同的卡上,你的数据也在不同的卡上,在进行模型并行的时候,我需要把数据或者参数都收集起来放到一张卡上才能进行接下来的计算,这就是 AllGather 的作用。
AllToAll
AllToAll 也是层内模型并行中经常用到的一个操作,特别是在模型并行和数据并行进行切换的时候,它本质上对一个矩阵进行了转置,我们后面在具体应用中会进一步说明。AllToAll 的导数是它本身。
Scatter 和 ReduceScatter
最后 Scatter 和 ReduceScatter 合起来讲,Scatter 就是分发,它将一张卡上的数据拆分给各张卡,它和 Gather 互为导数。
ReduceScatter 可以理解为在分发之前先进行了求和,它和 AllGather 互为导数。
简单参数并行
介绍完 MegEngine 的通信算子,我们来了解它们如何使用。首先,让我们从简单参数并行开始,它只涉及 AllGather 这一通信算子。
简单参数并行是怎么一回事?我们先用一个简单的全连接层(即矩阵乘法)来回顾一下数据并行——数据并行中,W 是我们的模型(即我们的权重 weight,每张卡拥有一份同样的拷贝),x 是数据。数据并行要求我们将数据平均拆分到每张卡上,2 卡拆 2 份,即 x0 和 x1,4 卡则拆成 4 份,依此类推,各张卡分别进行矩阵乘法计算,得到对应的结果 y。
简单参数并行本质是数据并行的优化?
我们不必在每张卡上都放完整的模型,而是只放部分模型,只有在我们需要(即前传)的时候,把分散在各张卡上的参数收集(AllGather)起来参与计算。
如何实现?
我们在做矩阵乘法操作之前,先对参数进行 AllGather,从各个节点上收集被我们拆开的参数,AllGather 以后每张卡都有全部的权重了,计算就变得和数据并行一模一样的。所以,简单参数并行的核心操作就是 AllGather,本质用通信来节省显存。
为什么能节省显存呢?
我们现在把整个求导过程也画出来了,我们知道在训练一份参数的时候,它其实是会占掉三份显存——参数一份,梯度一份,优化器的 momentum 一份,所以一个参数量 1 million 的模型,如果我们使用数据并行,会占用 3 * 4G = 12G(1 million fp32 类型的数据占用 4G)的显存,那我们一张 2080ti 就完全没有显存可以用于训练了。
我们再来研究一下这张图,我们在前传的时候做了一次 AllGather,在反传的时候,我们知道 AllGather 的导数是 ReduceScatter,所以,它反传的时候会进行一次 ReduceScatter 。这和数据并行不一样,数据并行前传不需要通信,反传需要进行 AllReduce这是他们的区别。
我们用 MegEngine 写了一套数据并行和简单参数并行的代码,它们有三个不同:
- 一个不同是它们的前传是不一样的——右边(简单参数并行)就是要做一次 AllGather ;
- 还有一个不同就是他们在参数初始化的时候。在数据并行中我们需要参数同步,所以我们要 Broadcast,但是在简单参数并行里面,我们需要的是参数分发,所以用 Scatter,就把它们给分发出去。
- 最后一个不同就是在求导的时候,求导的时候在数据并行当中我们需要进行 AllReduce(MegEngine 使用 AllReduce callback 来支持数据并行),但是在简单参数并行里面不需要进行 AllReduce,自动微分器会负责反传时正确调用 ReduceScatter。
专栏文章推荐
欢迎关注旷视研究院极术社区专栏,定期更新最新旷视研究院成果
加入旷视:career@megvii.com