-
Notifications
You must be signed in to change notification settings - Fork 117
Closed
Labels
bugSomething isn't workingSomething isn't working
Description
Describe the bug
In nemo_reinforcer/algorithms/loss_functions.py
, there's a potential risk of getting NaN values when the mask tensor contains all zeros. This occurs in the following code section:
mult_prob_error = ((torch.exp(lp_error) * mask).sum() / mask.sum()).item()
...
with torch.no_grad():
probs_ratio = masked_mean(ratios.detach(), mask).item()
probs_ratio_clamped = masked_mean(ratios_clamped.detach(), mask).item()
If mask
contains all zeros, the masked_mean function would likely perform a division by zero (since it's calculating a mean over masked elements), resulting in NaN values.
Steps/Code to reproduce bug
NA
Expected behavior
Add a check to handle the case when the mask is all zeros:
with torch.no_grad():
if mask.sum() > 0:
probs_ratio = masked_mean(ratios.detach(), mask).item()
probs_ratio_clamped = masked_mean(ratios_clamped.detach(), mask).item()
else:
probs_ratio = 0.0 # Default value when mask is all zeros
probs_ratio_clamped = 0.0 # Default value when mask is all zeros
Environment overview (please complete the following information)
- Environment location: Docker
- Method of install: [pip install or from source]. uv pip install -e '.[dev,test]'
Metadata
Metadata
Assignees
Labels
bugSomething isn't workingSomething isn't working