-
Notifications
You must be signed in to change notification settings - Fork 5.8k
add fp32 grad plus fp16 param in adamw #51141
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
add fp32 grad plus fp16 param in adamw #51141
Conversation
你的PR提交成功,感谢你对开源项目的贡献! |
dev_ctx.template Alloc<T>(param_out), | ||
master_in_data, | ||
master_out_data, | ||
param.numel()); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Is grad_type
completely independent with T
or MPDType
?
If not, there might be a better way to gather 2 branches into just one interface with template.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yes, grad_type
is independent of T
or MPDType
.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM
* Cherry-pick the register of bfloat16 for amp_kernel, pull request #45541. * Cherry-pick the master_grad support of adamw, pull request #51141. * add bf16 for some ops in static mode (#51582) * Add bfloat16 support for some api in static mode. * Fix codestyle. * Revert the change of layer_function_generator.py. --------- Co-authored-by: Shaojie WANG <wsjmessi@163.com>
PR types
Others
PR changes
OPs
Describe
optimizer with datatype casting