-
-
Notifications
You must be signed in to change notification settings - Fork 996
Closed
Labels
Description
When trying to debug the code from the forum post I came across this bug:
import torch
from pyro.ops.provenance import ProvenanceTensor
device = torch.device("cuda")
torch.set_default_device(device)
x = torch.tensor([1., 2., 3.])
y = ProvenanceTensor(x, frozenset(["x"]))
print(torch.as_tensor(y))
returns tensor([], device='cuda:0')
Two observations:
- This doesn't happen if
torch.set_default_device
is not used:print(torch.as_tensor(y.cuda()))
works fine - When
torch.set_default_device
is used then the following code gets invoked
https://github.com/pytorch/pytorch/blob/main/torch/utils/_device.py#L72-L76
wherefunc(*args, **kwargs)
returnstensor([], device='cuda:0')
. This doesn't happen when.cuda()
is used like in 1).