-
Notifications
You must be signed in to change notification settings - Fork 452
Add support for FP8 on H100 using NVidia's TransformerEngine #1965
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
d2917a5
to
5c04109
Compare
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Could you add some brief docs to the precision class about this new option? Also, is this something you want to add to setup.py
as an optional dependency, or too soon? Lastly, is there any way to write simple tests for this?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Is there a target for transformer_engine
in setup.py
(is it already a dependency somewhere)? if not, can we added it?
@dakinggg Since it's not really usable on current hardware I can add some negative tests where we check for the errors on current hardware. Something like this useful? |
Added a target txengine but two issues: 1) there is no pypi package for it 2) Installation depends on torch so it fails due to pip build isolation :-( |
Yeah, a negative test that errors would be helpful for now. Just to exercise the code path. Thanks! And hm, installation depending on torch is unfortunate. In that case, could you put installation instructions somewhere? Probably in the documentation for it fp8? |
099c050
to
dd11bbd
Compare
@dakinggg added tests |
New to |
@lukaemon : Two variants of FP8 exist and there is nothing called bf8 on NVidia cards.
I think so. Provided you have CUDA 12, TransformerEngine layers in your model and using amp_fp8 precision.
|
@dskhudia Thanks for clarification. |
@vgoklani Thanks. Ada fp8 support after 23q2 at least. |
elif dtype in ['amp_fp8']: | ||
# We use torch.bfloat16 by default for amp_fp8 as there is no | ||
# fp8 datatype in PyTorch yet. | ||
return torch.bfloat16 |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Hi! New to your repo! Fp8 integration sounds super nice! I'm just trying to understand what this line implies? Does it mean that everything is running in bf16 when we specify fp8?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Nvm I think I misunderstood how the fp8 system worked I was somehow expecting to play with a specific dtype. Not the context manager. My bad
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Glad that your confusion is resolved.
TransformerEngine has support for Ada GPUs now. Can the restriction in this PR be loosened to I've tested fp8 on a 4090 with llm-foundry and training starts successfully. |
What does this PR do?
Adds amp_fp8 precision. This will allows us to train model faster using FP8 precision on H100 systems. It's a no op if amp_fp8 precision is not used.
What issue(s) does this change relate to?
FP8 training support on H100