Skip to content

Add Attention Op to ONNX Opset 23 #6501

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 32 commits into from
Mar 1, 2025

Conversation

shubhambhokare1
Copy link
Contributor

@shubhambhokare1 shubhambhokare1 commented Oct 28, 2024

Description

Add the following key LLM ops to the ONNX standard: Attention.
This standardized attention operator should cover:

  • Self and Cross Attentions
  • Multi-Head Attention (MHA)
  • Group-Query Attention (GQA)
  • Multi-Query Attention (MQA)
  • No-bias and Causal Mask attentions

Motivation and Context

Standardize Operators that are showing up in key LLM models.

@shubhambhokare1 shubhambhokare1 requested a review from a team as a code owner October 28, 2024 20:49
Copy link

codecov bot commented Oct 28, 2024

Codecov Report

Attention: Patch coverage is 0% with 469 lines in your changes missing coverage. Please review.

Project coverage is 56.45%. Comparing base (3d5acaf) to head (bdc31a3).
Report is 143 commits behind head on main.

Files with missing lines Patch % Lines
onnx/backend/test/case/node/attention.py 0.00% 469 Missing ⚠️
Additional details and impacted files
@@            Coverage Diff             @@
##             main    #6501      +/-   ##
==========================================
- Coverage   57.13%   56.45%   -0.68%     
==========================================
  Files         507      509       +2     
  Lines       31927    32515     +588     
  Branches     3040     3057      +17     
==========================================
+ Hits        18240    18356     +116     
- Misses      12864    13334     +470     
- Partials      823      825       +2     

☔ View full report in Codecov by Sentry.
📢 Have feedback on the report? Share it here.

@WilliamTambellini
Copy link

+1

@yuanyao-nv
Copy link
Contributor

Some comments and questions:

  1. I think the name of the op should be Scaled... rather than Scalar...?
  2. Should kv-caching happen inside SDPA as you have proposed here? My understanding is it should be outside, since the input to SDPA is already the projected QKV.
  3. In some cases the attention mask can show up more generally as a sequence of pointwise ops, such as in the grok model where it is essentially a scale+tanh operation. Is it possible/Should we strive to make the SDPA definition more general?
  4. In many cases more granular control over the precision of each operation is desired. In fact most backends will have highly tuned precision combinations for the sequence of ops in attention. Should the spec be more flexible to accommodate that?

@gramalingam
Copy link
Contributor

Minor nit about the name: may be "Attention" would be a better choice. It is a shorter name. Anyway, the op covers multiple variants like MHA/GQA/SDPA etc.

@justinchuby justinchuby changed the title Add Attention Op (ScalarDotProductAttention) to ONNX Opset 23 Add Attention Op (ScaledDotProductAttention) to ONNX Opset 23 Jan 4, 2025
@justinchuby justinchuby added this to the 1.18 milestone Jan 8, 2025
@shubhambhokare1 shubhambhokare1 requested a review from a team as a code owner January 23, 2025 17:27
@shubhambhokare1 shubhambhokare1 changed the title Add Attention Op (ScaledDotProductAttention) to ONNX Opset 23 Add Attention Op to ONNX Opset 23 Jan 23, 2025
from onnx.reference.ops.op_asinh import Asinh
from onnx.reference.ops.op_atan import Atan
from onnx.reference.ops.op_atanh import Atanh
from onnx.reference.ops.op_attention import Attention

Check notice

Code scanning / CodeQL

Unused import Note

Import of 'AttributeHasValue' is not used.
from onnx.reference.ops.op_asinh import Asinh
from onnx.reference.ops.op_atan import Atan
from onnx.reference.ops.op_atanh import Atanh
from onnx.reference.ops.op_attention import Attention

Check notice

Code scanning / CodeQL

Unused import Note

Import of 'Attention' is not used.
Signed-off-by: shubhambhokare1 <shubhambhokare@gmail.com>
Signed-off-by: shubhambhokare1 <shubhambhokare@gmail.com>
Signed-off-by: shubhambhokare1 <shubhambhokare@gmail.com>
Signed-off-by: shubhambhokare1 <shubhambhokare@gmail.com>
Signed-off-by: shubhambhokare1 <shubhambhokare@gmail.com>
Signed-off-by: shubhambhokare1 <shubhambhokare@gmail.com>
Signed-off-by: shubhambhokare1 <shubhambhokare@gmail.com>
Signed-off-by: shubhambhokare1 <shubhambhokare@gmail.com>
@github-project-automation github-project-automation bot moved this from In progress to Reviewer approved in PR Tracker Feb 28, 2025
Signed-off-by: shubhambhokare1 <shubhambhokare@gmail.com>
Signed-off-by: shubhambhokare1 <shubhambhokare@gmail.com>
@gramalingam gramalingam added this pull request to the merge queue Feb 28, 2025
Merged via the queue into onnx:main with commit d9b1e4f Mar 1, 2025
39 of 41 checks passed
@github-project-automation github-project-automation bot moved this from Reviewer approved to Done in PR Tracker Mar 1, 2025
@justinchuby justinchuby added the release notes Important changes to call out in release notes label Mar 3, 2025

head_size_q = int(hidden_size_q / q_num_heads)
new_shape_q = [batch_size, q_num_heads, Q.shape[1], head_size_q]
Q = np.reshape(Q, new_shape_q)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

When Q is 3D shape (batch_size, q_sequence_length, q_hidden_size). It cannot directly reshape to [batch_size, q_num_heads, q_sequence_length, head_size_q]. Need reshape first then transpose.

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Agree with @tianleiwu . When Q is of 3D: [b, seq_len, num_head * head_size], it should be applied with reshape: [b, seq_len, num_head, head_size], and then the transpose: [b, num_head, seq_len, head_size].

softcap=None,
qk_matmul_output_mode=None,
) -> np.ndarray:
assert len(Q.shape) == len(K.shape) == len(V.shape)
Copy link
Contributor

@tianleiwu tianleiwu Mar 19, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If we require Q, K and V (and Output) have same rank, it is better to add to operator spec.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
release notes Important changes to call out in release notes topic: operator Issues related to ONNX operators
Projects
Status: Done
Development

Successfully merging this pull request may close these issues.