-
Notifications
You must be signed in to change notification settings - Fork 5.8k
Add FusedMultiTransformer fuse pass for GPT3 #45907
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Add FusedMultiTransformer fuse pass for GPT3 #45907
Conversation
552a2c5
to
0617335
Compare
paddle/fluid/framework/ir/pass.cc
Outdated
if (!graph->Has(kPassRecorder)) { | ||
graph->Set<PassRecorder>(kPassRecorder, new PassRecorder); | ||
} | ||
graph->Get<PassRecorder>(kPassRecorder).insert(Type()); | ||
|
||
if(graph->IsMainGraph() and "graph_viz_pass"!=Type() and "graph_to_program_pass"!=Type()) { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
and?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Done, Thanks!
int fusion_count{0}; | ||
auto handler = [&](const GraphPatternDetector::subgraph_t& subgraph, | ||
Graph* g) { | ||
GET_IR_NODE_FROM_SUBGRAPH(input0, input0, fused_multi_transformer_pattern); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
需要主动调用IsCompat(subgraph, graph), AddOpCompat才能生效
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Done, Thanks!
@@ -0,0 +1,416 @@ | |||
// Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
2022
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Done,thanks
… add_fused_multi_transformer_pass_fleetx
4e1aba3
1cd0d85
to
409cf3a
Compare
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
增补一下新加几个pass功能和效果的更具体描述。
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
训练侧也看一下静态训练的复用问题,@sneaxiy @Xreki @gongweibao
PR types
New features
PR changes
Others
Describe
Add FusedMultiTransformer fuse pass for GPT3
add passes:
fused_multi_transformer_encoder
fused_multi_transformer_decoder
fused_multi_transformer_encoder_fuse_qkv
fused_multi_transformer_decoder_fuse_qkv
fused_multi_transformer_encoder


匹配GPT encoder部分的Transformer Layer(包含multt-head attention和feed forward部分)
转化为
FusedMultiTransformer
融合OP,如图中红框部分即为一层,用于推理加速fused_multi_transformer_decoder


decoder结构与encoder类似,多了
Cache KV
的处理,如图中红框部分转换为带
TimeStep
输入的FusedMultiTransformer
融合OP,用于解码fused_multi_transformer_encoder/decoder_fuse_qkv

fused_multi_transformer_encoder/decoder_fuse_qkv
结构与fused_multi_transformer_encoder/decoder
类似,差别为Multi-Head Attention例QKV是concat成一个Tensor计算的,计算完后通过split分开,如图中红框所示add multi-devices passes:
multi_devices_fused_multi_transformer_encoder_fuse_qkv
multi_devices_fused_multi_transformer_decoder_fuse_qkv
multi-devices passes是在上述pass的基础上插入了
c_identity
和c_allreduce_sum
卡间通信OPadd subgraph pass support for the 6 passes
add
IR_NODE_UNLINK
for unlinking nodes in Graphfix
Var->persistent
check segment fault when Variable is nullTODO