Skip to content

Conversation

yuanyao-nv
Copy link
Contributor

@yuanyao-nv yuanyao-nv commented Aug 7, 2024

Description

  • FLOAT4E2M1 has been added to proto in Add FLOAT4E2M1 data type #6318
  • This PR adds FLOAT4E2M1 support for QuantizeLinear, DequantizeLinear, Cast, CastLike (opset 23).
  • Also add support to non-compute ops: Constant, ConstantOfShape, Identity, Reshape, Shape, Size, If, Loop, Scan, Flatten, Pad, Squeeze, Unsqueeze, Transpose (opset 23).

Similar to INT4/UNIT4, FP4 weights/inputs are expected to be packed.

@yuanyao-nv yuanyao-nv requested review from a team as code owners August 7, 2024 05:04
Copy link

codecov bot commented Aug 7, 2024

Codecov Report

Attention: Patch coverage is 18.00000% with 41 lines in your changes missing coverage. Please review.

Project coverage is 57.22%. Comparing base (83194ed) to head (b6544c3).
Report is 94 commits behind head on main.

Files Patch % Lines
onnx/backend/test/case/node/cast.py 0.00% 17 Missing ⚠️
onnx/backend/test/case/node/dequantizelinear.py 0.00% 8 Missing ⚠️
onnx/backend/test/case/node/quantizelinear.py 0.00% 8 Missing ⚠️
onnx/reference/ops/op_quantize_linear.py 14.28% 5 Missing and 1 partial ⚠️
onnx/reference/ops/op_cast_like.py 0.00% 1 Missing and 1 partial ⚠️
Additional details and impacted files
@@            Coverage Diff             @@
##             main    #6283      +/-   ##
==========================================
+ Coverage   56.95%   57.22%   +0.26%     
==========================================
  Files         506      507       +1     
  Lines       30467    31398     +931     
  Branches     4592     4691      +99     
==========================================
+ Hits        17353    17968     +615     
- Misses      12285    12577     +292     
- Partials      829      853      +24     

☔ View full report in Codecov by Sentry.
📢 Have feedback on the report? Share it here.

@justinchuby
Copy link
Member

Do you have plans to also push jax-ml/ml_dtypes#116 forward? If this included in ml_dyptes it would make the interop experience much better (and code run faster).

@yuanyao-nv
Copy link
Contributor Author

@onnx/sig-archinfra-approvers I'm seeing the following test errors

FAILED onnx/test/test_backend_test.py::OnnxBackendNodeModelTest::test_cast_FLOAT16_to_FLOAT4E2M1_cpu - onnx.onnx_cpp2py_export.shape_inference.InferenceError: [ShapeInferenceError] (op_type:Cast): output has unsupported type tensor(float4e2m1)
FAILED onnx/test/test_backend_test.py::OnnxBackendNodeModelTest::test_cast_FLOAT4E2M1_to_FLOAT16_cpu - onnx.onnx_cpp2py_export.shape_inference.InferenceError: [ShapeInferenceError] (op_type:Cast): input typestr: T1, has unsupported type: tensor(float4e2m1)
FAILED onnx/test/test_backend_test.py::OnnxBackendNodeModelTest::test_cast_FLOAT4E2M1_to_FLOAT_cpu - onnx.onnx_cpp2py_export.shape_inference.InferenceError: [ShapeInferenceError] (op_type:Cast): input typestr: T1, has unsupported type: tensor(float4e2m1)
FAILED onnx/test/test_backend_test.py::OnnxBackendNodeModelTest::test_cast_FLOAT_to_FLOAT4E2M1_cpu - onnx.onnx_cpp2py_export.shape_inference.InferenceError: [ShapeInferenceError] (op_type:Cast): output has unsupported type tensor(float4e2m1)
FAILED onnx/test/test_backend_test.py::OnnxBackendNodeModelTest::test_dequantizelinear_float4e2m1_cpu - onnx.onnx_cpp2py_export.shape_inference.InferenceError: [ShapeInferenceError] (op_type:DequantizeLinear): x typestr: T1, has unsupported type: tensor(float4e2m1)
FAILED onnx/test/test_backend_test.py::OnnxBackendNodeModelTest::test_quantizelinear_float4e2m1_cpu - onnx.onnx_cpp2py_export.shape_inference.InferenceError: [ShapeInferenceError] (op_type:QuantizeLinear): y_zero_point typestr: T2, has unsupported type: tensor(float4e2m1)

