-
Notifications
You must be signed in to change notification settings - Fork 1.2k
Description
Problem: [R] Cox loss function should take "negative log-likelihood" instead of "log-likelihood"
catboost version: 1.2.5
Operating System: Windows
Dear maintainers,
I need to use CatBoost v1.2.5 to train a Cox model. I've read catboost's tutorial: Survival analysis with Catboost and test the example from it. The example of the tutorial works in Python, but I am familiar with R, so I have translated the code into R:
library(catboost)
packageVersion("catboost")
# [1] ‘1.2.5’
# Load dataset "flchain", same as the tutorial example
data(flchain, package = "survival")
# Create label for Cox objective function
flchain['label'] <- ifelse(flchain[['death']] == 1, flchain[['futime']], -1 * flchain[['futime']])
# Split training and testing dataset
set.seed(1)
test_ind <- sample(nrow(flchain), nrow(flchain) * 0.3)
train <- flchain[-test_ind, ]
test <- flchain[test_ind, ]
features <- c('age', 'creatinine', 'flc.grp', 'kappa', 'lambda', 'mgus', 'sample.yr', 'sex')
train_pool <- catboost.load_pool(train[features], label = train[['label']])
test_pool <- catboost.load_pool(test[features], label = test[['label']])
model <- catboost.train(train_pool,
test_pool = test_pool,
params = list(iterations = 100,
loss_function = 'Cox',
eval_metric = 'Cox',
metric_period = 50))
- Output
0: learn: -12463.1527910 test: -5042.8238785 best: -5042.8238785 (0) total: 2.5ms remaining: 248ms
50: learn: -12721.8966627 test: -5161.8085568 best: -5161.8085568 (50) total: 104ms remaining: 99.6ms
99: learn: -13036.1622355 test: -5305.4916415 best: -5305.4916415 (99) total: 202ms remaining: 0us
Problem 1: The goal is to Maximize "log-likelihood" or Minimize "negative log-likelihood", but in R it seems trying to Minimize "log-likelihood"...
In the tutorial, it said the function that CatBoost is trying to optimize is "Cox log partial likelihood". From the output, the metric values for "learn" and "test" are all negative numbers, so I think these metric values are the "log-likelihood" (because in statistical sense, log-likelihood are negative numbers). If so, the goal of iterations should maximize it. However, from the output, we can see that the iterations are trying to minimize it. (-12463 to -13036 in "learn" column)
Problem 2: Inconsistent output to v1.2.3
I have reverted the version to v1.2.3. With the same code above, the metric values for "learn" and "test" are all decreasing positive numbers. I think the original log-likelihoods are converted to opposite numbers, namely "negative log-likelihood". If so, the goal of iterations should minimize it. This version of CatBoost completely makes sense. What makes the difference between the implementations of v1.2.3 & v1.2.5?
0: learn: 12453.1683362 test: 5038.4024811 best: 5038.4024811 (0) total: 154ms remaining: 15.2s
50: learn: 12256.5191692 test: 4947.5292621 best: 4947.5292621 (50) total: 261ms remaining: 251ms
99: learn: 12109.6471344 test: 4880.2979213 best: 4880.2979213 (99) total: 367ms remaining: 0us
Problem 3: Prediction performance is very poor!
I calculate the concordance index (C-index) of testing set, on v1.2.3 & v1.2.5 respectively, using the same code. The later is very poor.
library(survival)
concordancefit(Surv(test$futime, test$death),
-catboost.predict(model, test_pool))$concordance
- v1.2.3: C-index = 0.76
- v1.2.5: C-index = 0.20
Problem 4: Metric value reaches NaN
When I use a moderately large number of iterations, the metric value reaches NaN
in v1.2.5.
model <- catboost.train(train_pool,
test_pool = test_pool,
params = list(iterations = 2500,
loss_function = 'Cox',
eval_metric = 'Cox',
metric_period = 100))
v1.2.5
2100: learn: -1467146.2738498 test: -676558.5489888 best: -676558.5489888 (2100) total: 4.14s remaining: 786ms
2200: learn: -1516008.6256723 test: -694665.6985130 best: -694665.6985130 (2200) total: 4.35s remaining: 591ms
2300: learn: nan test: nan best: -694665.6985130 (2200) total: 4.57s remaining: 395ms
2400: learn: nan test: nan best: -694665.6985130 (2200) total: 4.78s remaining: 197ms
2499: learn: nan test: nan best: -694665.6985130 (2200) total: 4.99s remaining: 0us
v1.2.3
With v1.2.3 the output looks very well.
2100: learn: 15208.7450328 test: 6225.6551569 best: 4774.6366039 (300) total: 4.36s remaining: 828ms
2200: learn: 15559.1376839 test: 6406.8914711 best: 4774.6366039 (300) total: 4.56s remaining: 619ms
2300: learn: 15576.2159165 test: 6405.9180922 best: 4774.6366039 (300) total: 4.75s remaining: 411ms
2400: learn: 15709.8865444 test: 6474.3107742 best: 4774.6366039 (300) total: 4.95s remaining: 204ms
2499: learn: 16104.1899851 test: 6581.9568144 best: 4774.6366039 (300) total: 5.14s remaining: 0us