pytorch分布式训练使用Dataloader/WebDataset进行数据并行加载
1. 使用pytorch原生的DistributedSampler
在pytorch DDP数据并行时会对数据集进行切分,每个rank节点只处理部分数据。使用DistributedSampler来会把dataset数据集采样为一个子数据集。定义如下:
1 | torch.utils.data.distributed.DistributedSampler(dataset, num_replicas=None, rank=None, shuffle=True, seed=0, drop_last=False) |
- dataset:用于采样的数据集
- num_replicas (int,
optional):分布式训练总的进程数。默认对应取process_group中的
world_size
- rank (int,
optional):当前进程的rank号。默认对应取process_group中的
rank
- shuffle (bool, optional)
:为True表示对indices进行随机打乱。注意使用
DistributedSampler
时,torch.util.data.Dataloader创建时的shuffle参数,相当于把随机的功能交给了DistributedSampler
。默认为True - seed (int, optional):随机种子,默认为0
- drop_last (bool, optional):
为True的话会丢弃结尾的数据,保证数据大小可以被
num_replicas
整除;为False的话Sampler为增加额外的indices;默认为False
注意在分布式模式下,每个epoch启动前要调用set_epoch()
方法,用于在多个epoch执行时打乱顺序,不调用的话读取顺序都会一样。
1 | if is_distributed else None sampler = DistributedSampler(dataset) |
2. DPP使用WebDataset
2.1 基本介绍
WebDataset是专门针对大数据训练服务的。基于Pytorch IterableDataset实现的数据DataLoader,数据存储在一系列的POSIX tar包中,使用squence/streaming的数据访问方式。跟AIStore服务器和Tensorcom RDMA库结合,可以提供高性能的数据访问方式。
WebDataset会把一个大数据集切分为多个shard,一个tar包就是一个shard。跟pytorch的DataLoader不同的是,WebDataset是以shard为粒度进行I/O并行访问和shuffle。一组shard可以用一个文件的列表来表示,也可以写到一个大括号的方式来进行表示,例如字符串openimages-train-{000000..000554}.tar
表示数据集中包含有554个shard,每个shard分片中有1G的图像数据。在WebDataset中,这种ShardList
字符串的解析是通过braceexpand
库来进行的。以下两种表示是等价的:
1 | dataset = wds.WebDataset(["dataset-000.tar", "dataset-001.tar", "dataset-002.tar", "dataset-003.tar"]) |
WebDataset基本使用方式如下:
1 | import webdataset as wds |
webdataset.Webdataset
使用方法简单, 仅用一行代码,
初始化会自动按node数和worker数对shard进行切分:
1 | dataset = webdataset.Webdataset(urls) |
等价于如下的写法,内部处理对应的类是ShardList
,在示例中使用nodesplitter
和splitter
两个函数将URLs切分为多组shard:
1 | urls = list(braceexpand.braceexpand("dataset-{000000..000999}.tar")) |
2.2 多节点训练
最简单示例如下,使用resample+with_epoch
1 | dataset = wds.WebDataset(url, resampled=True).shuffle(1000).decode("rgb").to_tuple("png", "json").map(preprocess).with_epoch(10000) |
- shuffle:表示对大小为1000的buffer进行shuffle操作
- resampled:表示使用重采样使得数据流一直有
- with_epoch:指定为10000表示强制一个epoch有10000个batch数或者样本数,具体是batch数还是样本数跟前面iter的粒度有关。
复杂的pipeline示例:
1 | dataset = wds.DataPipeline( |
还有一个with_length可以配合使用,用于指定数据集的总长度
两个实际中使用的例子: * webdataset/webdataset-imagenet/imagenet.py * github.com/mlfoundations/open_clip/src/training/data.py
2.3 注意
在WebDataset文档中还介绍了使用ddp_equalize用于Multinode训练,但这种方式已经废弃,
底层实际还是采用with_epoch
和with_length
来实现,参考:ddp_equalize
#194、IGNORE_test_ddp_equalize、ddp
fixes
3. 参考
- torch.utils.data.distributed.DistributedSampler
- torch.nn.parallel.DistributedDataParallel
- torch.util.data.Dataloader
- webdataset/webdataset
- WebDataset Document
- Efficient PyTorch I/O library for Large Datasets, Many Files, Many GPUs
- Using DDP with WebDataset in pytorch lightning #250
- webdataset#multinode-training