Skip to content

[R / Bug] Cox loss function is trying to minimize "log-likelihood" instead of "negative log-likelihood" #2701

@darentsai

Description

@darentsai

Problem: [R] Cox loss function should take "negative log-likelihood" instead of "log-likelihood"
catboost version: 1.2.5
Operating System: Windows


Dear maintainers,

I need to use CatBoost v1.2.5 to train a Cox model. I've read catboost's tutorial: Survival analysis with Catboost and test the example from it. The example of the tutorial works in Python, but I am familiar with R, so I have translated the code into R:

library(catboost)
packageVersion("catboost")
# [1] ‘1.2.5’

# Load dataset "flchain", same as the tutorial example
data(flchain, package = "survival")

# Create label for Cox objective function
flchain['label'] <- ifelse(flchain[['death']] == 1, flchain[['futime']], -1 * flchain[['futime']])

# Split training and testing dataset
set.seed(1)
test_ind <- sample(nrow(flchain), nrow(flchain) * 0.3)
train <- flchain[-test_ind, ]
test <- flchain[test_ind, ]
features <- c('age', 'creatinine', 'flc.grp', 'kappa', 'lambda', 'mgus', 'sample.yr', 'sex')

train_pool <- catboost.load_pool(train[features], label = train[['label']])
test_pool <- catboost.load_pool(test[features], label = test[['label']])

model <- catboost.train(train_pool,
               test_pool = test_pool,
               params = list(iterations = 100,
                             loss_function = 'Cox',
                             eval_metric = 'Cox',
                             metric_period = 50))
  • Output
0:	learn: -12463.1527910	test: -5042.8238785	best: -5042.8238785 (0)	        total: 2.5ms	remaining: 248ms
50:	learn: -12721.8966627	test: -5161.8085568	best: -5161.8085568 (50)	total: 104ms	remaining: 99.6ms
99:	learn: -13036.1622355	test: -5305.4916415	best: -5305.4916415 (99)	total: 202ms	remaining: 0us

Problem 1: The goal is to Maximize "log-likelihood" or Minimize "negative log-likelihood", but in R it seems trying to Minimize "log-likelihood"...

In the tutorial, it said the function that CatBoost is trying to optimize is "Cox log partial likelihood". From the output, the metric values for "learn" and "test" are all negative numbers, so I think these metric values are the "log-likelihood" (because in statistical sense, log-likelihood are negative numbers). If so, the goal of iterations should maximize it. However, from the output, we can see that the iterations are trying to minimize it. (-12463 to -13036 in "learn" column)

Problem 2: Inconsistent output to v1.2.3

I have reverted the version to v1.2.3. With the same code above, the metric values for "learn" and "test" are all decreasing positive numbers. I think the original log-likelihoods are converted to opposite numbers, namely "negative log-likelihood". If so, the goal of iterations should minimize it. This version of CatBoost completely makes sense. What makes the difference between the implementations of v1.2.3 & v1.2.5?

0:	learn: 12453.1683362	test: 5038.4024811	best: 5038.4024811 (0)	total: 154ms	remaining: 15.2s
50:	learn: 12256.5191692	test: 4947.5292621	best: 4947.5292621 (50)	total: 261ms	remaining: 251ms
99:	learn: 12109.6471344	test: 4880.2979213	best: 4880.2979213 (99)	total: 367ms	remaining: 0us

Problem 3: Prediction performance is very poor!

I calculate the concordance index (C-index) of testing set, on v1.2.3 & v1.2.5 respectively, using the same code. The later is very poor.

library(survival)
concordancefit(Surv(test$futime, test$death),
               -catboost.predict(model, test_pool))$concordance
  • v1.2.3: C-index = 0.76
  • v1.2.5: C-index = 0.20

Problem 4: Metric value reaches NaN

When I use a moderately large number of iterations, the metric value reaches NaN in v1.2.5.

model <- catboost.train(train_pool,
                        test_pool = test_pool,
                        params = list(iterations = 2500,
                                      loss_function = 'Cox',
                                      eval_metric = 'Cox',
                                      metric_period = 100))

v1.2.5

2100:	learn: -1467146.2738498	test: -676558.5489888	best: -676558.5489888 (2100)	total: 4.14s	remaining: 786ms
2200:	learn: -1516008.6256723	test: -694665.6985130	best: -694665.6985130 (2200)	total: 4.35s	remaining: 591ms
2300:	learn:              nan	test:             nan	best: -694665.6985130 (2200)	total: 4.57s	remaining: 395ms
2400:	learn:              nan	test:             nan	best: -694665.6985130 (2200)	total: 4.78s	remaining: 197ms
2499:	learn:              nan	test:             nan	best: -694665.6985130 (2200)	total: 4.99s	remaining: 0us

v1.2.3

With v1.2.3 the output looks very well.

2100:	learn: 15208.7450328	test: 6225.6551569	best: 4774.6366039 (300)	total: 4.36s	remaining: 828ms
2200:	learn: 15559.1376839	test: 6406.8914711	best: 4774.6366039 (300)	total: 4.56s	remaining: 619ms
2300:	learn: 15576.2159165	test: 6405.9180922	best: 4774.6366039 (300)	total: 4.75s	remaining: 411ms
2400:	learn: 15709.8865444	test: 6474.3107742	best: 4774.6366039 (300)	total: 4.95s	remaining: 204ms
2499:	learn: 16104.1899851	test: 6581.9568144	best: 4774.6366039 (300)	total: 5.14s	remaining: 0us

Metadata

Metadata

Assignees

No one assigned

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions