-
Notifications
You must be signed in to change notification settings - Fork 25.1k
Closed
Labels
high prioritymodule: complexRelated to complex number support in PyTorchRelated to complex number support in PyTorchmodule: halfRelated to float16 half-precision floatsRelated to float16 half-precision floatstriagedThis issue has been looked at a team member, and triaged and prioritized into an appropriate moduleThis issue has been looked at a team member, and triaged and prioritized into an appropriate module
Description
Currently our ComplexHalf support is in a bad state. Users file issues complaining that ops on ComplexHalf issues trigger internal asserts (#71671, #71635) and that shouldn't happen even if particular op is not supported for ComplexHalf, the error message should state that datatype is not supported. We have 2 options going forward: disable ComplexHalf completely (#70606) or support it for real.
Disabling is relatively easy, we already have a PR for this.
However, there are requests to enable ComplexHalf because users would like to take advantage of better half performance on the GPUs (see e.g. #67324, pytorch/audio#2097).
This issue is to track enabling ComplexHalf should we decide to do that. At a minimum, the steps would be
- Enable common pointwise and reduction ops. Currently those ops for 2 complex datatypes that we support take approximately 40 MB of cuda context size, adding another complex type would result in approximately 20 MB bigger context and build time increase, however, with jiterator we can avoid build and context size penalty, and jiterate pointwise/reduction kernels for complex half. Defining math ops on complex half is not a problem, we can use approach similar to our current half support, where the values are cast to single precision complex first, math is performed in single precision, and result is truncated back.
- Enable fft operation. CuFFT supports complex half.
- Enable matrix multiplication operations. This unfortunately is not supported by cublas. Alternatives are using Triton, or doing 3 or 4 real matrix multiplications with corresponding copies to accommodate complex data layout.
cc @ezyang @gchanan @zou3519 @anjali411 @dylanbespalko @mruberry @lezcano @nikitaved @mthrok, @mruberry, @peterbell10
mruberry, nikitaved, ilovepytorch, faroit, anjali411 and 6 more
Metadata
Metadata
Assignees
Labels
high prioritymodule: complexRelated to complex number support in PyTorchRelated to complex number support in PyTorchmodule: halfRelated to float16 half-precision floatsRelated to float16 half-precision floatstriagedThis issue has been looked at a team member, and triaged and prioritized into an appropriate moduleThis issue has been looked at a team member, and triaged and prioritized into an appropriate module