Skip to content

Conversation

Qin-sx
Copy link
Contributor

@Qin-sx Qin-sx commented Jun 7, 2025

PR Category

User Experience

PR Types

Improvements

Description

flash attention应该是调用的Paddle fork的flash attention
scaled_dot_product_attention调用的接口应该是

bool flash_attn_fwd(const void * const q,         // batch_size x seqlen_q x num_heads x head_size
                    const void * const k,         // batch_size x seqlen_k x num_heads_k x head_size
                    const void * const v,         // batch_size x seqlen_k x num_heads_k x head_size
                    void * const rng_state,
                    void * const out,
                    void * const softmax_ptr,
                    void * const softmax_lse_ptr,
                    const int batch_size,
                    const int seqlen_q,
                    const int seqlen_k,
                    const int seqlen_q_rounded,
                    const int seqlen_k_rounded,
                    const int num_heads,
                    const int num_heads_k,
                    const int head_size,
                    const int head_size_rounded,
                    const float p_dropout,
                    const float softmax_scale,
                    const float softmax_unscale,
                    const bool is_causal,
                    const bool return_softmax,
                    const bool is_bf16,
                    cudaStream_t stream,
                    uint64_t seed,
                    uint64_t offset,
                    const void * const attn_mask,
                    const int64_t * const mask_dims,
                    const void * const flashmask_downstart_ptr,
                    const int64_t * const flashmask_dims,
                    const void * const flashmask_upend_ptr,
                    const void * const flashmask_downend_ptr,
                    const void * const flashmask_upstart_ptr,
                    const void * const flashmask_maxmin_ptr,
                    const int q_row_stride,
                    const int k_row_stride,
                    const int v_row_stride,
                    const int q_head_stride,
                    const int k_head_stride,
                    const int v_head_stride,
                    const int o_row_stride,
                    const int o_head_stride,
                    const int q_batch_stride,
                    const int k_batch_stride,
                    const int v_batch_stride,
                    const int o_batch_stride);

其中的检查为

#define CHECK_FWD_EXECTUABLE(__seqlen_q, __seqlen_k)                     \
      auto dprops = at::cuda::getCurrentDeviceProperties();              \
      const bool is_sm8x = dprops->major == 8 && dprops->minor >= 0;     \
      const bool is_sm90 = dprops->major == 9 && dprops->minor == 0;     \
      ASSERT_CHECK(is_sm8x || is_sm90);                                  \
      ASSERT_CHECK(batch_size > 0);                                      \
      ASSERT_CHECK(head_size % 8 == 0);                                  \
      ASSERT_CHECK(head_size <= 256);                                    \
      ASSERT_CHECK(num_heads % num_heads_k == 0);                        \
      if (attn_mask) {                                                   \
          ASSERT_CHECK(mask_dims[0] == batch_size);                      \
          ASSERT_CHECK(mask_dims[1] == 1 || mask_dims[1] == num_heads);  \
          ASSERT_CHECK(mask_dims[2] == 1 || mask_dims[2] == __seqlen_q); \
          ASSERT_CHECK(mask_dims[3] == __seqlen_k);                      \
      }

head_size的检查为head_size <= 256即可

	modified:   python/paddle/nn/functional/flash_attention.py
Copy link

paddle-bot bot commented Jun 7, 2025

你的PR提交成功,感谢你对开源项目的贡献!
请关注后续CI自动化测试结果,详情请参考Paddle-CI手册
Your PR has been submitted. Thanks for your contribution!
Please wait for the result of CI firstly. See Paddle CI Manual for details.

@paddle-bot paddle-bot bot added the contributor External developers label Jun 7, 2025
	modified:   python/paddle/nn/functional/flash_attention.py
@@ -66,12 +66,13 @@ def check_flash_head_dim_constraints(query, dropout_p=0.0):

is_head_dim_gt192 = head_dim > 192
is_head_dim_lte224 = head_dim <= 224
is_dropout = dropout_p > 0.0
# is_dropout = dropout_p > 0.0
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

这三个判断再对比下torch的逻辑,贴上来

Copy link
Contributor Author

@Qin-sx Qin-sx Jun 10, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

pytorch中的判断要求

bool can_use_flash_attention(sdp_params const& params, bool debug) {
#ifndef USE_FLASH_ATTENTION
  if (debug) {
    TORCH_WARN("Torch was not compiled with flash attention.");
  }
  return false;
#else // defined(USE_FLASH_ATTENTION)
  // Define gate functions that determine if a flash kernel can be ran
  // Replace with std::to_array when we migrate to c++20
  constexpr auto general_constraints = array_of<bool (*)(sdp_params const&, bool)>(
      check_runtime_disabled_flash,
      check_all_tensors_on_device,
      check_tensor_shapes,
      check_for_attn_mask,
      check_head_dim_size_flash<false /*caller_is_meff*/>,
      check_flash_attention_hardware_support,
      check_requires_grad_and_head_dim_gt192_constraints_on_sm86_89,
      check_flash_causal_non_square_seqlens,
      check_dtypes_low_precision);
  for (auto& constraint : general_constraints) {
    if (!constraint(params, debug)) {
      return false;
    }
  }

  if (has_for_nested_inputs(params)) {
    constexpr auto nested_constraints = array_of<bool (*)(sdp_params const&, bool)>(
        check_batch_size_nested,
        check_head_dim_size_flash_nested<false /*caller_is_meff*/>,
        check_for_seq_len_0_nested_tensor);
    for (auto& constraint : nested_constraints) {
      if (!constraint(params, debug)) {
        return false;
      }
    }
  }
  constexpr bool backend_supports_grouped_query_attention = true;
  if (has_only_dense_inputs(params)) {
    constexpr auto dense_constraints = array_of<bool (*)(sdp_params const&, bool)>(
        check_batch_size_and_num_heads_dense<backend_supports_grouped_query_attention>,
        check_nonzero_sequence_lengths_dense,
        check_last_dim_stride_equals_1_dense<true /*ignore_singleton_dim=*/>);
    for (auto& constraint : dense_constraints) {
      if (!constraint(params, debug)) {
        return false;
      }
    }
  }
  return true;
#endif // defined(USE_FLASH_ATTENTION)
}

