Skip to content

Potential NaN risk in loss_functions.py when mask is all zeros #158

@zpqiu

Description

@zpqiu

Describe the bug

In nemo_reinforcer/algorithms/loss_functions.py, there's a potential risk of getting NaN values when the mask tensor contains all zeros. This occurs in the following code section:

mult_prob_error = ((torch.exp(lp_error) * mask).sum() / mask.sum()).item()

...

with torch.no_grad():
    probs_ratio = masked_mean(ratios.detach(), mask).item()
    probs_ratio_clamped = masked_mean(ratios_clamped.detach(), mask).item()

If mask contains all zeros, the masked_mean function would likely perform a division by zero (since it's calculating a mean over masked elements), resulting in NaN values.

Steps/Code to reproduce bug

NA

Expected behavior

Add a check to handle the case when the mask is all zeros:

with torch.no_grad():
    if mask.sum() > 0:
        probs_ratio = masked_mean(ratios.detach(), mask).item()
        probs_ratio_clamped = masked_mean(ratios_clamped.detach(), mask).item()
    else:
        probs_ratio = 0.0  # Default value when mask is all zeros
        probs_ratio_clamped = 0.0  # Default value when mask is all zeros

Environment overview (please complete the following information)

  • Environment location: Docker
  • Method of install: [pip install or from source]. uv pip install -e '.[dev,test]'

Metadata

Metadata

Assignees

Labels

bugSomething isn't working

Type

No type

Projects

No projects

Milestone

No milestone

Relationships

None yet

Development

No branches or pull requests

Issue actions