pytorch分布式训练使用Dataloader/WebDataset进行数据并行加载

1. 使用pytorch原生的DistributedSampler

在pytorch DDP数据并行时会对数据集进行切分,每个rank节点只处理部分数据。使用DistributedSampler来会把dataset数据集采样为一个子数据集。定义如下:

1
torch.utils.data.distributed.DistributedSampler(dataset, num_replicas=None, rank=None, shuffle=True, seed=0, drop_last=False)
  • dataset:用于采样的数据集
  • num_replicas (int, optional):分布式训练总的进程数。默认对应取process_group中的world_size
  • rank (int, optional):当前进程的rank号。默认对应取process_group中的rank
  • shuffle (bool, optional) :为True表示对indices进行随机打乱。注意使用DistributedSampler时,torch.util.data.Dataloader创建时的shuffle参数,相当于把随机的功能交给了DistributedSampler。默认为True
  • seed (int, optional):随机种子,默认为0
  • drop_last (bool, optional): 为True的话会丢弃结尾的数据,保证数据大小可以被num_replicas整除;为False的话Sampler为增加额外的indices;默认为False

注意在分布式模式下,每个epoch启动前要调用set_epoch()方法,用于在多个epoch执行时打乱顺序,不调用的话读取顺序都会一样。

1
2
3
4
5
6
7
>>> sampler = DistributedSampler(dataset) if is_distributed else None
>>> loader = DataLoader(dataset, shuffle=(sampler is None),
... sampler=sampler)
>>> for epoch in range(start_epoch, n_epochs):
... if is_distributed:
... sampler.set_epoch(epoch)
... train(loader)

2. DPP使用WebDataset

2.1 基本介绍

WebDataset是专门针对大数据训练服务的。基于Pytorch IterableDataset实现的数据DataLoader,数据存储在一系列的POSIX tar包中,使用squence/streaming的数据访问方式。跟AIStore服务器和Tensorcom RDMA库结合,可以提供高性能的数据访问方式。

WebDataset会把一个大数据集切分为多个shard,一个tar包就是一个shard。跟pytorch的DataLoader不同的是,WebDataset是以shard为粒度进行I/O并行访问和shuffle。一组shard可以用一个文件的列表来表示,也可以写到一个大括号的方式来进行表示,例如字符串openimages-train-{000000..000554}.tar表示数据集中包含有554个shard,每个shard分片中有1G的图像数据。在WebDataset中,这种ShardList字符串的解析是通过braceexpand库来进行的。以下两种表示是等价的:

1
2
dataset = wds.WebDataset(["dataset-000.tar", "dataset-001.tar", "dataset-002.tar", "dataset-003.tar"])
dataset = wds.WebDataset("dataset-{000..003}.tar")

WebDataset基本使用方式如下:

1
2
3
4
5
6
7
import webdataset as wds

dataset = wds.WebDataset(url).shuffle(1000).decode("torchrgb").to_tuple("jpg;png", "json")
dataloader = torch.utils.data.DataLoader(dataset, num_workers=4, batch_size=16)

for inputs, outputs in dataloader:
...

webdataset.Webdataset使用方法简单, 仅用一行代码, 初始化会自动按node数和worker数对shard进行切分:

1
dataset = webdataset.Webdataset(urls)

等价于如下的写法,内部处理对应的类是ShardList,在示例中使用nodesplittersplitter两个函数将URLs切分为多组shard:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
urls = list(braceexpand.braceexpand("dataset-{000000..000999}.tar"))
dataset = wds.ShardList(urls, splitter=wds.split_by_worker, nodesplitter=wds.split_by_node, shuffle=False)
dataset = wds.Processor(dataset, wds.url_opener)
dataset = wds.Processor(dataset, wds.tar_file_expander)
dataset = wds.Processor(dataset, wds.group_by_keys)


def my_split_by_worker(urls):
wi = torch.utils.data.get_worker_info()
if wi is None:
return urls
else:
return urls[wi.id::wi.num_workers]

def my_split_by_node(urls):
node_id, node_count = torch.distributed.get_rank(), torch.distributed.get_world_size()
return urls[node_id::node_count]

2.2 多节点训练

最简单示例如下,使用resample+with_epoch

1
2
dataset = wds.WebDataset(url, resampled=True).shuffle(1000).decode("rgb").to_tuple("png", "json").map(preprocess).with_epoch(10000)
sample = next(iter(dataset))
  • shuffle:表示对大小为1000的buffer进行shuffle操作
  • resampled:表示使用重采样使得数据流一直有
  • with_epoch:指定为10000表示强制一个epoch有10000个batch数或者样本数,具体是batch数还是样本数跟前面iter的粒度有关。

复杂的pipeline示例:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
dataset = wds.DataPipeline(
wds.ResampledShards(url),
# at this point we have an iterator over all the shards
wds.tarfile_to_samples(),
wds.shuffle(1000),
wds.decode("torchrgb"),
# at this point, we have an list of decompressed training samples from each shard in this worker in sequence
get_patches, # note that can put iterator->iterator functions into the pipeline directly
wds.shuffle(10000),
wds.to_tuple("big.jpg", "json"),
wds.batched(16)
).with_epoch(10000)

batch = next(iter(loader))
batch[0].shape, batch[1].shape

还有一个with_length可以配合使用,用于指定数据集的总长度

两个实际中使用的例子: * webdataset/webdataset-imagenet/imagenet.py * github.com/mlfoundations/open_clip/src/training/data.py

2.3 注意

在WebDataset文档中还介绍了使用ddp_equalize用于Multinode训练,但这种方式已经废弃, 底层实际还是采用with_epochwith_length来实现,参考:ddp_equalize #194IGNORE_test_ddp_equalizeddp fixes

3. 参考