Skip to content

Conversation

pjin-nvidia
Copy link
Contributor

No description provided.

@pjin-nvidia pjin-nvidia marked this pull request as ready for review August 6, 2025 20:56
@pjin-nvidia pjin-nvidia changed the title Fix logprobs and logits-related OOM fix: logprobs and logits-related OOM Aug 6, 2025
Signed-off-by: Peter Jin <pjin@nvidia.com>
Based on NeMo commit: 8ddf4387344c6423763ec9ee0c9a755cbb5d8d35

Signed-off-by: Peter Jin <pjin@nvidia.com>
Signed-off-by: Peter Jin <pjin@nvidia.com>
Signed-off-by: Peter Jin <pjin@nvidia.com>
Signed-off-by: Peter Jin <pjin@nvidia.com>
from nemo.collections.llm.t5.model.t5 import T5Config


def get_model_from_config_no_float32(
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

was this function copied from somewhere? if so, what changes were made?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If it's only a one-line change, I'd prefer the change be reflected in the submodule (you can branch where the submodule is at to update)

also, if you expect the model coming back to not be a FP16 but something else, could you add a test asserting the model type? We're currently migrating away from tron, so once that is done, this test would ensure we don't miss this typing fix you're adding

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

also, if you expect the model coming back to not be a FP16 but something else, could you add a test asserting the model type? We're currently migrating away from tron, so once that is done, this test would ensure we don't miss this typing fix you're adding

what I did is add a float32 dtype check to the existing megatron logprobs test, and running that test on more cases of (logprob chunk size, deferred float32 logits)
https://github.com/NVIDIA-NeMo/RL/pull/856/files#diff-9556cb57e37308923c54e7a6df8982afafef5e36544f350af3324db43f74bdbeR703

one thing is that the policy model worker mainly exposes the model output through get_logprobs, and there is not another interface for getting at the underlying torch model logits. but I think just checking that the returned logprobs are float32 should be sufficient?

@@ -141,6 +143,123 @@ def backward(
return grad_input, None, None, None, None, None, None


class ChunkedDistributedLogprob(torch.autograd.Function):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

can we add a unit test for this function so we make sure the non-chunked version equals the chunk?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Signed-off-by: Peter Jin <pjin@nvidia.com>
Signed-off-by: Peter Jin <pjin@nvidia.com>
Signed-off-by: Peter Jin <pjin@nvidia.com>
Signed-off-by: Peter Jin <pjin@nvidia.com>
Signed-off-by: Peter Jin <pjin@nvidia.com>
Signed-off-by: Peter Jin <pjin@nvidia.com>
Signed-off-by: Peter Jin <pjin@nvidia.com>
Signed-off-by: Peter Jin <pjin@nvidia.com>
Signed-off-by: Peter Jin <pjin@nvidia.com>
Signed-off-by: Peter Jin <pjin@nvidia.com>
Signed-off-by: Peter Jin <pjin@nvidia.com>
Signed-off-by: Peter Jin <pjin@nvidia.com>
Signed-off-by: Peter Jin <pjin@nvidia.com>
Signed-off-by: Peter Jin <pjin@nvidia.com>
Signed-off-by: Peter Jin <pjin@nvidia.com>
Signed-off-by: Peter Jin <pjin@nvidia.com>
Signed-off-by: Peter Jin <pjin@nvidia.com>
Signed-off-by: Peter Jin <pjin@nvidia.com>
Signed-off-by: Peter Jin <pjin@nvidia.com>
Signed-off-by: Peter Jin <pjin@nvidia.com>
Signed-off-by: Peter Jin <pjin@nvidia.com>
moe_router_bias_update_rate: 0.0 # by default, disable bias updates for grpo
apply_rope_fusion: True
activation_checkpointing: True
defer_fp32_logits: True
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

what would be the reason to set this to False?

Copy link
Contributor Author

@pjin-nvidia pjin-nvidia Aug 12, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

mostly for strict backward compat, but we could instead enable it by default (i.e. make it an opt-out config like no_defer_fp32_logits or similar)

wdyt?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I see. How about the following:

  1. this PR introduces it, default off
  2. follow up PR where we run all our nightly tests to see if defaulting to true is ok, if so, remove the arg
    wdyt? If the feature is broadly applicable we should probably switch it to true so no one else runs into the same issue (assuming no accuracy penalty)

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

yup, (1) and then (2) SGTM!

Signed-off-by: Peter Jin <pjin@nvidia.com>
@terrykong terrykong changed the title fix: logprobs and logits-related OOM feat: chunked logprob calculation with deferred fp32 cast to help with OOM Aug 12, 2025
Signed-off-by: Peter Jin <pjin@nvidia.com>
@terrykong
Copy link
Contributor

There's a permission issue with that Check submodule fast-forward job. @chtruong814 is taking a look

Signed-off-by: Peter Jin <pjin@nvidia.com>
@chtruong814 chtruong814 added the CI:L1 Run doctests, unit tests, and functional tests label Aug 13, 2025
@chtruong814 chtruong814 added CI:L1 Run doctests, unit tests, and functional tests and removed CI:L1 Run doctests, unit tests, and functional tests labels Aug 13, 2025
Signed-off-by: Peter Jin <pjin@nvidia.com>
Signed-off-by: Peter Jin <pjin@nvidia.com>
@pjin-nvidia
Copy link
Contributor Author

closing in favor of #918

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
CI:L1 Run doctests, unit tests, and functional tests CI Relating to CI
Projects
None yet
Development

Successfully merging this pull request may close these issues.

5 participants