Skip to content

quantile() input type error when using bf16 #3666

@gitabtion

Description

@gitabtion

Reproduction

Entropies is a bfloat16 tensor when training with bf16, but quantile() input tensor must be either float or double dtype.

entropy_threshold = torch.quantile(entropies.flatten(), self.token_entropy_percentile_threshold)

System Info

RuntimeError: quantile() input tensor must be either float or double dtype

Checklist

  • I have checked that my issue isn't already filed (see open issues)
  • I have included my system information
  • Any code provided is minimal, complete, and reproducible (more on MREs)
  • Any code provided is properly formatted in code blocks, (no screenshot, more on code blocks)
  • Any traceback provided is complete

Metadata

Metadata

Assignees

No one assigned

    Labels

    🏋 GRPORelated to GRPO🐛 bugSomething isn't working

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions