Skip to content

[FR] Support Automatic Mixed Precision training #3316

@austinv11

Description

@austinv11

Issue Description

Better support for mixed precision training would be extremely helpful, at least for SVI. I can manually cast data into float16 or bfloat16 but I am unable to leverage PyTorch's automatic mixed precision training. This is because it requires the use of the GradScaler class during the optimization loop to properly scale gradients in a mixed-precision-aware manner. See the documentation for more info: https://pytorch.org/docs/stable/amp.html

It would be nice to have support for using this class within pyro optimizers to allow for amp support.

Metadata

Metadata

Assignees

No one assigned

    Labels

    enhancementhelp wantedIssues suitable for, and inviting external contributions

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions