Skip to content

Conversation

peishenyan
Copy link
Contributor

In the development of ONNX Runtime, we need know the output shape of each Op node for statical graph compilation. However, we found that we could use onnx shape inference to achieve almost all output shapes except the output shape of Einsum. In onnx/defs/math/defs.cc, we found that there was only Rank Inference function for Einsum instead of Shape Inference.
Given the results of equation parsing and the input shapes of the Einsum operation, we can easily infer the output shape. As such, we have expanded the rank inference to include shape inference for Einsum.

@peishenyan peishenyan requested a review from a team as a code owner March 11, 2024 05:52
@justinchuby justinchuby added the module: shape inference Issues related to shape inference label Mar 11, 2024
@justinchuby justinchuby added this to the 1.17 milestone Mar 11, 2024
@justinchuby
Copy link
Member

Thanks for your contribution! Could you check the CI errors?

@justinchuby
Copy link
Member

It would be helpful to improve the tests for Einsum starting from this line:

def test_einsum_transpose(self) -> None:

Signed-off-by: peishenyan <peishen.yan@intel.com>
Signed-off-by: peishenyan <peishen.yan@intel.com>
@peishenyan peishenyan requested a review from a team as a code owner March 12, 2024 06:45
Signed-off-by: peishenyan <peishen.yan@intel.com>
Copy link

codecov bot commented Mar 12, 2024

Codecov Report

All modified and coverable lines are covered by tests ✅

Project coverage is 56.94%. Comparing base (0bb2775) to head (44024b1).

Additional details and impacted files
@@            Coverage Diff             @@
##             main    #6010      +/-   ##
==========================================
+ Coverage   56.82%   56.94%   +0.11%     
==========================================
  Files         506      506              
  Lines       30377    30461      +84     
  Branches     4592     4592              
==========================================
+ Hits        17263    17347      +84     
  Misses      12285    12285              
  Partials      829      829              

☔ View full report in Codecov by Sentry.
📢 Have feedback on the report? Share it here.

@peishenyan peishenyan force-pushed the einsum_shape_inference branch from bf9f7b1 to 6ce4699 Compare March 12, 2024 07:09
Signed-off-by: peishenyan <peishen.yan@intel.com>
@peishenyan
Copy link
Contributor Author

Hi @justinchuby , thanks for your comments. I have improved the tests for einsum and fix some bugs in my previous code with the help of tests! The code is ready for review now. PTAL, thanks!

@justinchuby
Copy link
Member

LGTM. I wonder if there are more test cases that should be added? Consider referring to the pytorch tests here: https://github.com/pytorch/pytorch/blob/02bb2180f497d6d10dd11da51fffefaa5af80aca/torch/testing/_internal/common_methods_invocations.py#L6083-L6117

@justinchuby
Copy link
Member

As well as some complicated ones: 'ijk,ilm,njm,nlk,abc->' (https://numpy.org/doc/stable/reference/generated/numpy.einsum.html)

@justinchuby
Copy link
Member

justinchuby commented Mar 13, 2024

Or this table in your reference:

https://fdwr.github.io/MachineLearningOperators/OperatorFormulas.html

Call signature NumPy equivalent Description
('i', A1) A1 returns a view of A1
('i->', A1) sum(A1) sums the values of A1
('i,i->i', A1, B1) A1 * B1 element-wise multiplication of A1 and B1
('i,i->', A1, B1) inner(A1, B1) or dot(A1, B1) inner product of A1 and B1
('i,i', A1, B1) inner(A1, B1) or dot(A1, B1) inner product of A1 and B1
('i,j->ij', A1, B1) outer(A1, B1) outer product of A1 and B1
('ij->ij', A2) A2 returns a view of A2
('ij', A2) A2 returns a view of A2
('ji', A2) A2.T view transpose of A2
('ji->ij', A2) A2.T view transpose of A2
('ii->i', A2) diag(A2) view main diagonal of A2
('ii->', A2) trace(A2) sums main diagonal of A2
('ij->', A2) sum(A2) sums the values of A2
('ij->j', A2) sum(A2, axis=0) sum down the columns of A2 (across rows)
('ij->i', A2) sum(A2, axis=1) sum horizontally along the rows of A2
('ij,ij->ij', A2, B2) A2 * B2 element-wise multiplication of A2 and B2
('ij,ji->ij', A2, B2) A2 * B2.transpose() element-wise multiplication of A2 and B2.T
('ij,jk', A2, B2) matmul(A2, B2) or dot(A2, B2) matrix multiplication of A2 and B2
('ij,jk->ik', A2, B2) matmul(A2, B2) or dot(A2, B2) matrix multiplication of A2 and B2
('bij,bjk->bik', A2, B2) matmul(A3, B3) matrix multiplication of A3 and B3 (a stack of 2D matrices)
('bij,bkj->bik', A2, B2) matmul(A3, transpose(B3)) matrix multiplication of A3 and B3 (a stack of 2D matrices)
('ij,kj->ik', A2, B2) inner(A2, B2) inner product of A2 and B2
('ij,kj->ikj', A2, B2) A2[:, None] * B2 each row of A2 multiplied by B2
('ij,kl->ijkl', A2, B2) A2[:, :, None, None] * B2 each value of A2 multiplied by B2
(',ij', 3, B2)   Scalar times array: array([[ 0, 3, 6], [ 9, 12, 15]])
("ij,j", A2, B1) matvec(A2, B1) Matrix and vector.
("ii,ii->i", A2, B2) A2.diag() * B2.diag() diagonals multiplied by each other
("ii,ii->", A2, B2) dot(A2.diag(), B2.diag()) dot product of diagonals

You may leverage parameterized to parameterize the tests

@justinchuby justinchuby self-assigned this Mar 13, 2024
Copy link
Member

@justinchuby justinchuby left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thank you!

Signed-off-by: peishenyan <peishen.yan@intel.com>
@justinchuby justinchuby added the review needed: operators approvers Require reviews from members of operators-approvers label Mar 15, 2024
@justinchuby
Copy link
Member

@gramalingam @xadupre could you help approve? Thanks!

@justinchuby justinchuby enabled auto-merge March 22, 2024 00:29
@fdwr
Copy link
Contributor

fdwr commented Mar 22, 2024

That reminds me, I should review this one too (I only reviewed the ORT WebNN EP one)...

Copy link
Contributor

@fdwr fdwr left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Two minor comments.

Signed-off-by: peishenyan <peishen.yan@intel.com>
auto-merge was automatically disabled March 22, 2024 04:05

Head branch was pushed to by a user without write access

Copy link
Contributor

@fdwr fdwr left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

👍

continue;
}

