Skip to content

multiprocessing deadlocks #1120

@FuriouslyCurious

Description

@FuriouslyCurious

I am experimenting with Soumith's imagenet example, but it is crashing or deadlocking in several ways. I have added a bunch of "print" statements to it to figure out where it is crashing, and here is the GIST of full script: (as you can see, there are almost no significant modifications to the original code.) All code is running on 2x NVidia Titan X 12 GB cards with 96 GB RAM.

https://gist.github.com/FuriouslyCurious/81742b8126f07f919522a588147e6086

Issue 1: transforms.Scale(512) fails in THCTensorMathBlas.cu:241

How to reproduce:

  1. Images are being fed with transforms.Scale(512) or transforms.Scale(1024)
  2. Source images are 2048x2048.
  3. Workers >= 1
  4. Batchsize >= 2
  5. Script will crash on its own in few minutes

Output

 python train.py -a resnet18 -j 1 -b 2 /home/FC/data/P/
=> Parsing complete...
=> creating model 'resnet18'
=> Using CUDA DataParallel
=> Starting training images loading...
=> Starting validation images loading...
=> Loss criterion and optimizer setup
=> Starting training...
=> Training Epoch 0
Traceback (most recent call last):
  File "train.py", line 299, in <module>
    main()
  File "train.py", line 140, in main
    train(train_loader, model, criterion, optimizer, epoch)
  File "train.py", line 177, in train
    output = model(input_var)
  File "/conda3/envs/idp/lib/python3.5/site-packages/torch/nn/modules/module.py", line 202, in __call__
    result = self.forward(*input, **kwargs)
  File "/conda3/envs/idp/lib/python3.5/site-packages/torch/nn/parallel/data_parallel.py", line 92, in forward
    outputs = self.parallel_apply(replicas, scattered, gpu_dicts)
  File "/conda3/envs/idp/lib/python3.5/site-packages/torch/nn/parallel/data_parallel.py", line 102, in parallel_apply
    return parallel_apply(replicas, inputs, kwargs)
  File "/conda3/envs/idp/lib/python3.5/site-packages/torch/nn/parallel/parallel_apply.py", line 50, in parallel_apply
    raise output
  File "/conda3/envs/idp/lib/python3.5/site-packages/torch/nn/parallel/parallel_apply.py", line 30, in _worker
    output = module(*input, **kwargs)
  File "/conda3/envs/idp/lib/python3.5/site-packages/torch/nn/modules/module.py", line 202, in __call__
    result = self.forward(*input, **kwargs)
  File "/conda3/envs/idp/lib/python3.5/site-packages/torchvision-0.1.6-py3.5.egg/torchvision/models/resnet.py", line 150, in forward
  File "/conda3/envs/idp/lib/python3.5/site-packages/torch/nn/modules/module.py", line 202, in __call__
    result = self.forward(*input, **kwargs)
  File "/conda3/envs/idp/lib/python3.5/site-packages/torch/nn/modules/linear.py", line 54, in forward
    return self._backend.Linear()(input, self.weight, self.bias)
  File "/conda3/envs/idp/lib/python3.5/site-packages/torch/nn/_functions/linear.py", line 10, in forward
    output.addmm_(0, 1, input, weight.t())
RuntimeError: size mismatch at /data/users/soumith/miniconda2/conda-bld/pytorch-cuda80-0.1.10_1488757768560/work/torch/lib/THC/generic/THCTensorMathBlas.cu:241

Issue 2: Multiple worker threads deadlock in index_queue.get() and waiter.acquire()

How to reproduce:

  1. Images are being fed with default crop: transforms.RandomSizedCrop(224)
  2. Source images are 2048x2048.
  3. Workers > 2
  4. Batchsize > 40
  5. When you see GPU clock speed fall to resting MHz on NVidia-smi, script has deadlocked in waiter.acquire() and index_queue.get(). Abort the script manually.
python train.py -a resnet18 /home/FC/data/P
=> Parsing complete...
=> creating model 'resnet18'
=> Using CUDA DataParallel
=> Starting training images loading...
=> Starting validation images loading...
=> Loss criterion and optimizer setup
=> Starting training...
=> Training Epoch 0
^CProcess Process-4:
Process Process-3:
Traceback (most recent call last):
Traceback (most recent call last):
  File "train.py", line 299, in <module>
    main()
  File "train.py", line 140, in main
    train(train_loader, model, criterion, optimizer, epoch)
  File "train.py", line 168, in train
    for i, (input, target) in enumerate(train_loader):
  File "/conda3/envs/idp/lib/python3.5/site-packages/torch/utils/data/dataloader.py", line 168, in __next__
    idx, batch = self.data_queue.get()
  File "/conda3/envs/idp/lib/python3.5/queue.py", line 164, in get
    self.not_empty.wait()
  File "/conda3/envs/idp/lib/python3.5/threading.py", line 293, in wait
    waiter.acquire()
Traceback (most recent call last):
  File "/conda3/envs/idp/lib/python3.5/multiprocessing/process.py", line 249, in _bootstrap
    self.run()
  File "/conda3/envs/idp/lib/python3.5/multiprocessing/process.py", line 93, in run
    self._target(*self._args, **self._kwargs)
  File "/conda3/envs/idp/lib/python3.5/site-packages/torch/utils/data/dataloader.py", line 26, in _worker_loop
    r = index_queue.get()
  File "/conda3/envs/idp/lib/python3.5/multiprocessing/queues.py", line 342, in get
    with self._rlock:
  File "/conda3/envs/idp/lib/python3.5/multiprocessing/synchronize.py", line 96, in __enter__
    return self._semlock.__enter__()
  File "/conda3/envs/idp/lib/python3.5/multiprocessing/process.py", line 249, in _bootstrap
    self.run()
  File "/conda3/envs/idp/lib/python3.5/multiprocessing/process.py", line 93, in run
    self._target(*self._args, **self._kwargs)
  File "/conda3/envs/idp/lib/python3.5/site-packages/torch/utils/data/dataloader.py", line 26, in _worker_loop
    r = index_queue.get()

Issue 3: Single Worker thread hangs in threading.py:293 waiter.acquire()

How to reproduce:

  1. Images are being fed with NO crop or scale
  2. Source images are 2048x2048.
  3. Workers >= 1
  4. Batchsize >= 1
  5. When you see GPU clock speed fall to resting MHz on NVidia-smi, script has stalled in waiter.acquire(). Manually abort the script.
python train.py -a resnet152 -j 1 -b 1 /home/FC/data/P/
=> Parsing complete...
=> creating model 'resnet152'
=> Using CUDA DataParallel
=> Starting training images loading...
=> Starting validation images loading...
=> Loss criterion and optimizer setup
=> Starting training...
=> Training Epoch 0
^CTraceback (most recent call last):
  File "train.py", line 298, in <module>
    main()
  File "train.py", line 139, in main
    train(train_loader, model, criterion, optimizer, epoch)
  File "train.py", line 167, in train
    for i, (input, target) in enumerate(train_loader):
  File "/conda3/envs/idp/lib/python3.5/site-packages/torch/utils/data/dataloader.py", line 168, in __next__
    idx, batch = self.data_queue.get()
  File "/conda3/envs/idp/lib/python3.5/queue.py", line 164, in get
    self.not_empty.wait()
  File "/conda3/envs/idp/lib/python3.5/threading.py", line 293, in wait
    waiter.acquire()
KeyboardInterrupt

Metadata

Metadata

Assignees

No one assigned

    Labels

    high priorityneeds reproductionEnsure you have actionable steps to reproduce the issue. Someone else needs to confirm the repro.

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions