详解PyTorch FSDP数据并行(Fully Sharded Data Parallel)

1. 背景介绍

全切片数据并行(Fully Sharded Data Parallel,简称为FSDP)是数据并行的一种新的方式,FSDP最早是在2021年在FairScale-FSDP中提出的,后来合入了PyTorch 1.11版本中。微软之前Deepspeed框架中提出过三种级别的ZERO算法,FSDP可以看成是ZERO-3的实现。

2. 详细介绍

传统的数据并行(DDP)是在每一个GPU卡上保存整个model的参数/梯度/优化器状态, 然后对数据集切分为 \(N\) 个shard分片给不同的GPU进行训练,计算完梯度后通过all-reduce通信来做梯度的融合。如下图:

在FSDP中的主要思路是想办法把model的梯度/优化器状态/参数都进行切分操作,每个GPU只存部分的参数信息,也就是在ZERO-3的思路。为了能把所有的参数进行分片处理,核心在于要把DDP中的all-reduce操作拆解为reduce-scatterall-gather 操作。

如下图,在进行FSDP前向计算其中的一层Layer时,由于每个GPU都只保存了部分参数,所以需要先通过all-gather操作获得全部的参数;同理,在反向计算过程中,也需要通过all-gather操作,获得全部的参数;最后计算出来的梯度只是部分的结果,需要通过reduce-scatter通信进行累加操作,最终每个GPU卡分别只更新自己那部分参数(也就是local本地weight更新)。

FSDP的应用是对原有model layers加上了一层wrapper封装,只有在FSDP实例中的layer才会在前向和后向过程中执行gather相关操作,通过切分可以利用相同的显存大小训练更大的模型。为了进一步提升显存利用率,FSDP也支持把不活跃的实例全部offload调出到CPU上去。

FSDP计算过程的伪码如下:

1
2
3
4
5
6
7
8
9
10
11
12
FSDP forward pass:
for layer_i in layers:
all-gather full weights for layer_i
forward pass for layer_i
discard full weights for layer_i

FSDP backward pass:
for layer_i in layers:
all-gather full weights for layer_i
backward pass for layer_i
discard full weights for layer_i
reduce-scatter gradients for layer_i

在PyTorch中的示例如下, 通过FullyShardedDataParallel实现对model的封装,通过CPUOffload来决定采用哪种策略把参数调到CPU上。

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
from torch.distributed.fsdp import (
FullyShardedDataParallel,
CPUOffload,
)
from torch.distributed.fsdp.wrap import (
default_auto_wrap_policy,
)
import torch.nn as nn

class model(nn.Module):
def __init__(self):
super().__init__()
self.layer1 = nn.Linear(8, 4)
self.layer2 = nn.Linear(4, 16)
self.layer3 = nn.Linear(16, 4)

model = DistributedDataParallel(model())
fsdp_model = FullyShardedDataParallel(
model(),
fsdp_auto_wrap_policy=default_auto_wrap_policy,
cpu_offload=CPUOffload(offload_params=True),
)

使用FSDP训练GPT-175B和GPT-1T参数量大小的模型,词表大小50K,fp16的精度和使用SGD的优化器。

结果如下,使用FSDP时在GPU卡数增大的情况下,对GPU单卡的吞叶没有影响;在A100-40G机器下增大batch_size 但吞吐没有增加, 瓶颈不在于通信而是CUDA cache的分配到了瓶颈;当换为A100-80G机器时,CUDA cache的分配问题得到解决后,增大batch_size后吞吐进一步增加。

3. 参考