-
Notifications
You must be signed in to change notification settings - Fork 25.2k
Closed
Description
Hi, all:
Sorry about my entry-level question on how to use
torch.nn.DataParallel(model).cuda()?
The code is as:
import argparse
import os
import shutil
import time
import torch
import torch.nn as nn
import torch.nn.parallel
import torch.backends.cudnn as cudnn
import torch.optim
import torch.utils.data
import torchvision.transforms as transforms
import torchvision.datasets as datasets
import torchvision.models as models
model_names = sorted(name for name in models.__dict__
if name.islower() and not name.startswith("__")
and callable(models.__dict__[name]))
model = models.__dict__['resnet18']()
model = torch.nn.DataParallel(model).cuda()
And, the final line of the codes failed with the following message:
File "train.py", line xxx, in
model = torch.nn.DataParallel(model).cuda()
File "yyy/condaenv/lib/python3.6/site-packages/torch/nn/parallel/data_parallel.py", line 48, in init
output_device = device_ids[0]
IndexError: list index out of range
Can anybody give me a hand?
I'm using a GTX980M under Ubuntu 16.04, with up-to-date anaconda with Python 3.6.1, Pytorch 0.1.12 .
Thank you very much
Pei
Metadata
Metadata
Assignees
Labels
No labels