Skip to content

Problem in training iterable dataset #6437

@21Timothy

Description

@21Timothy

Describe the bug

I am using PyTorch DDP (Distributed Data Parallel) to train my model. Since the data is too large to load into memory at once, I am using load_dataset to read the data as an iterable dataset. I have used datasets.distributed.split_dataset_by_node to distribute the dataset. However, I have noticed that this distribution results in different processes having different amounts of data to train on. As a result, when the earliest process finishes training and starts predicting on the test set, other processes are still training, causing the overall training speed to be very slow.

Steps to reproduce the bug

def train(args, model, device, train_loader, optimizer, criterion, epoch, length):
    model.train()
    idx_length = 0
    for batch_idx, data in enumerate(train_loader):
        s_time = time.time()
        X = data['X']
        target = data['y'].reshape(-1, 28)
        X, target = X.to(device), target.to(device)
        optimizer.zero_grad()
        output = model(X)
        loss = criterion(output, target)
        loss.backward()
        optimizer.step()
        idx_length += 1
        if batch_idx % args.log_interval == 0:
            # print('Train Epoch: {} Batch_idx: {} Process: {} [{}/{} ({:.0f}%)]\tLoss: {:.6f}'.format(
            #     epoch, batch_idx, torch.distributed.get_rank(), batch_idx * len(X), length / torch.distributed.get_world_size(),
            #                                          100. * batch_idx * len(
            #                                              X) * torch.distributed.get_world_size() / length, loss.item()))
            print('Train Epoch: {} Batch_idx: {} Process: {} [{}/{} ({:.0f}%)]\t'.format(
                epoch, batch_idx, torch.distributed.get_rank(), batch_idx * len(X), length / torch.distributed.get_world_size(),
                                                                100. * batch_idx * len(
                                                                    X) * torch.distributed.get_world_size() / length))
            if args.dry_run:
                break
    print('Process %s length: %s time: %s' % (torch.distributed.get_rank(), idx_length, datetime.datetime.now()))

train_iterable_dataset = load_dataset("parquet", data_files=data_files, split="train", streaming=True)
test_iterable_dataset = load_dataset("parquet", data_files=data_files, split="test", streaming=True)
train_iterable_dataset = train_iterable_dataset.map(process_fn)
test_iterable_dataset = test_iterable_dataset.map(process_fn)
train_iterable_dataset = train_iterable_dataset.map(scale)
test_iterable_dataset = test_iterable_dataset.map(scale)

train_iterable_dataset = datasets.distributed.split_dataset_by_node(train_iterable_dataset,
                                                                    world_size=world_size, rank=local_rank).shuffle(seed=1234)
test_iterable_dataset = datasets.distributed.split_dataset_by_node(test_iterable_dataset,
                                                                   world_size=world_size, rank=local_rank).shuffle(seed=1234)
print(torch.distributed.get_rank(), train_iterable_dataset.n_shards, test_iterable_dataset.n_shards)

train_kwargs = {'batch_size': args.batch_size}
test_kwargs = {'batch_size': args.test_batch_size}
if use_cuda:
    cuda_kwargs = {'num_workers': 3,#ngpus_per_node,
                   'pin_memory': True,
                   'shuffle': False}
    train_kwargs.update(cuda_kwargs)
    test_kwargs.update(cuda_kwargs)
train_loader = torch.utils.data.DataLoader(train_iterable_dataset, **train_kwargs,
                                           # sampler=torch.utils.data.distributed.DistributedSampler(
                                           #     train_iterable_dataset,
                                           #     num_replicas=ngpus_per_node,
                                           #     rank=0)
                                           )
test_loader = torch.utils.data.DataLoader(test_iterable_dataset, **test_kwargs,
                                          # sampler=torch.utils.data.distributed.DistributedSampler(
                                          #     test_iterable_dataset,
                                          #     num_replicas=ngpus_per_node,
                                          #     rank=0)
                                          )
for epoch in range(1, args.epochs + 1):
    start_time = time.time()
    train_iterable_dataset.set_epoch(epoch)
    test_iterable_dataset.set_epoch(epoch)
    train(args, model, device, train_loader, optimizer, criterion, epoch, train_len)
    test(args, model, device, criterion2, test_loader)

And here’s the part of output:

Train Epoch: 1 Batch_idx: 5000 Process: 0 [320000/4710975.0 (7%)]	
Train Epoch: 1 Batch_idx: 5000 Process: 1 [320000/4710975.0 (7%)]	
Train Epoch: 1 Batch_idx: 5000 Process: 2 [320000/4710975.0 (7%)]	
Train Epoch: 1 Batch_idx: 5862 Process: 3 Data_length: 12 coststime: 0.04095172882080078
Train Epoch: 1 Batch_idx: 5862 Process: 0 Data_length: 3 coststime: 0.0751960277557373
Train Epoch: 1 Batch_idx: 5867 Process: 3 Data_length: 49 coststime: 0.0032558441162109375
Train Epoch: 1 Batch_idx: 5872 Process: 1 Data_length: 2 coststime: 0.022842884063720703
Train Epoch: 1 Batch_idx: 5876 Process: 3 Data_length: 63 coststime: 0.002694845199584961
Process 3 length: 5877 time: 2023-11-17 17:03:26.582317
Train epoch 1 costTime: 241.72063446044922s . Process 3 Start to test.
3 0 tensor(45508.8516, device='cuda:3')
3 100 tensor(45309.0469, device='cuda:3')
3 200 tensor(45675.3047, device='cuda:3')
3 300 tensor(45263.0273, device='cuda:3')
Process 3 Reduce metrics.
Train Epoch: 2 Batch_idx: 0 Process: 3 [0/4710975.0 (0%)]	
Train Epoch: 1 Batch_idx: 5882 Process: 1 Data_length: 63 coststime: 0.05185818672180176
Train Epoch: 1 Batch_idx: 5887 Process: 1 Data_length: 12 coststime: 0.006895303726196289
Process 1 length: 5888 time: 2023-11-17 17:20:48.578204
Train epoch 1 costTime: 1285.7279663085938s . Process 1 Start to test.
1 0 tensor(45265.9141, device='cuda:1')

Expected behavior

I'd like to know how to fix this problem.

Environment info

torch==2.0
datasets==2.14.0

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions