-
Notifications
You must be signed in to change notification settings - Fork 3.8k
Add Attention Op to ONNX Opset 23 #6501
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
Codecov ReportAttention: Patch coverage is
Additional details and impacted files@@ Coverage Diff @@
## main #6501 +/- ##
==========================================
- Coverage 57.13% 56.45% -0.68%
==========================================
Files 507 509 +2
Lines 31927 32515 +588
Branches 3040 3057 +17
==========================================
+ Hits 18240 18356 +116
- Misses 12864 13334 +470
- Partials 823 825 +2 ☔ View full report in Codecov by Sentry. |
+1 |
Some comments and questions:
|
Minor nit about the name: may be "Attention" would be a better choice. It is a shorter name. Anyway, the op covers multiple variants like MHA/GQA/SDPA etc. |
0799e71
to
4c8ae3d
Compare
from onnx.reference.ops.op_asinh import Asinh | ||
from onnx.reference.ops.op_atan import Atan | ||
from onnx.reference.ops.op_atanh import Atanh | ||
from onnx.reference.ops.op_attention import Attention |
Check notice
Code scanning / CodeQL
Unused import Note
from onnx.reference.ops.op_asinh import Asinh | ||
from onnx.reference.ops.op_atan import Atan | ||
from onnx.reference.ops.op_atanh import Atanh | ||
from onnx.reference.ops.op_attention import Attention |
Check notice
Code scanning / CodeQL
Unused import Note
Signed-off-by: shubhambhokare1 <shubhambhokare@gmail.com>
Signed-off-by: shubhambhokare1 <shubhambhokare@gmail.com>
044bb07
to
bef1541
Compare
Signed-off-by: shubhambhokare1 <shubhambhokare@gmail.com>
Signed-off-by: shubhambhokare1 <shubhambhokare@gmail.com>
Signed-off-by: shubhambhokare1 <shubhambhokare@gmail.com>
Signed-off-by: shubhambhokare1 <shubhambhokare@gmail.com>
Signed-off-by: shubhambhokare1 <shubhambhokare@gmail.com>
Signed-off-by: shubhambhokare1 <shubhambhokare@gmail.com>
Signed-off-by: shubhambhokare1 <shubhambhokare@gmail.com>
|
||
head_size_q = int(hidden_size_q / q_num_heads) | ||
new_shape_q = [batch_size, q_num_heads, Q.shape[1], head_size_q] | ||
Q = np.reshape(Q, new_shape_q) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
When Q is 3D shape (batch_size, q_sequence_length, q_hidden_size). It cannot directly reshape to [batch_size, q_num_heads, q_sequence_length, head_size_q]. Need reshape first then transpose.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Agree with @tianleiwu . When Q is of 3D: [b, seq_len, num_head * head_size], it should be applied with reshape: [b, seq_len, num_head, head_size], and then the transpose: [b, num_head, seq_len, head_size].
softcap=None, | ||
qk_matmul_output_mode=None, | ||
) -> np.ndarray: | ||
assert len(Q.shape) == len(K.shape) == len(V.shape) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
If we require Q, K and V (and Output) have same rank, it is better to add to operator spec.
Description
Add the following key LLM ops to the ONNX standard: Attention.
This standardized attention operator should cover:
Motivation and Context
Standardize Operators that are showing up in key LLM models.