-
Notifications
You must be signed in to change notification settings - Fork 5.8k
优化scaled_dot_product_attention中的后端切换逻辑 #73157
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: develop
Are you sure you want to change the base?
Conversation
modified: python/paddle/nn/functional/flash_attention.py
你的PR提交成功,感谢你对开源项目的贡献! |
modified: python/paddle/nn/functional/flash_attention.py
@@ -66,12 +66,13 @@ def check_flash_head_dim_constraints(query, dropout_p=0.0): | |||
|
|||
is_head_dim_gt192 = head_dim > 192 | |||
is_head_dim_lte224 = head_dim <= 224 | |||
is_dropout = dropout_p > 0.0 | |||
# is_dropout = dropout_p > 0.0 |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
这三个判断再对比下torch的逻辑,贴上来
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
pytorch中的判断要求
bool can_use_flash_attention(sdp_params const& params, bool debug) {
#ifndef USE_FLASH_ATTENTION
if (debug) {
TORCH_WARN("Torch was not compiled with flash attention.");
}
return false;
#else // defined(USE_FLASH_ATTENTION)
// Define gate functions that determine if a flash kernel can be ran
// Replace with std::to_array when we migrate to c++20
constexpr auto general_constraints = array_of<bool (*)(sdp_params const&, bool)>(
check_runtime_disabled_flash,
check_all_tensors_on_device,
check_tensor_shapes,
check_for_attn_mask,
check_head_dim_size_flash<false /*caller_is_meff*/>,
check_flash_attention_hardware_support,
check_requires_grad_and_head_dim_gt192_constraints_on_sm86_89,
check_flash_causal_non_square_seqlens,
check_dtypes_low_precision);
for (auto& constraint : general_constraints) {
if (!constraint(params, debug)) {
return false;
}
}
if (has_for_nested_inputs(params)) {
constexpr auto nested_constraints = array_of<bool (*)(sdp_params const&, bool)>(
check_batch_size_nested,
check_head_dim_size_flash_nested<false /*caller_is_meff*/>,
check_for_seq_len_0_nested_tensor);
for (auto& constraint : nested_constraints) {
if (!constraint(params, debug)) {
return false;
}
}
}
constexpr bool backend_supports_grouped_query_attention = true;
if (has_only_dense_inputs(params)) {
constexpr auto dense_constraints = array_of<bool (*)(sdp_params const&, bool)>(
check_batch_size_and_num_heads_dense<backend_supports_grouped_query_attention>,
check_nonzero_sequence_lengths_dense,
check_last_dim_stride_equals_1_dense<true /*ignore_singleton_dim=*/>);
for (auto& constraint : dense_constraints) {
if (!constraint(params, debug)) {
return false;
}
}
}
return true;
#endif // defined(USE_FLASH_ATTENTION)
}
bool check_requires_grad_and_head_dim_gt192_constraints_on_sm86_89(
sdp_params const& params,
bool debug) {
// Flash Attention will raise an error in the backward pass if the head_dim
// size is greater than 192 And the device is between in the range [sm86, sm89]
using sm86 = SMVersion<8, 6>;
using sm89 = SMVersion<8, 9>;
auto dprops = at::cuda::getCurrentDeviceProperties();
bool is_sm86_or_sm89 = check_sm_version<sm86, sm89>(dprops);
bool is_head_dim_gt192 = params.query.sym_size(-1) > 192;
bool is_head_dim_lte224 = params.query.sym_size(-1) <= 224;
bool is_dropout = params.dropout > 0.0;
// head_dim size in (192, 224] is not supported on sm86 and sm89
bool cond1 = is_head_dim_gt192 && is_head_dim_lte224;
// head_dim size > 224 and is_dropout is not supported on sm86 and sm89
bool cond2 = params.query.sym_size(-1) > 224 && is_dropout;
if (input_requires_grad(params) && is_sm86_or_sm89 && (cond1 || cond2)) {
if (debug) {
TORCH_WARN(
"Flash attention currently doesn't support training with head_dim ∈ (192, 224] or "
"(head_dim ∈ (224, 256] and dropout > 0.0) on gpu architectures in the range[sm86, sm89].",
"Attempting to run with dropout set to: ", params.dropout,
"and head_dim: ",
params.query.sym_size(-1), " on a sm ", dprops->major, ".",
dprops->minor, " gpu.");
}
return false;
}
return true;
}
这里paddle原本的判断逻辑应该是和pytorch一致的,pytorch在check_requires_grad_and_head_dim_gt192_constraints_on_sm86_89
函数中进行了判断
在flash attention库中目前找到了对head dim > 192的反向传播要求,但是其他的判断还没有找到
flash attention库中的判断要求
#define CHECK_BWD_EXECTUABLE(__seqlen_q, __seqlen_k) \
CHECK_FWD_EXECTUABLE(__seqlen_q, __seqlen_k) \
const bool is_sm80 = dprops->major == 8 && dprops->minor == 0; \
if (head_size > 192) { \
/* FlashAttention backward for head dim > 192 requires A100/A800 or H100/H800 */ \
ASSERT_CHECK(is_sm80 || is_sm90); \
}
请问一下是否参考pytorch的判断逻辑?即不对paddle原本的函数判断进行修改
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@Qin-sx 推测这个可能是pytorch自身实现的一个bug,导致这个case下的反向运行报错。这个不用对齐了。
看一下paddle的这几个后端kernel中的报错判断,与后端kernel的实际能力对齐,避免像这样的错误:前端选择了flash attention分支,结果后端kernel又不支持,前后不一致:(另一个后端variable emeffent也看下)
改完后也同时修一下paconvert里这个单测,这个由于后端选择的问题,很多case跑不过(会错误的选到flash attn分支):
https://github.com/PaddlePaddle/PaConvert/blob/master/tests/test_scaled_dot_product_attention.py
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
嗯,但是我感觉paddle和pytorch后端应该都是调用的flash attention库,要不我提两个pr吧,head_dim <= 256的是确定需要修改的,我再重新提一个pr,这样如果需要revert也方便一些。这个pr我再试一下会有什么报错情况。
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
嗯,但是我感觉paddle和pytorch后端应该都是调用的flash attention库,要不我提两个pr吧,head_dim <= 256的是确定需要修改的,我再重新提一个pr,这样如果需要revert也方便一些。这个pr我再试一下会有什么报错情况。
看提了另一个,那这个PR就改去掉check_requires_grad_and_head_dim_gt192_constraints_on_sm86_89_or_120或还有其他认为不合理的逻辑。
同时对齐前后端,前后端对齐也是明确要改的,因为后端都不支持,前端选择了也没意义,只会直接触发报错。
@Qin-sx 这个PR也需要修改 |
嗯,收到,但是这部分有点复杂,有可能涉及DCU的测试,我打算放在后面处理 |
Sorry to inform you that 1cb2356's CIs have passed for more than 7 days. To prevent PR conflicts, you need to re-run all CIs manually. |
PR Category
User Experience
PR Types
Improvements
Description
flash attention应该是调用的Paddle fork的flash attention库
scaled_dot_product_attention调用的接口应该是
其中的检查为
head_size的检查为head_size <= 256即可