-
Notifications
You must be signed in to change notification settings - Fork 25.1k
Closed
Labels
module: complexRelated to complex number support in PyTorchRelated to complex number support in PyTorchmodule: fftmodule: halfRelated to float16 half-precision floatsRelated to float16 half-precision floatstriagedThis issue has been looked at a team member, and triaged and prioritized into an appropriate moduleThis issue has been looked at a team member, and triaged and prioritized into an appropriate module
Description
🐛 Describe the bug
Hi, I try to use both torch.fft.rfft2 and half-precision (via torch.cuda.amp), and they seem don't work together.
(For an easy demonstration, I directly assign half precision via dtype.
import torch
a = torch.randn((2,3,128,128), dtype=torch.float16)
torch.fft.rfft2(a, norm='backward')
Output:
---------------------------------------------------------------------------
RuntimeError Traceback (most recent call last)
<ipython-input-19-b4a82a9b573a> in <module>
1 import torch
2 a = torch.randn((2,3,128,128), dtype=torch.float16)
----> 3 torch.fft.rfft2(a, norm='backward')
RuntimeError: Unsupported dtype Half
Versions
torch==1.8.1
cc @ezyang @anjali411 @dylanbespalko @mruberry @lezcano @nikitaved @peterbell10
Metadata
Metadata
Assignees
Labels
module: complexRelated to complex number support in PyTorchRelated to complex number support in PyTorchmodule: fftmodule: halfRelated to float16 half-precision floatsRelated to float16 half-precision floatstriagedThis issue has been looked at a team member, and triaged and prioritized into an appropriate moduleThis issue has been looked at a team member, and triaged and prioritized into an appropriate module