Seems like there's some list I need to propagate the new data type to. Any pointers?

@yuanyao-nv
Copy link
Contributor Author

Do you have plans to also push jax-ml/ml_dtypes#116 forward? If this included in ml_dyptes it would make the interop experience much better (and code run faster).

FP4 would the more impending priority for us. The remaining FP6 types could be worked on as well in the future if bandwidth permits.

@xadupre
Copy link
Contributor

xadupre commented Aug 7, 2024

@onnx/sig-archinfra-approvers I'm seeing the following test errors

FAILED onnx/test/test_backend_test.py::OnnxBackendNodeModelTest::test_cast_FLOAT16_to_FLOAT4E2M1_cpu - onnx.onnx_cpp2py_export.shape_inference.InferenceError: [ShapeInferenceError] (op_type:Cast): output has unsupported type tensor(float4e2m1)
FAILED onnx/test/test_backend_test.py::OnnxBackendNodeModelTest::test_cast_FLOAT4E2M1_to_FLOAT16_cpu - onnx.onnx_cpp2py_export.shape_inference.InferenceError: [ShapeInferenceError] (op_type:Cast): input typestr: T1, has unsupported type: tensor(float4e2m1)
FAILED onnx/test/test_backend_test.py::OnnxBackendNodeModelTest::test_cast_FLOAT4E2M1_to_FLOAT_cpu - onnx.onnx_cpp2py_export.shape_inference.InferenceError: [ShapeInferenceError] (op_type:Cast): input typestr: T1, has unsupported type: tensor(float4e2m1)
FAILED onnx/test/test_backend_test.py::OnnxBackendNodeModelTest::test_cast_FLOAT_to_FLOAT4E2M1_cpu - onnx.onnx_cpp2py_export.shape_inference.InferenceError: [ShapeInferenceError] (op_type:Cast): output has unsupported type tensor(float4e2m1)
FAILED onnx/test/test_backend_test.py::OnnxBackendNodeModelTest::test_dequantizelinear_float4e2m1_cpu - onnx.onnx_cpp2py_export.shape_inference.InferenceError: [ShapeInferenceError] (op_type:DequantizeLinear): x typestr: T1, has unsupported type: tensor(float4e2m1)
FAILED onnx/test/test_backend_test.py::OnnxBackendNodeModelTest::test_quantizelinear_float4e2m1_cpu - onnx.onnx_cpp2py_export.shape_inference.InferenceError: [ShapeInferenceError] (op_type:QuantizeLinear): y_zero_point typestr: T2, has unsupported type: tensor(float4e2m1)

Seems like there's some list I need to propagate the new data type to. Any pointers?

I usually look for strings such as FLOAT8E4M3FN and float8e4m3fn`` to see all the places it is used and I insert a new line to handle the new type.

@justinchuby
Copy link
Member

Do you have plans to also push jax-ml/ml_dtypes#116 forward? If this included in ml_dyptes it would make the interop experience much better (and code run faster).

FP4 would the more impending priority for us. The remaining FP6 types could be worked on as well in the future if bandwidth permits.

It would be very helpful to have an unpacked version of fp4e2m1 in ml_dtypes

@justinchuby
Copy link
Member

justinchuby commented Aug 9, 2024

Is float32x2 conventional naming? Should it just be float32?

@justinchuby justinchuby self-assigned this Aug 9, 2024
@justinchuby
Copy link
Member

