Skip to content

tree_map doesn't work over return_types #75218

@ezyang

Description

@ezyang

🐛 Describe the bug

import torch
from torch.utils._pytree import tree_map

print(tree_map(lambda a: None, torch.cummin(torch.randn(3), 0)))

This prints None. I expected to get torch.return_types.cummin(values=None, indices=None)

cc @zou3519

Versions

master

Metadata

Metadata

Assignees

Labels

module: pytreetriagedThis issue has been looked at a team member, and triaged and prioritized into an appropriate module

Type

No type

Projects

No projects

Milestone

No milestone

Relationships

None yet

Development

No branches or pull requests

Issue actions