Let’s say I have a dataset with 5 samples with values [1, 2, 3, 4, 5], with 2 GPUs (for DDP) and batch size of 2. This dataset is IterableDataset since I am streaming it.
Now I split the dataset using split_dataset_by_node to ensure it doesn’t get repeated. And since it’s already splitted, I don’t have to use DistributedSampler?
But in this case I noticed that the:
First iteraton:
first GPU will get → [1, 2]
first GPU will get → [3, 4]
Second iteraton:
first GPU will get → [5]
first GPU will get → Nothing
which actually creates an issue since in case of DistributedSampler, the samples are repeated internally to ensure non of the GPUs at any iteration is missing any data for gradient sync. So my questions are:
- Here since splitting is happening before hand, how to make sure each GPU get’s a batch at each iteration to avoid gradient sync issues?
- Do we need to use
DistributedSampler? If yes, how? - If
dataset.n_shards % world_size != 0, is it possible to shard the streaming dataset on the fly to avoid the case where data is missing?