Skip to content

Return resume support for CheckpointCallback and .train, predictΒ #1193

@Scitator

Description

@Scitator

πŸš€ Feature Request

Return the resume support for the CheckpointCallback and Runner.train, Runner.predict.

Motivation

That would be a great user-friendly feature – resume is quite common task during deep learning development.

Proposal

  1. Uncommen the CheckpointCallback code and check its correctness.... from my perspective, it should be refactored a bit.
  2. Add its support to Runner.
  3. Uncomment the cli interface and update the ConfigRunner and HydraRunner.

Proposed use case:

import os
from torch import nn, optim
from torch.utils.data import DataLoader
from catalyst import dl, utils
from catalyst.data.transforms import ToTensor
from catalyst.contrib.datasets import MNIST

model = nn.Sequential(nn.Flatten(), nn.Linear(28 * 28, 10))
criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(model.parameters(), lr=0.02)

loaders = {
    "train": DataLoader(
        MNIST(os.getcwd(), train=True, download=True, transform=ToTensor()), batch_size=32
    ),
    "valid": DataLoader(
        MNIST(os.getcwd(), train=False, download=True, transform=ToTensor()), batch_size=32
    ),
}

runner = dl.SupervisedRunner(
    input_key="features", output_key="logits", target_key="targets", loss_key="loss"
)
# model training
runner.train(
    model=model,
    criterion=criterion,
    optimizer=optimizer,
    loaders=loaders,
    num_epochs=1,
    callbacks=[
        dl.AccuracyCallback(input_key="logits", target_key="targets", topk_args=(1, 3, 5)),
        dl.PrecisionRecallF1SupportCallback(
            input_key="logits", target_key="targets", num_classes=10
        ),
        dl.AUCCallback(input_key="logits", target_key="targets"),
        # catalyst[ml] required ``pip install catalyst[ml]``
        # dl.ConfusionMatrixCallback(input_key="logits", target_key="targets", num_classes=10),
    ],
    logdir="./logs",
    valid_loader="valid",
    valid_metric="loss",
    minimize_valid_metric=True,
    verbose=True,
    load_best_on_end=True,
)

# here is the trick
runner.train(
    model=model,
    criterion=criterion,
    optimizer=optimizer,
    loaders=loaders,
    num_epochs=1,
    callbacks=[
        dl.AccuracyCallback(input_key="logits", target_key="targets", topk_args=(1, 3, 5)),
        dl.PrecisionRecallF1SupportCallback(
            input_key="logits", target_key="targets", num_classes=10
        ),
        dl.AUCCallback(input_key="logits", target_key="targets"),
        # catalyst[ml] required ``pip install catalyst[ml]``
        # dl.ConfusionMatrixCallback(input_key="logits", target_key="targets", num_classes=10),
    ],
    # ----
    logdir="./logs2",
    resume="./logs/checkpoints/train.1.pth",
    # ----
    valid_loader="valid",
    valid_metric="loss",
    minimize_valid_metric=True,
    verbose=True,
    load_best_on_end=True,
)

Alternatives

Additional context

Checklist

  • feature proposal description
  • motivation
  • extra proposal context / proposal alternatives review

FAQ

Please review the FAQ before submitting an issue:

Metadata

Metadata

Labels

enhancementNew feature or requesthelp wantedExtra attention is needed

Type

No type

Projects

No projects

Milestone

No milestone

Relationships

None yet

Development

No branches or pull requests

Issue actions