Skip to content

Conversation

zkh2016
Copy link
Contributor

@zkh2016 zkh2016 commented Aug 25, 2021

PR types

New features

PR changes

OPs

Describe

Fused elementwise_add, dropout, elementwise_add and layer_norm into one operator, only support Forward

//before fusion
out1 = elementwise_add(src, bias)
out2 = dropout(out1)
out3 = elementwise_add(residual, out2)
out = layer_norm(out3, other args)
//after fusion
out = fused_layernorm_residual_dropout_bias(src, residual, bias, other args)

@paddle-bot-old
Copy link

Thanks for your contribution!
Please wait for the result of CI firstly. See Paddle CI Manual for details.

@zkh2016 zkh2016 force-pushed the fused_layernorm_residual_dropout_bias branch from 87b2723 to aa27d96 Compare August 31, 2021 06:35
@zkh2016 zkh2016 force-pushed the fused_layernorm_residual_dropout_bias branch from 9c77253 to bdb5f85 Compare September 9, 2021 02:12
@zkh2016 zkh2016 marked this pull request as draft September 9, 2021 06:16
@zkh2016 zkh2016 marked this pull request as ready for review September 15, 2021 03:15
namespace paddle {
namespace operators {

namespace cg = cooperative_groups;
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

这个没用到?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

done

if (is_test) {
factor = is_upscale_in_train ? static_cast<T>(1.0f)
: static_cast<T>(1.0f - dropout_prob);
}
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

这段忘记替换为GetFactor了

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

done

@xingfeng01
Copy link
Contributor

建议下个 PR 加些函数的注释,解释下计算逻辑。

@xingfeng01
Copy link
Contributor

LGTM

@lanxianghit lanxianghit merged commit 7975dfc into PaddlePaddle:develop Sep 17, 2021
AnnaTrainingG pushed a commit to AnnaTrainingG/Paddle that referenced this pull request Sep 29, 2021
…35151)

Fused elementwise_add, dropout, elementwise_add and layer_norm into one operator, only support Forward. 
No Python API changed.
@zkh2016 zkh2016 deleted the fused_layernorm_residual_dropout_bias branch August 19, 2022 04:05
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

4 participants