-
Notifications
You must be signed in to change notification settings - Fork 25.1k
Closed
Labels
module: amp (automated mixed precision)autocastautocasttriagedThis issue has been looked at a team member, and triaged and prioritized into an appropriate moduleThis issue has been looked at a team member, and triaged and prioritized into an appropriate module
Milestone
Description
🐛 Describe the bug
PyTorch 2.4 deprecated the use of torch.cuda.amp.autocast
in favor of torch.amp.autocast("cuda", ...)
, but this change has missed updating internal uses in PyTorch. For example in DP here:
pytorch/torch/nn/parallel/parallel_apply.py
Lines 92 to 93 in 3710a79
with torch.cuda.device(device), torch.cuda.stream(stream), autocast( | |
enabled=autocast_enabled |
import torch
model = torch.nn.Linear(10, 10).cuda()
model = torch.nn.DataParallel(model, device_ids=[0, 1])
output = model(torch.randn(20, 10).cuda())
Produces:
/home/adrian/.conda/envs/lightning/lib/python3.10/site-packages/torch/nn/parallel/parallel_apply.py:79: FutureWarning: `torch.cuda.amp.autocast(args...)` is deprecated. Please use `torch.amp.autocast('cuda', args...)` instead.
with torch.cuda.device(device), torch.cuda.stream(stream), autocast(enabled=autocast_enabled):
/home/adrian/.conda/envs/lightning/lib/python3.10/site-packages/torch/nn/modules/linear.py:117: UserWarning: Attempting to run cuBLAS, but there was no current CUDA context! Attempting to set the primary context... (Triggered internally at ../aten/src/ATen/cuda/CublasHandlePool.cpp:135.)
return F.linear(input, self.weight, self.bias)
Since these are caused internally, the user will see the warning but not be able to do anything with it.
If desired, I can send a PR with a fix for this :)
Versions
PyTorch version: 2.4.0+cu121
function2-llx
Metadata
Metadata
Assignees
Labels
module: amp (automated mixed precision)autocastautocasttriagedThis issue has been looked at a team member, and triaged and prioritized into an appropriate moduleThis issue has been looked at a team member, and triaged and prioritized into an appropriate module