Skip to content

Conversation

dakinggg
Copy link
Contributor

What does this PR do?

Adds map_location to some calls to torch.load

What issue(s) does this change relate to?

No Jira, just noticed the bug while doing other stuff.

@dakinggg dakinggg requested a review from a team as a code owner February 18, 2023 02:21
Copy link
Contributor

@mvpatel2000 mvpatel2000 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why is this necessary?

@dakinggg
Copy link
Contributor Author

When you torch.load, it tries to put the tensors back on the device they were saved from
[torch.load()](https://pytorch.org/docs/stable/generated/torch.load.html#torch.load) uses Python’s unpickling facilities but treats storages, which underlie tensors, specially. They are first deserialized on the CPU and are then moved to the device they were saved from. If this fails (e.g. because the run time system doesn’t have certain devices), an exception is raised. However, storages can be dynamically remapped to an alternative set of devices using the map_location argument.

This ends up with an error if you (e.g.) try to load something on cpu after it was saved on gpu, so everywhere we call torch.load we map to cpu. I just missed that piece in my original PR with these changes.

Copy link
Contributor

@eracah eracah left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM

@dakinggg dakinggg merged commit b16fb66 into mosaicml:dev Feb 21, 2023
@dakinggg dakinggg deleted the device_map branch June 1, 2023 00:14
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants