WIP: [python-package] ensure predict() always returns an array #6348
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
Contributes to #3756.
Contributes to #3867.
Overview
For most combinations of parameters and input data types,
.predict()
methods in the Python package return either anumpy
array or ascipy.sparse.{csc,csr}_matrix
.Since #3000 (merged June 2020, released in LightGBM v3.0.0 August 2020), there is exactly 1 exception.
For multiclass classification models,
.predict(X, pred_contrib=True)
returns a Pythonlist
of sparse matrices ifX
is ascipy.sparse.{csc,csr}_matrix
.example with lightgbm 4.3.0 (click me)
This PR proposes modifying LightGBM such that
.predict()
methods always return anumpy
orscipy.sparse
array.And other related changes:
lightgbm.dask
that existed to handle this special caseBenefits of this change
lightgbm.dask
interfacelightgbm
's internal prediction routines, and in docs for.predict()
methodslightgbm
to use array operations on the output of.predict()
methods unconditionallylightgbm
to ignore or work-around type-checking issues when they unconditionally use the output of.predict()
as if it's only eithernumpy
orscipy.sparse
arrayshap
library (see "Does this breakshap
" below)lightgbm
easier, improving the likelihood that type-checkers likemypy
will be able to catch bugsCosts of this change
lightgbm
returning alist
for this specific casescipy.sparse.hstack()
for CSC/CSR matricesNotes for Reviewers
Why is a list currently returned?
I'm not sure. I couldn't figure that out from #3000. I hope @imatiach-msft or @guolinke will be able to remember or someone else who understands C++ and CSC/CSR format better than me will be able to explain.
I suspect it was related to concerns of the form "concatenating CSC or CSR matrices requires hard-to-do-efficiently index updates", as mentioned here:
Or maybe it's because of concerns about having more than
INT32_MAX
items, which couldn't be represented as anint32
indptr, as mentioned here:sparse.hstack
returns incorrect result when the stack would result in indices too large fornp.int32
scipy/scipy#16569Does this break
shap
?No... the
shap
package already cannot generate SHAP values for multiclass classificationlightgbm
models andscipy.sparse
inputs.reproducible example (click me)
Consider the following example, with
numpy==1.26.0
,scikit-learn==1.4.1
,scipy==1.12.0
, andshap==0.44.0
, using Python 3.11 on x84_64 macOS.With
lightgbm==4.3.0
:As of this branch,
shap
still fails on such cases... but in a way that I think we could fix more easily after a release oflightgbm
including this PR.details (click me)
Error raised by
shap
usinglightgbm
as of this PR:So yes still broken... but in a way that I think we could easily fix in
shap
. That error is coming fromshap
's attempt here to reshape the.predict(..., pred_contrib=True)
output into a 3D matrix.https:/shap/shap/blob/d51d173f28b52d2f501b33668bf4529acf22709a/shap/explainers/_tree.py#L444
scipy.sparse.{csc,csr}_matrix
objects are 2-dimensional.With
xgboost==2.0.3
,shap
happily generates SHAP values for multiclass classification models andscipy.sparse
matrices.details (click me)
All setup identical to the reproducible example above.
How I tested this
Interactively, with the examples shared above.
Confirmed in the
lint
CI job that this does not introduce newmypy
errors: https:/microsoft/LightGBM/actions/runs/8107505294/job/22159131171?pr=6348.Modified the 2 existing Python unit tasks (one for
Booster.predict()
, one for dask) which already thoroughly cover this code path.References
More notes: More notes on that at #3867 (comment)