[tools] add fp8 max/min constant in utils #3959
Merged
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
Motivation
Suggested in PR#3702, AMD uses two FP8 formats :
Mismatch of behavior of FP8 in Meta implementation and AMD implementation
In tradition FP8 MAX in torch.float8_e4m3fnuz format is represented with
Currently in AMD, 224 is chosen as FP8 MAX :
But both values can be represented in AMD device functions. Here is the device function test with ROCm SDK 6.3:
From the above test, both 224, 240 are valid numer in AMD FP8 E4M3 FNUZ format. So what's the problem ?
The problem is round trip. Here is the torch test with AMD device :
Any value great than FP8 MAX will be represented as NaN in pytorch, not clamped to a proper max value.
However, in AMD device implementation, it is clamp to 240.
In current vllm application, AMD FP8 MAX is hard coded as 224 for better precision, though 240 is also valid in ROCM 6.3 SDK
Besides I also found, min sub norm value (1e-3 ~ 0.000976562) is not provided by default in torch. So I added them, so it can be useful extend the represetnation of digits below minimum normals, hence can be used to adjust weights.Group Quant efficiency
MXFP8 was introduced in OCP project and supported in NVIDIA Blackwell, AMD Quantization toolkit Quark, and projected to be availabe in pytorch in 2025 Q1. This enables
native
group quant with 32 consecutive elements instead of per tensor quant.The immediate benefits bought by MXFP8 is that we can use FP8 E8M0 to store scaling number instead of FP32 .
Currently SGLang implemented group quant with group size 128, then we can adjust the implementation to support 32-group quant in Blackwell and other chips support OCP MXFP8.
Modifications
Add global constant. Usage example :
Note : FP8_E4M3_MIN_SUB_NORM , OCP_MXFP8_E8M0_GROUP_QUANT_GRANULARITY will be added in the relevant PRs.
Checklist