Skip to content

torch.fft.rfft2 doesn't support half dtype #70664

@dong03

Description

@dong03

🐛 Describe the bug

Hi, I try to use both torch.fft.rfft2 and half-precision (via torch.cuda.amp), and they seem don't work together.

(For an easy demonstration, I directly assign half precision via dtype.

import torch
a = torch.randn((2,3,128,128), dtype=torch.float16)
torch.fft.rfft2(a, norm='backward')

Output:

---------------------------------------------------------------------------
RuntimeError                              Traceback (most recent call last)
<ipython-input-19-b4a82a9b573a> in <module>
      1 import torch
      2 a = torch.randn((2,3,128,128), dtype=torch.float16)
----> 3 torch.fft.rfft2(a, norm='backward')

RuntimeError: Unsupported dtype Half

Versions

torch==1.8.1

cc @ezyang @anjali411 @dylanbespalko @mruberry @lezcano @nikitaved @peterbell10

Metadata

Metadata

Assignees

No one assigned

    Labels

    module: complexRelated to complex number support in PyTorchmodule: fftmodule: halfRelated to float16 half-precision floatstriagedThis issue has been looked at a team member, and triaged and prioritized into an appropriate module

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions