-
Notifications
You must be signed in to change notification settings - Fork 77
在scaled_dot_product_attention函数中加入bool mask的测试 #586
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
在scaled_dot_product_attention函数中加入bool mask的测试 #586
Conversation
modified: tests/test_scaled_dot_product_attention.py
Thanks for your contribution! |
@Qin-sx PaddlePaddle/Paddle#72927 paddle这个修改还没有合入,这个为何就能跑过呢? |
嗯,感觉可能是跳过了,可能是测试环境有变化,skip的条件是参考的之前测试的条件。 tests/test_scaled_dot_product_attention.py::test_case_1 W0604 23:36:46.620575 29364 gpu_resources.cc:114] Please NOTE: device: 0, GPU Compute Capability: 8.9, Driver API Version: 12.7, Runtime API Version: 11.8
PASSED
tests/test_scaled_dot_product_attention.py::test_case_2 PASSED
tests/test_scaled_dot_product_attention.py::test_case_3 PASSED
tests/test_scaled_dot_product_attention.py::test_case_4 PASSED
tests/test_scaled_dot_product_attention.py::test_case_5 PASSED
tests/test_scaled_dot_product_attention.py::test_case_6 FAILED ------------------------------------------------------------------------ Captured log call ------------------------------------------------------------------------
INFO Converter_5:utils.py:176 ===========================================
INFO Converter_5:utils.py:176 PyTorch to Paddle Convert Start ------>:
INFO Converter_5:utils.py:176 ===========================================
INFO Converter_5:utils.py:176 Start convert file: /home/qinsx/paddle_develop/PaConvert/test_project/pytorch_temp.py --> /home/qinsx/paddle_develop/PaConvert/test_project/paddle_temp.py
INFO Converter_5:utils.py:176 [pytorch_temp.py:3] remove 'import torch'
INFO Converter_5:utils.py:176 [pytorch_temp.py] add 'import paddle' in line 1
INFO Converter_5:utils.py:176 [pytorch_temp.py:1] [Success] Convert torch.float16 to Paddle
INFO Converter_5:utils.py:176 [pytorch_temp.py:7] [Success] Convert torch.tensor to Paddle
INFO Converter_5:utils.py:176 [pytorch_temp.py:1] [Success] Convert torch.bool to Paddle
INFO Converter_5:utils.py:176 [pytorch_temp.py:8] [Success] Convert torch.tensor to Paddle
INFO Converter_5:utils.py:176 [pytorch_temp.py:9] [Success] Convert torch.nn.functional.scaled_dot_product_attention to Paddle
INFO Converter_5:utils.py:176 [pytorch_temp.py:9] [Success] Convert Class Method: torch.Tensor.float to Paddle
INFO Converter_5:utils.py:176 Finish convert /home/qinsx/paddle_develop/PaConvert/test_project/pytorch_temp.py --> /home/qinsx/paddle_develop/PaConvert/test_project/paddle_temp.py
WARNING Converter_5:utils.py:165
===========================================
WARNING Converter_5:utils.py:165 Convert Summary
WARNING Converter_5:utils.py:165 ===========================================
WARNING Converter_5:utils.py:165 There are 6 Pytorch APIs in this Project:
WARNING Converter_5:utils.py:165 6 Pytorch APIs have been converted to Paddle successfully!
WARNING Converter_5:utils.py:165 0 Pytorch APIs are not supported to convert to Paddle currently!
WARNING Converter_5:utils.py:165 Convert Rate is: 100.00%
WARNING Converter_5:utils.py:165
Thank you to use Paddle Code Convert Tool. You can make any suggestions
to us by submitting issues to [https://github.com/PaddlePaddle/PaConvert].
===================================================================== short test summary info =====================================================================
FAILED tests/test_scaled_dot_product_attention.py::test_case_6 - ValueError: (InvalidArgument) attn_mask is expected to have the same data type with q.
=================================================================== 1 failed, 5 passed in 8.93s =================================================================== |
@Qin-sx 本地你改完了这个sdpa支持bool后,还是无法跑过吗?那说明sdpa的这个改动可能有问题。 |
嗯,我本地的修改完sdpa的分支atten_bool_mask之前跑过是没问题的。昨天的测试使用其他的开发分支跑的,sdpa函数没有修改,所以第六个测试没有通过。目的是为了测试原本的函数能不能跑通新增的第五个测试,以及在本地环境有没有执行测试。 |
使用atten_bool_mask分支在本地跑是没问题的 tests/test_scaled_dot_product_attention.py::test_case_1 W0605 22:21:59.724702 4018 gpu_resources.cc:114] Please NOTE: device: 0, GPU Compute Capability: 8.9, Driver API Version: 12.7, Runtime API Version: 11.8
PASSED
tests/test_scaled_dot_product_attention.py::test_case_2 PASSED
tests/test_scaled_dot_product_attention.py::test_case_3 PASSED
tests/test_scaled_dot_product_attention.py::test_case_4 PASSED
tests/test_scaled_dot_product_attention.py::test_case_5 PASSED
tests/test_scaled_dot_product_attention.py::test_case_6 PASSED
============================================================= 6 passed in 8.49s ============================================================== |
@@ -99,3 +99,45 @@ def test_case_4(): | |||
unsupport=True, | |||
reason="paddle not support 'scale' and 'enable_gqa' ", | |||
) | |||
|
|||
|
|||
@pytest.mark.skipif( |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
这个先合入了。
后面看一下这个sdpa的单测需要加这么多限制吗,可以直接去掉这些skip来跑吗
PR Docs
https://github.com/PaddlePaddle/docs/pull/7318/files
PaddlePaddle/Paddle#72927
PR APIs