Skip to content

[Ray Tune] train_fn without checkpoint_dir kwarg raises a false deprecation warning  #41562

@AnirudhDagar

Description

@AnirudhDagar

What happened + What you expected to happen

I've a question regarding how _detect_checkpoint_function is being used for raising the deprecation warning if a checkpoint_dir kwarg is present in the trainable function.

In case the user setup their trainable (train function/train_fn) with multiple parameters (see MRE snippet below) that are being passed using ray.tune.with_parameters, it will raise the deprecation warning everytime even when the kwargs of the trainable function specifically don't have checkpoint_dir in them.

Reproduction script

from ray import tune

trainable_kwargs = {'kwarg1': 1, 'kwarg2': 2}
def train_fn(config, **trainable_kwargs):
  pass

tuner = tune.Tuner(
    tune.with_parameters(train_fn, **trainable_kwargs),
)

Reproduction script output

In [1]: from ray import train, tune
   ...:
   ...: trainable_kwargs = {'kwarg1': 1, 'kwarg2': 2}
   ...:
   ...: def train_fn(config, **trainable_kwargs):
   ...:   pass
   ...:
   ...: tuner = tune.Tuner(
   ...:     tune.with_parameters(train_fn, **trainable_kwargs),
   ...: )
---------------------------------------------------------------------------
DeprecationWarning                        Traceback (most recent call last)
Cell In[1], line 9
      5 def train_fn(config, **trainable_kwargs):
      6   pass
      8 tuner = tune.Tuner(
----> 9     tune.with_parameters(train_fn, **trainable_kwargs),
     10 )

File /opt/conda/envs/ag_py311/lib/python3.11/site-packages/ray/tune/trainable/util.py:315, in with_parameters(trainable, **kwargs)
    310 if _detect_checkpoint_function(trainable, partial=True):
    311     from ray.tune.trainable.function_trainable import (
    312         _CHECKPOINT_DIR_ARG_DEPRECATION_MSG,
    313     )
--> 315     raise DeprecationWarning(_CHECKPOINT_DIR_ARG_DEPRECATION_MSG)
    317 def inner(config):
    318     fn_kwargs = {}

DeprecationWarning: Accepting a `checkpoint_dir` argument in your training function is deprecated.
Please use `ray.train.get_checkpoint()` to access your checkpoint as a
`ray.train.Checkpoint` object instead. See below for an example:

Before
------

from ray import tune

def train_fn(config, checkpoint_dir=None):
    if checkpoint_dir:
        torch.load(os.path.join(checkpoint_dir, "checkpoint.pt"))
    ...

tuner = tune.Tuner(train_fn)
tuner.fit()

After
-----

from ray import train, tune

def train_fn(config):
    checkpoint: train.Checkpoint = train.get_checkpoint()
    if checkpoint:
        with checkpoint.as_directory() as checkpoint_dir:
            torch.load(os.path.join(checkpoint_dir, "checkpoint.pt"))
    ...

tuner = tune.Tuner(train_fn)
tuner.fit()

Is this is not a possible use case, because atm it is causing a deprecation warning even though I'm not passing any checkpoint_dir argument. I think the logic is incorrect here. This code snippet worked perfectly fine with ray==2.6.3. I can reproduce the issue starting ray>=2.7.0

Versions / Dependencies

Python 3.11.0
ray==2.8.1
Distributor ID: Ubuntu
Description: Ubuntu 20.04.6 LTS
Release: 20.04
Codename: focal

Metadata

Metadata

Assignees

Labels

P2Important issue, but not time-criticalbugSomething that is supposed to be working; but isn'ttuneTune-related issues

Type

No type

Projects

No projects

Milestone

No milestone

Relationships

None yet

Development

No branches or pull requests

Issue actions