-
-
Notifications
You must be signed in to change notification settings - Fork 394
Global Precision/Recall/F1 Callback #433
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
Keeps track of global true positives, false positives, and false negatives for each epoch and calculates precision, recall, and F1-score based on those metrics. Currently, for binary cases only (use multiple instances for multi-label).
Calculates the global precision (positive predictive value or ppv), recall (true positive rate or tpr), and F1-score per class for each loader. Currently, supports binary and multi-label cases.
…global-metrics Updating with the latest changes to remove possibility of git diff bugs
For PrecisionRecallF1ScoreCallback
Hi! Thank you for your PR! |
Yep, I'll give it a go tonight! |
Quick self-reminder: Edit: Fixed |
Passes `make codestyle`.
…ending the metric This caused metrics to be averaged incorrectly because we were just calling np.mean(float) instead of np.mean(list of floats)
def on_batch_end(self, state: RunnerState): | ||
logits: torch.Tensor = state.output[self.output_key].detach().float() | ||
targets: torch.Tensor = state.input[self.input_key].detach().float() | ||
probabilities: torch.Tensor = torch.sigmoid(logits) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think, we could parametrise the used function during callback initialization
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I might be misinterpreting what you're saying, but are you asking to make a reusable function to do the ops above and initialize it as an attribute during callback initialization?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I mean, you can parametrize activation function like here
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think it would be more consistent, but I don't think PrecisionRecallF1ScoreCallback
should inherit MetricCallback
/MultiMetricCallback
because PrecisionRecallF1ScoreCallback
calculates the metrics for the entire loader instead of for each batch (maybe make a separate base global callback?). I'm curious on your thoughts?
(I'll add the option to specify the activation function but without the super.init(...)
in the meantime)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think we can create new abstraction like "DatasetMetricCallback", that need need to collect statistic during on_batch_end
and also some additional calculations on_loader_end
. We have the same case with AUCCallback, so... looks like we need something general :)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
#450 <- Related PR
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I'll try to get to it tomorrow!
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Do we need to add .cpu()
? Otherwise it throws error:
TypeError: can't convert CUDA tensor to numpy. Use Tensor.cpu() to copy the tensor to host memory first.
If you have any questions – do not hesitate to ask :) |
…ils.get_activation_fn
…ds for PrecisionRecallF1ScoreMeter Did so for increased clarity.
…oreMeter Still need to create tests for the callback, but this is a rough start.
@jchen42703 PR looks good, the only thing – check the codestyle please. |
@Scitator Should I add in a test for the callback (not just the meter + metrics) as well? |
def on_batch_end(self, state: RunnerState): | ||
logits: torch.Tensor = state.output[self.output_key].detach().float() | ||
targets: torch.Tensor = state.input[self.input_key].detach().float() | ||
activation_fn = get_activation_fn(self.activation) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think you can move it to init
A callback that tracks metrics through meters and prints metrics for each class on `state.on_loader_end`. **Have not tested yet**
Gonna try to adapt |
Cleaner, reduces repeat code.
Cleaner, reduces repeat code
They work for binary ( |
@jchen42703 You are right, currently we do not support Meanwhile, |
Gonna try to do some last minute cleanup:
Edit:
|
Removed unnecessary imports + overindented
Did so because catalyst does not currently support self.num_classes == 1 and len(probabilities.shape) == 1; tensors have no channels and you get a bunch of CUDA and memory pinning errors.
…back` Did so because doing so makes it clear to anyone implementing a child of `MeterMetricsCallback` that you need to specify `class_names` and `num_classes` for the callback to work properly. Also added a check for num_classes == 1 (should be > 1).
…zation changes (class_names and num_classes) Did so for PrecisionRecallF1ScoreCallback and AUCCallback
Awesome PR! |
Description
PrecisionRecallF1ScoreMeter:
tracks TP, FP, and FN for each loader, and calculates precision (ppv), recall (tpr), and f1-score, based on those metrics.PrecisionRecallF1ScoreCallback:
callback forPrecisionRecallF1ScoreMeter
. Modeled after the AUCCallback & AccuracyCallback (multiple metrics).Related Issue
N/A
Type of Change
Checklist
make check-style
.make check-docs
.