lintrunner errors will need to be ignore in line. For example # noqa: PLR2004

Copy link
Member

@justinchuby justinchuby left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

lgtm. Thanks!

@justinchuby justinchuby reopened this Aug 23, 2024
@justinchuby
Copy link
Member

cc @gramalingam for another look

github-merge-queue bot pushed a commit that referenced this pull request Aug 24, 2024
### Description
- Add FLOAT4E2M1 as a new data type to proto as well as relevant helper
functions and tests.
- This PR splits out the portion of
#6283 relevant to data type updates to
reduce the PR's size.

### Motivation and Context
Narrow precision data types with sub-byte bit widths are becoming
solutions to the rising cost, performance, and deployment challenges of
LLMs. ONNX already has INT4/UINT4. FP4 is another commonly used
narrow-precision data type for compressing both the weights and
activations of LLMs. For example
[OCP](https://www.opencompute.org/documents/ocp-microscaling-formats-mx-v1-0-spec-final-pdf)
MXFP4 uses E2M1 as element type.

Similar to INT4/UNIT4, FP4 weights/inputs are expected to be packed.

Signed-off-by: Yuan Yao <yuanyao@nvidia.com>
Signed-off-by: Yuan Yao (yuanyao) <yuanyao@nvidia.com>
Signed-off-by: Yuan Yao <yuanyao@nvidia.com>
@yuanyao-nv yuanyao-nv changed the title Add FLOAT4E2M1 data type Add FLOAT4E2M1 support to relevant operators Aug 24, 2024
Signed-off-by: Yuan Yao <yuanyao@nvidia.com>
@yuanyao-nv yuanyao-nv removed the run release CIs Use this label to trigger release tests in CI label Aug 25, 2024
@yuanyao-nv yuanyao-nv closed this Aug 25, 2024
@yuanyao-nv yuanyao-nv reopened this Aug 25, 2024
andife pushed a commit to andife/onnx that referenced this pull request Aug 26, 2024
### Description
- Add FLOAT4E2M1 as a new data type to proto as well as relevant helper
functions and tests.
- This PR splits out the portion of
onnx#6283 relevant to data type updates to
reduce the PR's size.

### Motivation and Context
Narrow precision data types with sub-byte bit widths are becoming
solutions to the rising cost, performance, and deployment challenges of
LLMs. ONNX already has INT4/UINT4. FP4 is another commonly used
narrow-precision data type for compressing both the weights and
activations of LLMs. For example
[OCP](https://www.opencompute.org/documents/ocp-microscaling-formats-mx-v1-0-spec-final-pdf)
MXFP4 uses E2M1 as element type.

Similar to INT4/UNIT4, FP4 weights/inputs are expected to be packed.

Signed-off-by: Yuan Yao <yuanyao@nvidia.com>
Signed-off-by: Andreas Fehlner <fehlner@arcor.de>
@liqunfu
Copy link
Collaborator

liqunfu commented Aug 27, 2024

@yuanyao-nv the test failure can be reproed if you create a python 3.9 environment. build onnx from your branch, python -m pip install -r requirements-min.txt, and then: python onnx\test\test_backend_reference.py -k test_dequantizelinear_float4e2m1_cpu (make sure you numpy version is 1.20.3). A quick fix is to insert mantissa = mantissa.astype(np.float32) at onnx\numpy_helper.py before val = np.where(. Please let me know if you still have issue. I am happy to meet you on team or zoom. Thank you

@justinchuby justinchuby added the run release CIs Use this label to trigger release tests in CI label Aug 27, 2024
Signed-off-by: Yuan Yao <yuanyao@nvidia.com>
@yuanyao-nv
Copy link
Contributor Author

Thanks all for the help! @justinchuby @liqunfu @gramalingam
I think this PR is ready to merge.

@justinchuby justinchuby added this pull request to the merge queue Aug 27, 2024
Merged via the queue into onnx:main with commit f22a2ad Aug 27, 2024
70 checks passed
linshokaku pushed a commit to linshokaku/onnx that referenced this pull request Oct 2, 2024
### Description
- Add FLOAT4E2M1 as a new data type to proto as well as relevant helper
functions and tests.
- This PR splits out the portion of
onnx#6283 relevant to data type updates to
reduce the PR's size.

### Motivation and Context
Narrow precision data types with sub-byte bit widths are becoming
solutions to the rising cost, performance, and deployment challenges of
LLMs. ONNX already has INT4/UINT4. FP4 is another commonly used
narrow-precision data type for compressing both the weights and
activations of LLMs. For example
[OCP](https://www.opencompute.org/documents/ocp-microscaling-formats-mx-v1-0-spec-final-pdf)
MXFP4 uses E2M1 as element type.

Similar to INT4/UNIT4, FP4 weights/inputs are expected to be packed.

Signed-off-by: Yuan Yao <yuanyao@nvidia.com>
Signed-off-by: Linsho Kaku <linsho@preferred.jp>
linshokaku pushed a commit to linshokaku/onnx that referenced this pull request Oct 2, 2024
### Description
- FLOAT4E2M1 has been added to proto in
onnx#6318
- This PR adds FLOAT4E2M1 support for QuantizeLinear, DequantizeLinear,
Cast, CastLike (opset 23).
- Also add support to non-compute ops: Constant, ConstantOfShape,
Identity, Reshape, Shape, Size, If, Loop, Scan, Flatten, Pad, Squeeze,
Unsqueeze, Transpose (opset 23).

Similar to INT4/UNIT4, FP4 weights/inputs are expected to be packed.

---------

Signed-off-by: Yuan Yao (yuanyao) <yuanyao@nvidia.com>
Signed-off-by: Yuan Yao <yuanyao@nvidia.com>
Signed-off-by: Linsho Kaku <linsho@preferred.jp>
@TedThemistokleous
Copy link

Do you have plans to also push jax-ml/ml_dtypes#116 forward? If this included in ml_dyptes it would make the interop experience much better (and code run faster).

FP4 would the more impending priority for us. The remaining FP6 types could be worked on as well in the future if bandwidth permits.

Hi just seeing this PR after as search on fp6 in your repo, are there now efforts for fp6 support added in to the onnx spec? I don't see any PR's or relevant specs that we can leverage right now on your technical site.

If not, what would adding that support involve?

@justinchuby
Copy link
Member

@yuanyao-nv @TedThemistokleous could you share your use case?

@yuanyao-nv
Copy link
Contributor Author

Do you have plans to also push jax-ml/ml_dtypes#116 forward? If this included in ml_dyptes it would make the interop experience much better (and code run faster).

FP4 would the more impending priority for us. The remaining FP6 types could be worked on as well in the future if bandwidth permits.

Hi just seeing this PR after as search on fp6 in your repo, are there now efforts for fp6 support added in to the onnx spec? I don't see any PR's or relevant specs that we can leverage right now on your technical site.

If not, what would adding that support involve?

I don't have immediate use cases for MXFP6 at this point. But please feel free to add it if you need it. I also have a PR open for e8m0 currently #7030 which should cover the scale type for all MX formats.

@TedThemistokleous
Copy link

@yuanyao-nv @TedThemistokleous could you share your use case?

I'm a dev on MIGraphX and we're looking to support the fp6 MX type for our upcomming 7.1 ROCm Release.

I've made a feature request here:
#7048

I'm also the developer responsible for the MIGraphX and ROCm Execution Providers in Onnxruntime
I've also made a feature request here for Onnxruntime to handle that end - microsoft/onnxruntime#25054

On the MIGraphX side we'd need to add in the type support once we have an idea to handle input data and able to find fp6 quantized onnx models to parse in.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
module: spec run release CIs Use this label to trigger release tests in CI
Projects
Status: Done
Development

Successfully merging this pull request may close these issues.

5 participants