-
Notifications
You must be signed in to change notification settings - Fork 6.7k
Description
What happened + What you expected to happen
I've a question regarding how _detect_checkpoint_function is being used for raising the deprecation warning if a checkpoint_dir
kwarg is present in the trainable function.
In case the user setup their trainable (train function/train_fn) with multiple parameters (see MRE snippet below) that are being passed using ray.tune.with_parameters, it will raise the deprecation warning everytime even when the kwargs of the trainable function specifically don't have checkpoint_dir
in them.
Reproduction script
from ray import tune
trainable_kwargs = {'kwarg1': 1, 'kwarg2': 2}
def train_fn(config, **trainable_kwargs):
pass
tuner = tune.Tuner(
tune.with_parameters(train_fn, **trainable_kwargs),
)
Reproduction script output
In [1]: from ray import train, tune
...:
...: trainable_kwargs = {'kwarg1': 1, 'kwarg2': 2}
...:
...: def train_fn(config, **trainable_kwargs):
...: pass
...:
...: tuner = tune.Tuner(
...: tune.with_parameters(train_fn, **trainable_kwargs),
...: )
---------------------------------------------------------------------------
DeprecationWarning Traceback (most recent call last)
Cell In[1], line 9
5 def train_fn(config, **trainable_kwargs):
6 pass
8 tuner = tune.Tuner(
----> 9 tune.with_parameters(train_fn, **trainable_kwargs),
10 )
File /opt/conda/envs/ag_py311/lib/python3.11/site-packages/ray/tune/trainable/util.py:315, in with_parameters(trainable, **kwargs)
310 if _detect_checkpoint_function(trainable, partial=True):
311 from ray.tune.trainable.function_trainable import (
312 _CHECKPOINT_DIR_ARG_DEPRECATION_MSG,
313 )
--> 315 raise DeprecationWarning(_CHECKPOINT_DIR_ARG_DEPRECATION_MSG)
317 def inner(config):
318 fn_kwargs = {}
DeprecationWarning: Accepting a `checkpoint_dir` argument in your training function is deprecated.
Please use `ray.train.get_checkpoint()` to access your checkpoint as a
`ray.train.Checkpoint` object instead. See below for an example:
Before
------
from ray import tune
def train_fn(config, checkpoint_dir=None):
if checkpoint_dir:
torch.load(os.path.join(checkpoint_dir, "checkpoint.pt"))
...
tuner = tune.Tuner(train_fn)
tuner.fit()
After
-----
from ray import train, tune
def train_fn(config):
checkpoint: train.Checkpoint = train.get_checkpoint()
if checkpoint:
with checkpoint.as_directory() as checkpoint_dir:
torch.load(os.path.join(checkpoint_dir, "checkpoint.pt"))
...
tuner = tune.Tuner(train_fn)
tuner.fit()
Is this is not a possible use case, because atm it is causing a deprecation warning even though I'm not passing any checkpoint_dir
argument. I think the logic is incorrect here. This code snippet worked perfectly fine with ray==2.6.3
. I can reproduce the issue starting ray>=2.7.0
Versions / Dependencies
Python 3.11.0
ray==2.8.1
Distributor ID: Ubuntu
Description: Ubuntu 20.04.6 LTS
Release: 20.04
Codename: focal