-
Notifications
You must be signed in to change notification settings - Fork 3.8k
Add FLOAT8E8M0 data type #7030
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Add FLOAT8E8M0 data type #7030
Conversation
Could you update https://github.com/onnx/ir-py/blob/main/src/onnx_ir/_enums.py and the tensor representations e.g. https://github.com/onnx/ir-py/blob/fdee1e28e199f67ced802d785565ff6ebba6f63c/src/onnx_ir/_core.py#L258 as well, after consensus is reached? Thanks! |
b14d50e
to
cbfe6c8
Compare
Out of curiosity: what are the benefits of each rounding mode? Was it different because of the lack of spec, or due to platform characteristics/ performance considerations? |
Does the proposed rounding mode attribute for cast affect any other data types? |
Given the difference in native behavior, a given backend is unlikely to implement all rounding modes, I assume. Wondering if this has an implication to model portability |
@justinchuby CUDA has done extensive experiments to show that roundup gives the best accuracy and has standardized it in the CUDA spec, so essentially roundup should be the only mode that matters for MX applications. I'm ok with adding just roundup in the ONNX spec. Unfortunately OCP didn't define it this way and efforts to correct it has not seen much progress. Other libraries have mostly chosen RNE for consistency with other float types. But it's unlikely people will use that for MX use cases. I included the other As the doc says, "round_mode" attribute only applies to e8m0, so this won't interact with the existing types. |
The reference evaluator is likely going to be implemented by ml_dtypes (proposed). Is there a way to simulate the rounding mode in an efficient manner? I assume we can create a mask of everything that needs to be rounded up, and manipulate those elements as a post processing step? |
Codecov ReportAttention: Patch coverage is
Additional details and impacted files@@ Coverage Diff @@
## main #7030 +/- ##
==========================================
- Coverage 56.40% 56.37% -0.04%
==========================================
Files 510 510
Lines 32721 32806 +85
Branches 3093 3115 +22
==========================================
+ Hits 18457 18493 +36
- Misses 13410 13456 +46
- Partials 854 857 +3 ☔ View full report in Codecov by Sentry. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Could you also update https://github.com/onnx/onnx/blob/main/docs/docsgen/source/technical/float8.md
092ab78
to
d31890b
Compare
Do you plan to update cast like as well? |
Signed-off-by: Yuan Yao <yuanyao@nvidia.com>
Signed-off-by: Yuan Yao <yuanyao@nvidia.com>
Signed-off-by: Yuan Yao <yuanyao@nvidia.com>
onnx/backend/test/data/node/test_cast_BFLOAT16_to_FLOAT/model.onnx
Outdated
Show resolved
Hide resolved
Signed-off-by: Yuan Yao <yuanyao@nvidia.com>
Signed-off-by: Yuan Yao <yuanyao@nvidia.com>
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Looks good to me. Would be helpful to get another review
Signed-off-by: Yuan Yao <yuanyao@nvidia.com>
Be sure to regenerate the test data. There must be some issues in CI that’s not catching the discrepancies |
Signed-off-by: Yuan Yao <yuanyao@nvidia.com>
@gramalingam @xadupre @onnx/sig-operators |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM ... just had a minor comment as above.
Signed-off-by: Yuan Yao <yuanyao@nvidia.com>
### Description The is a followup PR to #7030 which added the float8e8m0 dtype and updated the Cast op. This PR enables float8e8m0 for the following ops in opset 24: - QuantizeLinear, DequantizeLinear, CastLike - Constant, ConstantOfShape, Identity, Reshape, Shape, Size, If, Loop, Scan, Flatten, Pad, Squeeze, Unsqueeze, Transpose ### Motivation and Context <!-- - Why is this change required? What problem does it solve? --> <!-- - If it fixes an open issue, please link to the issue here. --> --------- Signed-off-by: Yuan Yao <yuanyao@nvidia.com>
This PR adds comprehensive support for the FLOAT8E8M0 data type that was added to ONNX in onnx/onnx#7030. ## Changes Made - **Added FLOAT8E8M0 enum value**: Set to 24 (next available value after FLOAT4E2M1=23) - **Updated numpy type mapping**: Added support for `ml_dtypes.float8_e8m0fnu` - **Added type properties**: Configured as 8-bit floating point, signed type - **Added short name**: "f8e8m0" for compact representation - **Updated serialization**: Added FLOAT8E8M0 to appropriate sets in `serde.py` for proper tensor serialization/deserialization - **Added tests**: Included parameterized test case and conditional ONNX compatibility check ## Testing The implementation includes comprehensive testing: ```python import onnx_ir._enums as enums import ml_dtypes import numpy as np # Create tensor with FLOAT8E8M0 type data = np.array([1.0, 2.0, 3.0], dtype=ml_dtypes.float8_e8m0fnu) tensor = ir_core.Tensor(data) assert tensor.dtype == enums.DataType.FLOAT8E8M0 # Test properties assert enums.DataType.FLOAT8E8M0.is_floating_point() == True assert enums.DataType.FLOAT8E8M0.bitwidth == 8 assert enums.DataType.FLOAT8E8M0.short_name() == 'f8e8m0' # Test serialization round-trip tensor_proto = serde.serialize_tensor(tensor) assert tensor_proto.data_type == 24 ``` All existing tests continue to pass, ensuring no regression in functionality. Fixes #127. <!-- START COPILOT CODING AGENT TIPS --> --- 💡 You can make Copilot smarter by setting up custom instructions, customizing its development environment and configuring Model Context Protocol (MCP) servers. Learn more [Copilot coding agent tips](https://gh.io/copilot-coding-agent-tips) in the docs. --------- Signed-off-by: Justin Chu <justinchuby@users.noreply.github.com> Co-authored-by: copilot-swe-agent[bot] <198982749+Copilot@users.noreply.github.com> Co-authored-by: justinchuby <11205048+justinchuby@users.noreply.github.com> Co-authored-by: Justin Chu <justinchuby@users.noreply.github.com>
Description
Add new data type FLOAT8E8M0 and related helper functions.

Update Cast op for this new type.
Paper on cuda's choice of roundup: https://arxiv.org/abs/2506.08027
A followup PR will update Q/DQ and other non-compute operators.
Motivation and Context
E8M0 serves as the common scale type for microscaling (MX) formats: https://www.opencompute.org/documents/ocp-microscaling-formats-mx-v1-0-spec-final-pdf