Skip to content

Global Precision/Recall/F1 Callback #433

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 21 commits into from
Oct 29, 2019

Conversation

jchen42703
Copy link
Contributor

@jchen42703 jchen42703 commented Oct 10, 2019

Description

  • PrecisionRecallF1ScoreMeter: tracks TP, FP, and FN for each loader, and calculates precision (ppv), recall (tpr), and f1-score, based on those metrics.
  • PrecisionRecallF1ScoreCallback: callback for PrecisionRecallF1ScoreMeter. Modeled after the AUCCallback & AccuracyCallback (multiple metrics).
  • Example logs for a 4 class multi-label classifier

Related Issue

N/A

Type of Change

  • Examples / docs / tutorials / contributors update
  • Bug fix (non-breaking change which fixes an issue)
  • Improvement (non-breaking change which improves an existing feature)
  • New feature (non-breaking change which adds functionality)
  • Breaking change (fix or feature that would cause existing functionality to change)

Checklist

  • I have read the CODE_OF_CONDUCT document.
  • I have read the CONTRIBUTING document.
  • I have checked the code-style using make check-style.
  • I have written the docstring in Google format for all the method and classes that I used.
  • I have checked the docs using make check-docs.

Keeps track of global true positives, false positives, and false negatives for each epoch and calculates precision, recall, and F1-score based on those metrics. Currently, for binary cases only (use multiple instances for multi-label).
Calculates the global precision (positive predictive value or ppv), recall (true positive rate or tpr), and F1-score per class for each loader. Currently, supports binary and multi-label cases.
…global-metrics

Updating with the latest changes to remove possibility of git diff bugs
For PrecisionRecallF1ScoreCallback
@TezRomacH TezRomacH added enhancement New feature or request Hacktoberfest labels Oct 10, 2019
@TezRomacH
Copy link
Contributor

Hi! Thank you for your PR!
Can you please run the make codestyle command to convert the code to the style used in Catalyst and make our CI tests pass?

@jchen42703
Copy link
Contributor Author

Yep, I'll give it a go tonight!

@jchen42703
Copy link
Contributor Author

jchen42703 commented Oct 10, 2019

Quick self-reminder:
Double check that the average values are computed correctly (prec_recall_f1score[prefix] = metric_ v. prec_recall_f1score[prefix].append(metric_))

Edit: Fixed

Passes `make codestyle`.
…ending the metric

This caused metrics to be averaged incorrectly because we were just calling np.mean(float) instead of np.mean(list of floats)
def on_batch_end(self, state: RunnerState):
logits: torch.Tensor = state.output[self.output_key].detach().float()
targets: torch.Tensor = state.input[self.input_key].detach().float()
probabilities: torch.Tensor = torch.sigmoid(logits)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think, we could parametrise the used function during callback initialization

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I might be misinterpreting what you're saying, but are you asking to make a reusable function to do the ops above and initialize it as an attribute during callback initialization?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I mean, you can parametrize activation function like here

Copy link
Contributor Author

@jchen42703 jchen42703 Oct 12, 2019

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think it would be more consistent, but I don't think PrecisionRecallF1ScoreCallback should inherit MetricCallback/MultiMetricCallback because PrecisionRecallF1ScoreCallback calculates the metrics for the entire loader instead of for each batch (maybe make a separate base global callback?). I'm curious on your thoughts?
(I'll add the option to specify the activation function but without the super.init(...) in the meantime)

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think we can create new abstraction like "DatasetMetricCallback", that need need to collect statistic during on_batch_end and also some additional calculations on_loader_end. We have the same case with AUCCallback, so... looks like we need something general :)

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

#450 <- Related PR

Copy link
Contributor Author

@jchen42703 jchen42703 Oct 16, 2019

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'll try to get to it tomorrow!

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Do we need to add .cpu()? Otherwise it throws error:
TypeError: can't convert CUDA tensor to numpy. Use Tensor.cpu() to copy the tensor to host memory first.

@Scitator
Copy link
Member

If you have any questions – do not hesitate to ask :)

@Scitator
Copy link
Member

@jchen42703 PR looks good, the only thing – check the codestyle please.

@jchen42703
Copy link
Contributor Author

@Scitator Should I add in a test for the callback (not just the meter + metrics) as well?

def on_batch_end(self, state: RunnerState):
logits: torch.Tensor = state.output[self.output_key].detach().float()
targets: torch.Tensor = state.input[self.input_key].detach().float()
activation_fn = get_activation_fn(self.activation)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think you can move it to init

A callback that tracks metrics through meters and prints metrics for each class on `state.on_loader_end`. **Have not tested yet**
@jchen42703
Copy link
Contributor Author

Gonna try to adapt PrecisionRecallF1ScoreCallback and AUCCallback to MeterMetricsCallback tomorrow and test to see if it works properly

@jchen42703
Copy link
Contributor Author

They work for binary (num_classes=2) and multi-label cases, but not when num_classes=1, so I'm considering dropping that portion of MeterMetricsCallback.
Also, I'm considering refactoring confusionmeter.py using #450's catalyst.utils.confusion_matrix.py to track all of the stats so that we can expand this class to #450's callbacks. I could be overcomplicating this though.

@Scitator
Copy link
Member

Scitator commented Oct 18, 2019

@jchen42703 You are right, currently we do not support num_classes==1, so I think it's good idea to add assert num_classes > 1 for now.

Meanwhile, MeterMetricsCallback looks really good 👍

@jchen42703
Copy link
Contributor Author

jchen42703 commented Oct 20, 2019

Gonna try to do some last minute cleanup:

Edit:

Removed unnecessary imports + overindented
Did so because catalyst does not currently support self.num_classes == 1 and len(probabilities.shape) == 1; tensors have no channels and you get a bunch of CUDA and memory pinning errors.
…back`

Did so because doing so makes it clear to anyone implementing a child of `MeterMetricsCallback` that you need to specify `class_names` and `num_classes` for the callback to work properly.
Also added a check for num_classes == 1 (should be > 1).
…zation changes (class_names and num_classes)

Did so for PrecisionRecallF1ScoreCallback and AUCCallback
@jchen42703 jchen42703 requested a review from Scitator October 23, 2019 02:37
@Scitator Scitator merged commit c6ea0fc into catalyst-team:master Oct 29, 2019
@Scitator
Copy link
Member

Awesome PR!

@smivv smivv mentioned this pull request Oct 30, 2019
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
enhancement New feature or request Hacktoberfest
Projects
None yet
Development

Successfully merging this pull request may close these issues.

4 participants