-
-
Notifications
You must be signed in to change notification settings - Fork 394
Closed
Labels
enhancementNew feature or requestNew feature or requesthelp wantedExtra attention is neededExtra attention is needed
Description
π Feature Request
Return the resume support for the CheckpointCallback
and Runner.train
, Runner.predict
.
Motivation
That would be a great user-friendly feature β resume is quite common task during deep learning development.
Proposal
- Uncommen the CheckpointCallback code and check its correctness.... from my perspective, it should be refactored a bit.
- Add its support to Runner.
- Uncomment the cli interface and update the ConfigRunner and HydraRunner.
Proposed use case:
import os
from torch import nn, optim
from torch.utils.data import DataLoader
from catalyst import dl, utils
from catalyst.data.transforms import ToTensor
from catalyst.contrib.datasets import MNIST
model = nn.Sequential(nn.Flatten(), nn.Linear(28 * 28, 10))
criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(model.parameters(), lr=0.02)
loaders = {
"train": DataLoader(
MNIST(os.getcwd(), train=True, download=True, transform=ToTensor()), batch_size=32
),
"valid": DataLoader(
MNIST(os.getcwd(), train=False, download=True, transform=ToTensor()), batch_size=32
),
}
runner = dl.SupervisedRunner(
input_key="features", output_key="logits", target_key="targets", loss_key="loss"
)
# model training
runner.train(
model=model,
criterion=criterion,
optimizer=optimizer,
loaders=loaders,
num_epochs=1,
callbacks=[
dl.AccuracyCallback(input_key="logits", target_key="targets", topk_args=(1, 3, 5)),
dl.PrecisionRecallF1SupportCallback(
input_key="logits", target_key="targets", num_classes=10
),
dl.AUCCallback(input_key="logits", target_key="targets"),
# catalyst[ml] required ``pip install catalyst[ml]``
# dl.ConfusionMatrixCallback(input_key="logits", target_key="targets", num_classes=10),
],
logdir="./logs",
valid_loader="valid",
valid_metric="loss",
minimize_valid_metric=True,
verbose=True,
load_best_on_end=True,
)
# here is the trick
runner.train(
model=model,
criterion=criterion,
optimizer=optimizer,
loaders=loaders,
num_epochs=1,
callbacks=[
dl.AccuracyCallback(input_key="logits", target_key="targets", topk_args=(1, 3, 5)),
dl.PrecisionRecallF1SupportCallback(
input_key="logits", target_key="targets", num_classes=10
),
dl.AUCCallback(input_key="logits", target_key="targets"),
# catalyst[ml] required ``pip install catalyst[ml]``
# dl.ConfusionMatrixCallback(input_key="logits", target_key="targets", num_classes=10),
],
# ----
logdir="./logs2",
resume="./logs/checkpoints/train.1.pth",
# ----
valid_loader="valid",
valid_metric="loss",
minimize_valid_metric=True,
verbose=True,
load_best_on_end=True,
)
Alternatives
Additional context
Checklist
- feature proposal description
- motivation
- extra proposal context / proposal alternatives review
FAQ
Please review the FAQ before submitting an issue:
- I have read the documentation and FAQ
- I have reviewed the minimal examples section
- I have checked the changelog for main framework updates
- I have read the contribution guide
- I have joined Catalyst slack (#__questions channel) for issue discussion
Metadata
Metadata
Labels
enhancementNew feature or requestNew feature or requesthelp wantedExtra attention is neededExtra attention is needed