-
Notifications
You must be signed in to change notification settings - Fork 127
feat: chunked logprob calculation with deferred fp32 cast to help with OOM #856
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
Signed-off-by: Peter Jin <pjin@nvidia.com>
Based on NeMo commit: 8ddf4387344c6423763ec9ee0c9a755cbb5d8d35 Signed-off-by: Peter Jin <pjin@nvidia.com>
Signed-off-by: Peter Jin <pjin@nvidia.com>
2a985fe
to
6a445bc
Compare
Signed-off-by: Peter Jin <pjin@nvidia.com>
nemo_rl/tron/model.py
Outdated
from nemo.collections.llm.t5.model.t5 import T5Config | ||
|
||
|
||
def get_model_from_config_no_float32( |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
was this function copied from somewhere? if so, what changes were made?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
it's a copy of nemo/tron/model.py:
https://github.com/NVIDIA/NeMo/blob/8ddf4387344c6423763ec9ee0c9a755cbb5d8d35/nemo/tron/model.py
the main change is removing the Float16Module
wrapper (which is what originally casts the model logits output to float32):
https://github.com/NVIDIA-NeMo/RL/pull/856/files/a020289609cfa0d7a695a175eed009fdb4695088#diff-37539801eab6c58172c5cf85be33a1f9eac04c096a8e23170550ddf3bff8e3b3R125-R128
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
If it's only a one-line change, I'd prefer the change be reflected in the submodule (you can branch where the submodule is at to update)
also, if you expect the model coming back to not be a FP16 but something else, could you add a test asserting the model type? We're currently migrating away from tron, so once that is done, this test would ensure we don't miss this typing fix you're adding
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
updated NeMo submodule
branch: https://github.com/NVIDIA/NeMo/tree/pjin/nemorl-logprob
commit: NVIDIA-NeMo/NeMo@0bf0dbc
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
also, if you expect the model coming back to not be a FP16 but something else, could you add a test asserting the model type? We're currently migrating away from tron, so once that is done, this test would ensure we don't miss this typing fix you're adding
what I did is add a float32 dtype check to the existing megatron logprobs test, and running that test on more cases of (logprob chunk size, deferred float32 logits)
https://github.com/NVIDIA-NeMo/RL/pull/856/files#diff-9556cb57e37308923c54e7a6df8982afafef5e36544f350af3324db43f74bdbeR703
one thing is that the policy model worker mainly exposes the model output through get_logprobs
, and there is not another interface for getting at the underlying torch model logits. but I think just checking that the returned logprobs are float32 should be sufficient?
@@ -141,6 +143,123 @@ def backward( | |||
return grad_input, None, None, None, None, None, None | |||
|
|||
|
|||
class ChunkedDistributedLogprob(torch.autograd.Function): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
can we add a unit test for this function so we make sure the non-chunked version equals the chunk?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
added a chunk_size
parameter to DistributedLogprobTestActor
:
Signed-off-by: Peter Jin <pjin@nvidia.com>
Signed-off-by: Peter Jin <pjin@nvidia.com>
Signed-off-by: Peter Jin <pjin@nvidia.com>
Signed-off-by: Peter Jin <pjin@nvidia.com>
58a202e
to
df70715
Compare
Signed-off-by: Peter Jin <pjin@nvidia.com>
Signed-off-by: Peter Jin <pjin@nvidia.com>
Signed-off-by: Peter Jin <pjin@nvidia.com>
Signed-off-by: Peter Jin <pjin@nvidia.com>
Signed-off-by: Peter Jin <pjin@nvidia.com>
Signed-off-by: Peter Jin <pjin@nvidia.com>
Signed-off-by: Peter Jin <pjin@nvidia.com>
Signed-off-by: Peter Jin <pjin@nvidia.com>
Signed-off-by: Peter Jin <pjin@nvidia.com>
Signed-off-by: Peter Jin <pjin@nvidia.com>
Signed-off-by: Peter Jin <pjin@nvidia.com>
Signed-off-by: Peter Jin <pjin@nvidia.com>
Signed-off-by: Peter Jin <pjin@nvidia.com>
moe_router_bias_update_rate: 0.0 # by default, disable bias updates for grpo | ||
apply_rope_fusion: True | ||
activation_checkpointing: True | ||
defer_fp32_logits: True |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
what would be the reason to set this to False?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
mostly for strict backward compat, but we could instead enable it by default (i.e. make it an opt-out config like no_defer_fp32_logits
or similar)
wdyt?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I see. How about the following:
- this PR introduces it, default off
- follow up PR where we run all our nightly tests to see if defaulting to true is ok, if so, remove the arg
wdyt? If the feature is broadly applicable we should probably switch it to true so no one else runs into the same issue (assuming no accuracy penalty)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
yup, (1) and then (2) SGTM!
Signed-off-by: Peter Jin <pjin@nvidia.com>
1c8186a
to
81fb8e1
Compare
Signed-off-by: Peter Jin <pjin@nvidia.com>
There's a permission issue with that |
Signed-off-by: Peter Jin <pjin@nvidia.com>
Signed-off-by: Peter Jin <pjin@nvidia.com>
closing in favor of #918 |
No description provided.