-
Notifications
You must be signed in to change notification settings - Fork 125
feat: fp8 block scaling #543
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
fb57ec1
to
5b9c1ba
Compare
53d8ec3
to
59e8b12
Compare
975df8c
to
36c1710
Compare
c8304c0
to
5bc8868
Compare
d68514a
to
e3a8daf
Compare
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Should we also add a config, RL/examples/configs/grpo_math_8B_fp8_L3_F1_G_i.yaml
For example, below config can be a good candidate (with optionally set num_last_layers_in_bf16: 0 num_first_layers_in_bf16: 0):
GRPO Algorithm Configuration
defaults: "grpo_math_1B.yaml"
grpo:
num_prompts_per_step: 64
num_generations_per_prompt: 32
loss_fn:
use_importance_sampling_correction: true
policy:
model_name: "meta-llama/Llama-3.1-8B-Instruct"
tokenizer:
name: ${policy.model_name} ## specify if you'd like to use a tokenizer different from the model's default
train_global_batch_size: 512
train_micro_batch_size: 1
generation_batch_size: 32 # Only used when generating using HF backend
logprob_batch_size: 2
max_total_sequence_length: 4096
precision: "bfloat16"
fsdp_offload_enabled: false
activation_checkpointing_enabled: false
dtensor_cfg:
enabled: True
dynamic_batching:
train_mb_tokens: 4096
logprob_mb_tokens: 8192
optimizer:
name: "torch.optim.AdamW"
kwargs:
lr: 3.0e-7
weight_decay: 0.01
betas: [0.9, 0.999]
eps: 1e-8
scheduler:
- name: "torch.optim.lr_scheduler.LinearLR"
kwargs:
start_factor: 0.1
end_factor: 1.0
# The scheduler iteration is per GPRO step and is decoupled with the optimizer step (may be >=1 per GPRO step)
total_iters: 13
- name: "torch.optim.lr_scheduler.ConstantLR"
kwargs:
factor: 1.0
total_iters: 10000000000
- milestones: [13]
generation:
backend: "vllm"
max_new_tokens: ${policy.max_total_sequence_length}
temperature: 1.0
top_p: 1.0
top_k: null
stop_token_ids: null
stop_strings: null
vllm_cfg:
precision: 'fp8'
use_deep_gemm: true
num_last_layers_in_bf16: 3
num_first_layers_in_bf16: 1
tensor_parallel_size: 1
gpu_memory_utilization: 0.6
max_model_len: ${policy.max_total_sequence_length}
cluster:
gpus_per_node: 8
num_nodes: 1
e3a8daf
to
32ada21
Compare
36a127e
to
b899f3b
Compare
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Not super necessary immediately, but I think it'd be nice to include convergence plots for proof in the repo.
b899f3b
to
f5401dc
Compare
Signed-off-by: Jimmy Zhang <jiemingz@nvidia.com>
efe09d7
to
b2d7e9a
Compare
Signed-off-by: Jimmy Zhang <133159885+jiemingz@users.noreply.github.com>
Signed-off-by: Jimmy Zhang <jiemingz@nvidia.com>
Signed-off-by: Jimmy Zhang <jiemingz@nvidia.com> Signed-off-by: Jimmy Zhang <133159885+jiemingz@users.noreply.github.com> Signed-off-by: Sahil Jain <sahilj@nvidia.com> Co-authored-by: Sahil Jain <sahilj@nvidia.com> Co-authored-by: Sahil Jain <48468750+SahilJain314@users.noreply.github.com> Signed-off-by: Julien Veron Vialard <jveronvialar@nvidia.com>
Signed-off-by: Jimmy Zhang <jiemingz@nvidia.com> Signed-off-by: Jimmy Zhang <133159885+jiemingz@users.noreply.github.com> Signed-off-by: Sahil Jain <sahilj@nvidia.com> Co-authored-by: Sahil Jain <sahilj@nvidia.com> Co-authored-by: Sahil Jain <48468750+SahilJain314@users.noreply.github.com> Signed-off-by: Qidong Su <qidongs@nvidia.com>
Signed-off-by: Jimmy Zhang <jiemingz@nvidia.com> Signed-off-by: Jimmy Zhang <133159885+jiemingz@users.noreply.github.com> Signed-off-by: Sahil Jain <sahilj@nvidia.com> Co-authored-by: Sahil Jain <sahilj@nvidia.com> Co-authored-by: Sahil Jain <48468750+SahilJain314@users.noreply.github.com> Signed-off-by: Qidong Su <qidongs@nvidia.com>
What does this PR do ?
Add a one line overview of what this PR aims to accomplish.
Issues
List issues that this PR closes (syntax):
Usage
# Add a code snippet demonstrating how to use this
Before your PR is "Ready for review"
Pre checks:
Additional Information