Skip to content

Conversation

dskhudia
Copy link
Contributor

What does this PR do?

Adds amp_fp8 precision. This will allows us to train model faster using FP8 precision on H100 systems. It's a no op if amp_fp8 precision is not used.

What issue(s) does this change relate to?

FP8 training support on H100

@dskhudia dskhudia force-pushed the amp_bf8 branch 2 times, most recently from d2917a5 to 5c04109 Compare February 14, 2023 00:32
Copy link
Contributor

@dakinggg dakinggg left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Could you add some brief docs to the precision class about this new option? Also, is this something you want to add to setup.py as an optional dependency, or too soon? Lastly, is there any way to write simple tests for this?

Copy link
Contributor

@mvpatel2000 mvpatel2000 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is there a target for transformer_engine in setup.py (is it already a dependency somewhere)? if not, can we added it?

@dskhudia
Copy link
Contributor Author

Lastly, is there any way to write simple tests for this?

@dakinggg Since it's not really usable on current hardware I can add some negative tests where we check for the errors on current hardware. Something like this useful?

@dskhudia
Copy link
Contributor Author

Is there a target for transformer_engine in setup.py

Also, is this something you want to add to setup.py as an optional dependency, or too soon?

Added a target txengine but two issues: 1) there is no pypi package for it 2) Installation depends on torch so it fails due to pip build isolation :-(

@dakinggg
Copy link
Contributor

Yeah, a negative test that errors would be helpful for now. Just to exercise the code path. Thanks! And hm, installation depending on torch is unfortunate. In that case, could you put installation instructions somewhere? Probably in the documentation for it fp8?

@dskhudia dskhudia force-pushed the amp_bf8 branch 2 times, most recently from 099c050 to dd11bbd Compare February 16, 2023 20:34
@dskhudia
Copy link
Contributor Author

@dakinggg added tests
@mvpatel2000 added installation instructions

@dskhudia dskhudia merged commit c4ce366 into mosaicml:dev Feb 16, 2023
@dskhudia dskhudia deleted the amp_bf8 branch February 16, 2023 22:51
dakinggg pushed a commit to dakinggg/composer that referenced this pull request Feb 17, 2023
@lukaemon
Copy link

New to fp8 or bf8. Does this mean it's possible to do fp8 training with consumer level ada GPU like rtx 4090?

@dskhudia
Copy link
Contributor Author

@lukaemon : Two variants of FP8 exist and there is nothing called bf8 on NVidia cards.

training with consumer level ada GPU

I think so. Provided you have CUDA 12, TransformerEngine layers in your model and using amp_fp8 precision.
NVidia announcement does point to FP8 in 4090.

Ada’s new 4th Generation Tensor Cores are unbelievably fast, with an all new 8-Bit Floating Point (FP8) Tensor Engine, increasing throughput by up to 5X, to 1.32 Tensor-petaFLOPS on the GeForce RTX 4090.

@lukaemon
Copy link

@dskhudia Thanks for clarification.

@vgoklani
Copy link

@lukaemon
Copy link

@vgoklani Thanks. Ada fp8 support after 23q2 at least.

elif dtype in ['amp_fp8']:
# We use torch.bfloat16 by default for amp_fp8 as there is no
# fp8 datatype in PyTorch yet.
return torch.bfloat16

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hi! New to your repo! Fp8 integration sounds super nice! I'm just trying to understand what this line implies? Does it mean that everything is running in bf16 when we specify fp8?

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nvm I think I misunderstood how the fp8 system worked I was somehow expecting to play with a specific dtype. Not the context manager. My bad

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Glad that your confusion is resolved.

@float-trip
Copy link

float-trip commented Aug 9, 2023

TransformerEngine has support for Ada GPUs now. Can the restriction in this PR be loosened to torch.cuda.get_device_capability() >= (8, 9)?

I've tested fp8 on a 4090 with llm-foundry and training starts successfully.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

7 participants