Skip to content

Add attribute output_dtype to QuantizeLinear #5956

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 4 commits into from
Feb 25, 2024

Conversation

galagam
Copy link
Contributor

@galagam galagam commented Feb 23, 2024

The purpose of this change is to allow setting the quantized type without providing the zero-point tensor for symmetric quantization.
This reduces model size, most importantly for block quantization where the zero-point tensor dimensions are large, and reduces backend runtime.

This implements issue #5943

@galagam galagam requested review from a team as code owners February 23, 2024 14:36
@galagam
Copy link
Contributor Author

galagam commented Feb 23, 2024

@gramalingam following up on our discussion in the Operators SIG meeting yesterday, here are the changes for #5943.
If you can do a quick review, hopefully we'll be able to get this into v1.16.

Copy link

codecov bot commented Feb 23, 2024

Codecov Report

Attention: Patch coverage is 62.50000% with 30 lines in your changes are missing coverage. Please review.

Project coverage is 56.79%. Comparing base (945d7be) to head (05b222a).

Files Patch % Lines
onnx/backend/test/case/node/quantizelinear.py 0.00% 14 Missing ⚠️
onnx/reference/ops/op_quantize_linear.py 70.00% 7 Missing and 5 partials ⚠️
onnx/test/shape_inference_test.py 77.77% 4 Missing ⚠️
Additional details and impacted files
@@           Coverage Diff           @@
##             main    #5956   +/-   ##
=======================================
  Coverage   56.79%   56.79%           
=======================================
  Files         506      506           
  Lines       30308    30349   +41     
  Branches     4580     4589    +9     
=======================================
+ Hits        17214    17238   +24     
- Misses      12267    12283   +16     
- Partials      827      828    +1     

☔ View full report in Codecov by Sentry.
📢 Have feedback on the report? Share it here.

Comment on lines +39 to +42
ONNX_ASSERTM(
false,
"Attribute output_dtype is not supported for Opset Version %d, supply a zero-point tensor instead",
target_version().version())

Check notice

Code scanning / CodeQL

Too many arguments to formatting function

Format for barf (in a macro expansion) expects 5 arguments but given 6
Copy link
Contributor

@gramalingam gramalingam left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM, thanks for the quick PR, greatly appreciate it!

The purpose of this change is to allow setting the quantized type
without providing the zero-point tensor.
This reduces model size, most importantly for block quantization
where the zero-point tensor dimensions are large.
It also simplifies the creation of symmetric quantization nodes.

Signed-off-by: Gal Hubara Agam <ghubaraagam@nvidia.com>
Signed-off-by: Gal Hubara Agam <ghubaraagam@nvidia.com>
Signed-off-by: Gal Hubara Agam <ghubaraagam@nvidia.com>
Signed-off-by: Gal Hubara Agam <ghubaraagam@nvidia.com>
@galagam galagam force-pushed the quantize-output-dtype-attr branch from 1e26f28 to 05b222a Compare February 24, 2024 00:10
@gramalingam gramalingam added this pull request to the merge queue Feb 25, 2024
Merged via the queue into onnx:main with commit c95a59c Feb 25, 2024
cjvolzka added a commit that referenced this pull request Feb 26, 2024
* 'main' of https://github.com/onnx/onnx:
  Add attribute output_dtype to QuantizeLinear (#5956)
  Update inliner to propagate valueinfos (#5942)
  Fix ConstantOfShape type constraints (#5961)
  Support register custom OpSchema by python (#5906)
  Fix ReferenceEvaluator when run from a subclass (#5936)
isdanni pushed a commit to isdanni/onnx that referenced this pull request Mar 18, 2024
The purpose of this change is to allow setting the quantized type
without providing the zero-point tensor for symmetric quantization.
This reduces model size, most importantly for block quantization where
the zero-point tensor dimensions are large, and reduces backend runtime.

This implements issue onnx#5943

---------

Signed-off-by: Gal Hubara Agam <ghubaraagam@nvidia.com>
Signed-off-by: isdanni <leedanni@gmail.com>
linshokaku pushed a commit to linshokaku/onnx that referenced this pull request Oct 2, 2024
The purpose of this change is to allow setting the quantized type
without providing the zero-point tensor for symmetric quantization.
This reduces model size, most importantly for block quantization where
the zero-point tensor dimensions are large, and reduces backend runtime.

This implements issue onnx#5943

---------

Signed-off-by: Gal Hubara Agam <ghubaraagam@nvidia.com>
Signed-off-by: Linsho Kaku <linsho@preferred.jp>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
Status: Done
Development

Successfully merging this pull request may close these issues.

2 participants