-
Notifications
You must be signed in to change notification settings - Fork 5.8k
优化scaled_dot_product_attention的head_size检查 #73240
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
优化scaled_dot_product_attention的head_size检查 #73240
Conversation
modified: python/paddle/nn/functional/flash_attention.py modified: test/legacy_test/test_scaled_dot_product_attention.py
你的PR提交成功,感谢你对开源项目的贡献! |
Codecov ReportAll modified and coverable lines are covered by tests ✅
Additional details and impacted files@@ Coverage Diff @@
## develop #73240 +/- ##
===========================================
Coverage ? 100.00%
===========================================
Files ? 1
Lines ? 1
Branches ? 0
===========================================
Hits ? 1
Misses ? 0
Partials ? 0 ☔ View full report in Codecov by Sentry. 🚀 New features to boost your workflow:
|
@@ -169,7 +169,7 @@ def test_dot_scale_product_float_mask(self): | |||
) | |||
|
|||
with sdp_kernel( | |||
enable_math=True, enable_flash=False, enable_mem_efficient=False | |||
enable_math=None, enable_flash=None, enable_mem_efficient=None |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
建议加一些自动后端选择的case,去掉with sdp_kernel,原来的case维持现状
modified: test/legacy_test/test_scaled_dot_product_attention.py
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@Qin-sx 原来的case不用动,只是额外再加一些case
嗯,收到,之前有个DCU报错,我想看一下是什么报错。现在加一下新的测试。 |
modified: test/legacy_test/test_scaled_dot_product_attention.py
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM
PR Category
User Experience
PR Types
Bug fixes
Description
head_size的检查为head_size <= 256即可
相关pr