const auto inserted = label_maps.insert({term[index], num_labels}).second;

Check notice

Code scanning / CodeQL

For loop variable changed in body

Loop counters should not be modified in the body of the [loop](1).
}
}
} else { // Infer the dimension for right-hand side
// If there's an ellipsis, add it's corresponding dimensions
// If there's an ellipsis, add its corresponding dimensions

Check notice

Code scanning / CodeQL

For loop variable changed in body

Loop counters should not be modified in the body of the [loop](1).
Signed-off-by: peishenyan <peishen.yan@intel.com>
@gramalingam
Copy link
Contributor

@fdwr : do you know if the onnxruntime implementation supports broadcasting across the ellipsis dimensions? The ONNX spec says "broadcasting" ... which would mean the shape-inference should do the broadcasting logic for the ellipsis dimensions (or, at least, leave them as unknown).

@fdwr
Copy link
Contributor

fdwr commented Mar 22, 2024

@fdwr : do you know if the onnxruntime implementation supports broadcasting across the ellipsis dimensions? The ONNX spec says "broadcasting" ... which would mean the shape-inference should do the broadcasting logic for the ellipsis dimensions (or, at least, leave them as unknown).

🤔 I'm not sure about ellipsis broadcasting without spelunking. What I know is:
(a) the ORT DML EP does not (we fallback to the CPU EP if an ellipsis is found inside the equation)
(b) the ORT CPU EP supports at least some degree of ellipsis support, because I had test cases for ...ii->...i, which passed.

(update) Judging from this comment in the CPU EP, it probably does support ellipsis broadcasting:

  // A function to process broadcasted dims (ellipsis) of inputs that they occur in
  Status PostProcessBroadcastedDims();

Signed-off-by: peishenyan <peishen.yan@intel.com>
Copy link
Contributor

@fdwr fdwr left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

👍

Copy link
Contributor

@gramalingam gramalingam left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for the extension ... greatly appreciated.

@justinchuby justinchuby enabled auto-merge March 22, 2024 22:08
@justinchuby justinchuby added this pull request to the merge queue Mar 22, 2024
Merged via the queue into onnx:main with commit 5b191de Mar 22, 2024
linshokaku pushed a commit to linshokaku/onnx that referenced this pull request Oct 2, 2024
In the development of ONNX Runtime, we need know the output shape of
each Op node for statical graph compilation. However, we found that we
could use onnx shape inference to achieve almost all output shapes
except the output shape of Einsum. In `onnx/defs/math/defs.cc`, we found
that there was only Rank Inference function for Einsum instead of Shape
Inference.
Given the results of equation parsing and the input shapes of the Einsum
operation, we can easily infer the output shape. As such, we have
expanded the rank inference to include shape inference for Einsum.

---------

Signed-off-by: peishenyan <peishen.yan@intel.com>
Co-authored-by: G. Ramalingam <grama@microsoft.com>
Signed-off-by: Linsho Kaku <linsho@preferred.jp>
fdwr pushed a commit to microsoft/onnxruntime that referenced this pull request Oct 10, 2024
…22376)

### Description
<!-- Describe your changes. -->
Pick up onnx/onnx#6010 to support EinSum shape inference


### Motivation and Context
<!-- - Why is this change required? What problem does it solve?
- If it fixes an open issue, please link to the issue here. -->
This change allows EinSum operator's output shape to be inferenced so
that it can run on accelerators.
rohan11235813 pushed a commit to quadric-io/onnxruntime that referenced this pull request Aug 19, 2025
…22376)

### Description
<!-- Describe your changes. -->
Pick up onnx/onnx#6010 to support EinSum shape inference


### Motivation and Context
<!-- - Why is this change required? What problem does it solve?
- If it fixes an open issue, please link to the issue here. -->
This change allows EinSum operator's output shape to be inferenced so
that it can run on accelerators.
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
module: shape inference Issues related to shape inference review needed: operators approvers Require reviews from members of operators-approvers
Projects
Status: Done
Development

Successfully merging this pull request may close these issues.

4 participants