-
Notifications
You must be signed in to change notification settings - Fork 1.2k
Closed
Description
Problem:
catboost (in R) training with YetiRank
fails with an error: Targets are required for YetiRank loss function.
. But works with any other loss function in the same dataset.
Reproducible example:
library(catboost)
data = data.frame(
"target" = rnorm(10000),
"x" = rnorm(10000),
"x_2" = rnorm(10000),
"group" = rep(1:200, each = 50)
)
data_pool <- catboost.load_pool(
data[, c("x", "x_2")],
label = data$target,
group_id = as.integer(data$group)
)
# Works
catboost_cv <- catboost.cv(
pool = data_pool,
params = list(
"has_time" = TRUE,
"iterations" = 20,
"loss_function" = 'PairLogit',
"od_type" = "Iter",
"od_wait" = 5,
"learning_rate" = 0.005,
"random_seed" = 123L,
"verbose" = 0
)
)
# Works
catboost_cv <- catboost.cv(
pool = data_pool,
params = list(
"has_time" = TRUE,
"iterations" = 20,
"loss_function" = 'StochasticFilter',
"od_type" = "Iter",
"od_wait" = 5,
"learning_rate" = 0.005,
"random_seed" = 123L,
"verbose" = 0
)
)
# Fails
catboost_cv <- catboost.cv(
pool = data_pool,
params = list(
"has_time" = TRUE,
"iterations" = 20,
"loss_function" = 'YetiRank',
"od_type" = "Iter",
"od_wait" = 5,
"learning_rate" = 0.005,
"random_seed" = 123L,
"verbose" = 0
)
)
Error:
Error in catboost.cv(pool = data_pool, params = list(has_time = TRUE, :
catboost/libs/train_lib/options_helper.cpp:88: Targets are required for YetiRank loss function.
But strangely this (train not cv on YetiRank) works:
# Works
catboost.train(
data_pool,
params = list(
"has_time" = TRUE,
"iterations" = 20,
"loss_function" = 'YetiRank',
"od_type" = "Iter",
"od_wait" = 5,
"learning_rate" = 0.005,
"random_seed" = 123L,
"verbose" = 0
)
)
Session info:
R version 4.1.3 (2022-03-10)
Platform: x86_64-redhat-linux-gnu (64-bit)
Running under: CentOS Stream 8
Matrix products: default
BLAS/LAPACK: /usr/lib64/libopenblas-r0.3.15.so
locale:
[1] LC_CTYPE=C.utf8 LC_NUMERIC=C LC_TIME=C.utf8
[4] LC_COLLATE=C.utf8 LC_MONETARY=C.utf8 LC_MESSAGES=C.utf8
[7] LC_PAPER=C.utf8 LC_NAME=C LC_ADDRESS=C
[10] LC_TELEPHONE=C LC_MEASUREMENT=C.utf8 LC_IDENTIFICATION=C
attached base packages:
[1] stats graphics grDevices utils datasets methods base
other attached packages:
[1] catboost_1.0.5
loaded via a namespace (and not attached):
[1] compiler_4.1.3 tools_4.1.3 jsonlite_1.8.0
catboost version: catboost_1.0.5
Operating System: CentOS Stream 8
R version 4.1.3 (2022-03-10)
yk4r2yk4r2