详解PyTorch FSDP数据并行(Fully Sharded Data Parallel)
1. 背景介绍
全切片数据并行(Fully Sharded Data
Parallel,简称为FSDP)是数据并行的一种新的方式,FSDP最早是在2021年在FairScale-FSDP中提出的,后来合入了PyTorch
1.11版本中。微软之前Deepspeed框架中提出过三种级别的ZERO算法,FSDP可以看成是ZERO-3
的实现。
2. 详细介绍
传统的数据并行(DDP)是在每一个GPU卡上保存整个model的参数/梯度/优化器状态, 然后对数据集切分为 \(N\) 个shard分片给不同的GPU进行训练,计算完梯度后通过all-reduce通信来做梯度的融合。如下图:
在FSDP中的主要思路是想办法把model的梯度/优化器状态/参数都进行切分操作,每个GPU只存部分的参数信息,也就是在ZERO-3
的思路。为了能把所有的参数进行分片处理,核心在于要把DDP中的all-reduce操作拆解为reduce-scatter和all-gather
操作。
如下图,在进行FSDP前向计算其中的一层Layer时,由于每个GPU都只保存了部分参数,所以需要先通过all-gather操作获得全部的参数;同理,在反向计算过程中,也需要通过all-gather操作,获得全部的参数;最后计算出来的梯度只是部分的结果,需要通过reduce-scatter通信进行累加操作,最终每个GPU卡分别只更新自己那部分参数(也就是local本地weight更新)。
FSDP的应用是对原有model layers加上了一层wrapper封装,只有在FSDP实例中的layer才会在前向和后向过程中执行gather相关操作,通过切分可以利用相同的显存大小训练更大的模型。为了进一步提升显存利用率,FSDP也支持把不活跃的实例全部offload调出到CPU上去。
FSDP计算过程的伪码如下:
1 | FSDP forward pass: |
在PyTorch中的示例如下,
通过FullyShardedDataParallel
实现对model的封装,通过CPUOffload
来决定采用哪种策略把参数调到CPU上。
1 | from torch.distributed.fsdp import ( |
使用FSDP训练GPT-175B和GPT-1T参数量大小的模型,词表大小50K,fp16的精度和使用SGD的优化器。
结果如下,使用FSDP时在GPU卡数增大的情况下,对GPU单卡的吞叶没有影响;在A100-40G机器下增大batch_size
但吞吐没有增加, 瓶颈不在于通信而是CUDA
cache的分配到了瓶颈;当换为A100-80G机器时,CUDA
cache的分配问题得到解决后,增大batch_size
后吞吐进一步增加。