-
Notifications
You must be signed in to change notification settings - Fork 617
Kendall's Tau metric, based loosely on scipy. #2169
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from all commits
419c5bb
4058b33
4b294d7
c1e78b6
7828aa5
d50ad60
196fd07
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,195 @@ | ||
# Copyright 2019 The TensorFlow Authors. All Rights Reserved. | ||
# | ||
# Licensed under the Apache License, Version 2.0 (the "License"); | ||
# you may not use this file except in compliance with the License. | ||
# You may obtain a copy of the License at | ||
# | ||
# http://www.apache.org/licenses/LICENSE-2.0 | ||
# | ||
# Unless required by applicable law or agreed to in writing, software | ||
# distributed under the License is distributed on an "AS IS" BASIS, | ||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
# See the License for the specific language governing permissions and | ||
# limitations under the License. | ||
# ============================================================================== | ||
"""Implements Kendall's Tau metric and loss.""" | ||
|
||
import warnings | ||
|
||
import numpy as np | ||
import tensorflow as tf | ||
from tensorflow_addons.metrics.utils import MeanMetricWrapper | ||
from tensorflow_addons.utils.types import TensorLike | ||
|
||
|
||
def _iterative_mergesort(y: TensorLike, aperm: TensorLike) -> (tf.int32, tf.Tensor): | ||
"""Non-recusive mergesort that counts exchanges. | ||
|
||
Args: | ||
y: values to be sorted. | ||
aperm: original ordering. | ||
|
||
Returns: | ||
A tuple consisting of a int32 scalar that counts the number of | ||
exchanges required to produce a sorted permutation, and a tf.int32 | ||
Tensor that contains the ordering of y values that are sorted. | ||
""" | ||
exchanges = 0 | ||
num = tf.size(y) | ||
k = tf.constant(1, tf.int32) | ||
while tf.less(k, num): | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Since TFP is a general purpose library, we need to write everything so that it does not presume eager execution. Minimally this means no python controlflow (eg, More generally, this has a code nature that would be quite inefficient in a SIMD regime. Perhaps we could examine the full quadratic comparison? Eg,
In general this would be quite large. I therefore recommend extending this idea to "chunks" using a Thoughts? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Only after I finished writing this did I fully examine the implementation of the AUC metrics in Keras, which use a bucketized approximation, not that dissimilar to the two-dimensional bucket approach presented in https://arxiv.org/abs/1712.01521 which would be much more amenable to streaming. (OTOH, I'm unclear why Metrics need to be streaming, usually the classifier labels is a space much smaller than the input features.) I also note that I should have looked at the completely rewritten scipy implementation which works differently, but still requires sorting by both of the full inputs which seems largely incompatible with tf Metrics design. I suspect a reasonable thing to do here is to probably fork this into two separate efforts, one for tfp which would focus on the explicit exact solution with potential backoff to approximate, and a tensorflow addons submission that focuses on implementing a streaming approximation. I think I can manage that, but it's going to take some time - and at present I don't have a good sense of which of these would be easier to do first. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Hi - sorry I'm back to this finally - I think from the discussion this PR should be moved to tfp, However, I think I need to make a significant redesign to comply with the suggestion of removing the control flow. I'm not quite able to distill what the consensus is here for next steps, maybe I should do all of these:
|
||
for left in tf.range(0, num - k, 2 * k, dtype=tf.int32): | ||
rght = left + k | ||
rend = tf.minimum(rght + k, num) | ||
tmp = tf.TensorArray(dtype=tf.int32, size=num) | ||
m, i, j = 0, left, rght | ||
while tf.less(i, rght) and tf.less(j, rend): | ||
permij = aperm.gather([i, j]) | ||
yij = tf.gather(y, permij) | ||
if tf.less_equal(yij[0], yij[1]): | ||
tmp = tmp.write(m, permij[0]) | ||
i += 1 | ||
else: | ||
tmp = tmp.write(m, permij[1]) | ||
# Explanation here | ||
# https://www.geeksforgeeks.org/counting-inversions/. | ||
exchanges += rght - i | ||
j += 1 | ||
m += 1 | ||
while tf.less(i, rght): | ||
tmp = tmp.write(m, aperm.read(i)) | ||
i += 1 | ||
m += 1 | ||
while tf.less(j, rend): | ||
tmp = tmp.write(m, aperm.read(j)) | ||
j += 1 | ||
m += 1 | ||
aperm = aperm.scatter(tf.range(left, rend), tmp.gather(tf.range(0, m))) | ||
k *= 2 | ||
return exchanges, aperm.stack() | ||
|
||
|
||
def kendalls_tau(y_true: TensorLike, y_pred: TensorLike) -> tf.Tensor: | ||
"""Computes Kendall's Tau for two ordered lists. | ||
|
||
Kendall's Tau measures the correlation between ordinal rankings. This | ||
implementation is similar to the one used in scipy.stats.kendalltau. | ||
Args: | ||
y_true: A tensor that provides a true ordinal ranking of N items. | ||
y_pred: A presumably model provided ordering of the same N items: | ||
|
||
Returns: | ||
Kendell's Tau, the 1945 tau-b formulation that ignores ordering of | ||
ties, as a scalar Tensor. | ||
""" | ||
y_true = tf.reshape(y_true, [-1]) | ||
y_pred = tf.reshape(y_pred, [-1]) | ||
y_pred.shape.assert_is_compatible_with(y_true.shape) | ||
if tf.equal(tf.size(y_true), 0) or tf.equal(tf.size(y_pred), 0): | ||
warnings.warn("y_true and y_pred tensors are not the same size.") | ||
return np.nan | ||
sorensenjs marked this conversation as resolved.
Show resolved
Hide resolved
|
||
perm = tf.argsort(y_true) | ||
n = tf.shape(perm)[0] | ||
if tf.less(n, 2): | ||
warnings.warn("Scalar tensors have no defined ordering.") | ||
return np.nan | ||
sorensenjs marked this conversation as resolved.
Show resolved
Hide resolved
|
||
|
||
left = 0 | ||
# scan for ties, and for each range of ties do a argsort on | ||
# the y_pred value. (TF has no lexicographical sorting, although | ||
# jax can sort complex number lexicographically. Hmm.) | ||
lexi = tf.TensorArray(tf.int32, size=n) | ||
for i in tf.range(n): | ||
lexi = lexi.write(i, perm[i]) | ||
for right in tf.range(1, n): | ||
ytruelr = tf.gather(y_true, tf.gather(perm, [left, right])) | ||
if tf.not_equal(ytruelr[0], ytruelr[1]): | ||
sub = perm[left:right] | ||
subperm = tf.argsort(tf.gather(y_pred, sub)) | ||
lexi = lexi.scatter(tf.range(left, right), tf.gather(sub, subperm)) | ||
left = right | ||
sub = perm[left:n] | ||
subperm = tf.argsort(tf.gather(y_pred, perm[left:n])) | ||
lexi.scatter(tf.range(left, n), tf.gather(sub, subperm)) | ||
|
||
# This code follows roughly along with scipy/stats/stats.py v. 0.15.1 | ||
# compute joint ties | ||
first = 0 | ||
t = 0 | ||
for i in tf.range(1, n): | ||
permfirsti = lexi.gather([first, i]) | ||
y_truefirsti = tf.gather(y_true, permfirsti) | ||
y_predfirsti = tf.gather(y_pred, permfirsti) | ||
if y_truefirsti[0] != y_truefirsti[1] or y_predfirsti[0] != y_predfirsti[1]: | ||
t += ((i - first) * (i - first - 1)) // 2 | ||
first = i | ||
t += ((n - first) * (n - first - 1)) // 2 | ||
|
||
# compute ties in y_true | ||
first = 0 | ||
u = 0 | ||
for i in tf.range(1, n): | ||
y_truefirsti = tf.gather(y_true, lexi.gather([first, i])) | ||
if y_truefirsti[0] != y_truefirsti[1]: | ||
u += ((i - first) * (i - first - 1)) // 2 | ||
first = i | ||
u += ((n - first) * (n - first - 1)) // 2 | ||
|
||
# count exchanges | ||
exchanges, newperm = _iterative_mergesort(y_pred, lexi) | ||
# compute ties in y_pred after mergesort with counting | ||
first = 0 | ||
v = 0 | ||
for i in tf.range(1, n): | ||
y_predfirsti = tf.gather(y_pred, tf.gather(newperm, [first, i])) | ||
if y_predfirsti[0] != y_predfirsti[1]: | ||
v += ((i - first) * (i - first - 1)) // 2 | ||
first = i | ||
v += ((n - first) * (n - first - 1)) // 2 | ||
|
||
tot = (n * (n - 1)) // 2 | ||
if tf.equal(tot, u) or tf.equal(tot, v): | ||
return np.nan # Special case for all ties in both ranks | ||
|
||
# Prevent overflow; equal to np.sqrt((tot - u) * (tot - v)) | ||
denom = tf.math.exp( | ||
0.5 | ||
* ( | ||
tf.math.log(tf.cast(tot - u, tf.float32)) | ||
+ tf.math.log(tf.cast(tot - v, tf.float32)) | ||
) | ||
) | ||
tau = ( | ||
tf.cast(tot - (v + u - t), tf.float32) - 2.0 * tf.cast(exchanges, tf.float32) | ||
) / denom | ||
|
||
return tau | ||
|
||
|
||
class KendallsTau(MeanMetricWrapper): | ||
"""Computes how well a model orders items, computing mean-tau for batches. | ||
|
||
Any types supported by tf.math.less may be used for y_pred and y_true | ||
values, and these types do not need to be the same as they are never | ||
compared against each other. The return type of this metric is always | ||
tf.float32 and a value between -1.0 and 1.0. | ||
|
||
References: | ||
"A Note on Average Tau as a Measure of Concordance", William L Hays, | ||
Journal of the American Statistical Assoc, Jun 1960, V55 N290 p. 331-341. | ||
|
||
"Statistical Properties of Average Kendall's Tau Under Multivariate | ||
Contaminated Gaussian Model", Huadong Lai and Weichao Xu, IEEE Access, V7, | ||
p.159177-159189, 2019. | ||
|
||
Note that there is a streaming implementation of an approximate algorithm, | ||
but this is different from the one implemented here, see: | ||
"An Online Algorithm for Nonparametric Correlations", Wei Xiao, 2017, | ||
https://arxiv.org/abs/1712.01521. | ||
|
||
Attributes: | ||
name: (Optional) string name of the metric instance. | ||
""" | ||
|
||
def __init__(self, name: str = "kendalls_tau", **kwargs): | ||
super().__init__(kendalls_tau, name=name, dtype=tf.float32) |
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,89 @@ | ||
# Copyright 2019 The TensorFlow Authors. All Rights Reserved. | ||
# | ||
# Licensed under the Apache License, Version 2.0 (the "License"); | ||
# you may not use this file except in compliance with the License. | ||
# You may obtain a copy of the License at | ||
# | ||
# http://www.apache.org/licenses/LICENSE-2.0 | ||
# | ||
# Unless required by applicable law or agreed to in writing, software | ||
# distributed under the License is distributed on an "AS IS" BASIS, | ||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
# See the License for the specific language governing permissions and | ||
# limitations under the License. | ||
# ============================================================================== | ||
"""Tests Hamming metrics.""" | ||
|
||
import random | ||
|
||
import numpy as np | ||
from scipy import stats | ||
|
||
import tensorflow as tf | ||
from tensorflow.keras import layers | ||
|
||
from tensorflow_addons.metrics import KendallsTau, kendalls_tau | ||
|
||
|
||
def test_config(): | ||
kl_obj = KendallsTau() | ||
assert kl_obj.name == "kendalls_tau" | ||
assert kl_obj.dtype == tf.float32 | ||
|
||
|
||
def test_kendall_tau(): | ||
x1 = [12, 2, 1, 12, 2] | ||
x2 = [1, 4, 7, 1, 0] | ||
expected = stats.kendalltau(x1, x2)[0] | ||
res = kendalls_tau(tf.constant(x1, tf.float32), tf.constant(x2, tf.float32)) | ||
np.testing.assert_allclose(expected, res.numpy(), atol=1e-5) | ||
|
||
|
||
def test_kendall_tau_float(): | ||
x1 = [0.12, 0.02, 0.01, 0.12, 0.02] | ||
x2 = [0.1, 0.4, 0.7, 0.1, 0.0] | ||
expected = stats.kendalltau(x1, x2)[0] | ||
res = kendalls_tau(tf.constant(x1, tf.float32), tf.constant(x2, tf.float32)) | ||
np.testing.assert_allclose(expected, res.numpy(), atol=1e-5) | ||
|
||
|
||
def test_kendall_random_lists(): | ||
left = [1, 1, 1, 2, 2, 2, 3, 3, 3, 4, 4, 4, 5, 5, 6, 7, 8, 9] | ||
for _ in range(10): | ||
right = random.sample(left, len(left)) | ||
expected = stats.kendalltau(left, right)[0] | ||
res = kendalls_tau( | ||
tf.constant(left, tf.float32), tf.constant(right, tf.float32) | ||
) | ||
np.testing.assert_allclose(expected, res.numpy(), atol=1e-5) | ||
|
||
|
||
def test_keras_model(): | ||
model = tf.keras.Sequential() | ||
model.add(layers.InputLayer(input_shape=(1,))) | ||
model.add(layers.Dense(1, kernel_initializer="ones")) | ||
kt = KendallsTau() | ||
model.compile(optimizer="rmsprop", loss="mae", metrics=[kt]) | ||
data = np.array([[0.12], [0.02], [0.01], [0.12], [0.02]]) | ||
labels = np.array([0.1, 0.4, 0.7, 0.1, 0.0]) | ||
history = model.fit(data, labels, epochs=1, batch_size=5, verbose=0) | ||
expected = stats.kendalltau(np.array(data).flat, labels)[0] | ||
np.testing.assert_allclose(expected, history.history["kendalls_tau"], atol=1e-5) | ||
|
||
|
||
def test_averaging_tau_model(): | ||
model = tf.keras.Sequential() | ||
model.add(layers.InputLayer(input_shape=(1,))) | ||
model.add(layers.Dense(1, kernel_initializer="ones")) | ||
kt = KendallsTau() | ||
model.compile(optimizer="rmsprop", loss="mae", metrics=[kt]) | ||
data = np.array([[5], [3], [2], [1], [4], [1], [2], [3], [4], [5]]) | ||
labels = np.array([1, 2, 2, 3, 6, 10, 11, 12, 13, 14]) | ||
history = model.fit(data, labels, epochs=1, batch_size=5, verbose=0, shuffle=False) | ||
expected = np.mean( | ||
[ | ||
stats.kendalltau(data[0:5].flat, labels[0:5])[0], | ||
stats.kendalltau(data[5:].flat, labels[5:])[0], | ||
] | ||
) | ||
np.testing.assert_allclose(expected, history.history["kendalls_tau"], atol=1e-5) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
2020