-
Notifications
You must be signed in to change notification settings - Fork 2.9k
Description
Describe the bug
The sharding of IterableDatasets with respect to distributed and dataloader worker processes appears problematic with significant performance traps and inconsistencies wrt to distributed train processes vs worker processes.
Splitting across num_workers (per train process loader processes) and world_size (distributed training processes) appears inconsistent.
- worker split:
datasets/src/datasets/iterable_dataset.py
Lines 1266 to 1283 in 9d6d161
if self._is_main_process() and ex_iterable.n_shards < worker_info.num_workers: logger.warning( f"Too many dataloader workers: {worker_info.num_workers} (max is dataset.n_shards={ex_iterable.n_shards}). " f"Stopping {worker_info.num_workers - ex_iterable.n_shards} dataloader workers." ) logger.info( f"To parallelize data loading, we give each process some shards (or data sources) to process. " f"Therefore it's unnecessary to have a number of workers greater than dataset.n_shards={ex_iterable.n_shards}. " f"To enable more parallelism, please split the dataset in more files than {ex_iterable.n_shards}." ) # split workload _log_prefix = f"node#{self._distributed.rank} " if self._distributed else "" shards_indices = self._ex_iterable.split_shard_indices_by_worker(worker_info.id, worker_info.num_workers) if shards_indices: logger.debug( f"{_log_prefix}dataloader worker#{worker_info.id}, ': Starting to iterate over {len(shards_indices)}/{ex_iterable.n_shards} shards." ) ex_iterable = ex_iterable.shard_data_sources(worker_id=worker_info.id, num_workers=worker_info.num_workers) - distributed split:
datasets/src/datasets/iterable_dataset.py
Lines 1335 to 1356 in 9d6d161
if self._distributed: rank = self._distributed.rank world_size = self._distributed.world_size if ex_iterable.n_shards % world_size == 0: if self._is_main_process(): n_shards_per_node = ex_iterable.n_shards // world_size plural = "s" if n_shards_per_node > 1 else "" logger.info( f"Assigning {n_shards_per_node} shard{plural} (or data source{plural}) of the dataset to each node." ) ex_iterable = ex_iterable.shard_data_sources(rank, world_size) else: if self._is_main_process(): logger.info( f"Assigning 1 out of {world_size} examples of the dataset to each node. The others are skipped during the iteration." ) logger.info( f"It is more optimized to distribute the dataset shards (or data sources) across nodes. " f"You can do that by using a dataset with number of shards that is a factor of world_size={world_size}. " f"The current dataset has {ex_iterable.n_shards} which is not a factor of {world_size}" ) ex_iterable = StepExamplesIterable(ex_iterable, step=world_size, offset=rank)
In the case of the distributed split, there is a modulus check that flips between two very different behaviours, why is this different than splitting across the data loader workers? For IterableDatasets the DataLoaders worker processes are independent, so whether it's workers within one train process or across a distributed world the shards should be distributed the same, across world_size * num_worker
independent workers in either case...
Further, the fallback case when the n_shards % world_size == 0
check fails is a rather extreme change. I argue it is not desirable to do that implicitly, it should be an explicit case for specific scenarios (ie reliable validation). A train scenario would likely be much better handled with improved wrapping / stopping behaviour to eg also fix #6437. Changing from stepping shards to stepping samples means that every single process reads ALL of the shards. This was never an intended default for sharded training, shards gain their performance advantage in large scale distributed training by explicitly avoiding the need to have every process overlapping in the data they read, by default, only the data allocated to each process via their assigned shards should be read in each pass of the dataset.
Using a large scale CLIP example, some of the larger datasets have 10-20k shards across 100+TB of data. Training with 1000 GPUs we are switching between reading 100 terabytes per epoch to 100 petabytes if say change 20k % 1000 and drop one gpu-node to 20k % 992.
The 'step over samples' case might be worth the overhead in specific validation scenarios where gaurantees of at least/most once samples seen are more important and do not make up a significant portion of train time or are done in smaller world sizes outside of train.
Steps to reproduce the bug
N/A
Expected behavior
We have an iterable dataset with N shards, to split across workers
- shuffle shards (same seed across all train processes)
- step shard iterator across distributed processes
- step shard iterator across dataloader worker processes
- shuffle samples in every worker via shuffle buffer (different seed in each worker, but ideally controllable (based on base seed + worker id + epoch).
- end up with (possibly uneven) number of shards per worker but each shard only ever accessed by 1 worker per pass (epoch)
Environment info
N/A