bool check_requires_grad_and_head_dim_gt192_constraints_on_sm86_89(
    sdp_params const& params,
    bool debug) {
  // Flash Attention will raise an error in the backward pass if the head_dim
  // size is greater than 192 And the device is between in the range [sm86, sm89]
  using sm86 = SMVersion<8, 6>;
  using sm89 = SMVersion<8, 9>;
  auto dprops = at::cuda::getCurrentDeviceProperties();
  bool is_sm86_or_sm89 = check_sm_version<sm86, sm89>(dprops);
  bool is_head_dim_gt192 = params.query.sym_size(-1) > 192;
  bool is_head_dim_lte224 = params.query.sym_size(-1) <= 224;
  bool is_dropout = params.dropout > 0.0;
  //  head_dim size  in (192, 224] is not supported on sm86 and sm89
  bool cond1 = is_head_dim_gt192 && is_head_dim_lte224;
  // head_dim size > 224 and is_dropout is not supported on sm86 and sm89
  bool cond2 = params.query.sym_size(-1) > 224 && is_dropout;
  if (input_requires_grad(params) && is_sm86_or_sm89 && (cond1 || cond2)) {
    if (debug) {
      TORCH_WARN(
          "Flash attention currently doesn't support training with head_dim ∈ (192, 224] or "
          "(head_dim ∈ (224, 256] and dropout > 0.0) on gpu architectures in the range[sm86, sm89].",
          "Attempting to run with dropout set to: ", params.dropout,
          "and head_dim: ",
          params.query.sym_size(-1), " on a sm ", dprops->major, ".",
          dprops->minor, " gpu.");
    }
    return false;
  }
  return true;
}

这里paddle原本的判断逻辑应该是和pytorch一致的,pytorch在check_requires_grad_and_head_dim_gt192_constraints_on_sm86_89函数中进行了判断
在flash attention库中目前找到了对head dim > 192的反向传播要求,但是其他的判断还没有找到

flash attention库中的判断要求

#define CHECK_BWD_EXECTUABLE(__seqlen_q, __seqlen_k)                                       \
      CHECK_FWD_EXECTUABLE(__seqlen_q, __seqlen_k)                                         \
      const bool is_sm80 = dprops->major == 8 && dprops->minor == 0;                       \
      if (head_size > 192) {                                                               \
          /* FlashAttention backward for head dim > 192 requires A100/A800 or H100/H800 */ \
          ASSERT_CHECK(is_sm80 || is_sm90);                                                \
      }

请问一下是否参考pytorch的判断逻辑?即不对paddle原本的函数判断进行修改

Copy link
Contributor

@zhwesky2010 zhwesky2010 Jun 10, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@Qin-sx 推测这个可能是pytorch自身实现的一个bug,导致这个case下的反向运行报错。这个不用对齐了。

infoflow 2025-06-10 15-47-30

看一下paddle的这几个后端kernel中的报错判断,与后端kernel的实际能力对齐,避免像这样的错误:前端选择了flash attention分支,结果后端kernel又不支持,前后不一致:(另一个后端variable emeffent也看下)

infoflow 2025-06-10 15-55-30

改完后也同时修一下paconvert里这个单测,这个由于后端选择的问题,很多case跑不过(会错误的选到flash attn分支):

https://github.com/PaddlePaddle/PaConvert/blob/master/tests/test_scaled_dot_product_attention.py

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

嗯,但是我感觉paddle和pytorch后端应该都是调用的flash attention库,要不我提两个pr吧,head_dim <= 256的是确定需要修改的,我再重新提一个pr,这样如果需要revert也方便一些。这个pr我再试一下会有什么报错情况。

Copy link
Contributor

@zhwesky2010 zhwesky2010 Jun 11, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

嗯,但是我感觉paddle和pytorch后端应该都是调用的flash attention库,要不我提两个pr吧,head_dim <= 256的是确定需要修改的,我再重新提一个pr,这样如果需要revert也方便一些。这个pr我再试一下会有什么报错情况。

看提了另一个,那这个PR就改去掉check_requires_grad_and_head_dim_gt192_constraints_on_sm86_89_or_120或还有其他认为不合理的逻辑。

同时对齐前后端,前后端对齐也是明确要改的,因为后端都不支持,前端选择了也没意义,只会直接触发报错。

@zhwesky2010 zhwesky2010 changed the title 优化scaled_dot_product_attention中的head_dim判断 优化scaled_dot_product_attention中的后端切换逻辑 Jun 10, 2025
@zhwesky2010
Copy link
Contributor

@Qin-sx 这个PR也需要修改

@Qin-sx
Copy link
Contributor Author

Qin-sx commented Jun 13, 2025

@Qin-sx 这个PR也需要修改

嗯,收到,但是这部分有点复杂,有可能涉及DCU的测试,我打算放在后面处理

Copy link

paddle-ci-bot bot commented Jun 16, 2025

Sorry to inform you that 1cb2356's CIs have passed for more than 7 days. To prevent PR conflicts, you need to re-run all CIs manually.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
contributor External developers
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants