Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

RandomSurvivalForest unusally slow #343

Closed
shil3754 opened this issue Mar 3, 2023 · 19 comments · Fixed by #371
Closed

RandomSurvivalForest unusally slow #343

shil3754 opened this issue Mar 3, 2023 · 19 comments · Fixed by #371

Comments

@shil3754
Copy link

shil3754 commented Mar 3, 2023

I attempted to fit a RandomSurvivalForest with 500,000 training instances and 140 features on a machine with 90GB memory. Unfortunately, after hours of waiting, the program ran out of memory and crashed. I wasn't able to see any progress even though the parameter 'verbose' was set to 1.

However, I was able to fit a sklearn.RandomForestRegressor using the same data, with the time of event as the label (censored or not) under the exact same settings. The whole fitting process took less than 1 minute. All common parameters, such as 'n_jobs', were set to be the same, and the only difference was the type of model. In both cases, 'n_jobs' was set to -1 to utilize parallelization.

I am struggling to understand why there is such a significant difference in training time between these two models. Although I expect survival analysis to take a bit longer than usual regression, the difference is quite significant. Unfortunately, my entire training dataset has more than 10 million instances, and it seems rather hopeless to apply RandomSurvivalForest at the moment.

I am wondering if there are any suggestions on how I could speed up the training process.

@sebp
Copy link
Owner

sebp commented Mar 3, 2023

That's not good.

sksurv calls sklearn to grow the forest, it just overrides the split criterion. It would be interesting to check whether increasing the number of samples or the number of features has a higher impact.

I would suspect that the number of samples could be the problem, because computing the split criterion involves sorting by the time of an event.

@sebp
Copy link
Owner

sebp commented Mar 12, 2023

I can confirm that it is indeed the sorting operation that is responsible for ~60% of the time it takes to fit a tree.

profiling

@solidate
Copy link

I am also having a very similar experience with RandomSurvivalForest as mentioned by @shil3754 . I tested it with 100 records and it worked fine. Then I ran it with 500k records on 96 core and 200GB memory machine. Now, It is taking forever to get any output from RandomSurvivalForest even after adding min_sample_splits, max_depth and variours other hyperparameters

@juancq
Copy link

juancq commented Apr 2, 2023

@sebp what is your recommendation in view of this?

@sebp
Copy link
Owner

sebp commented Apr 2, 2023

@sebp what is your recommendation in view of this?

I haven't investigated what's happening inside the tree builder from sklearn, but my hope is that it would be possible to sort only once and then use the order inside sksurv's split criterion instead of sorting during each call by the tree builder.

@tommyfuu
Copy link

having a similar experience with only 33599 samples and 29 features in the training set on a cluster.

@sebp
Copy link
Owner

sebp commented May 16, 2023

FYI, I started looking into this. It should be possible to pre-compute the log-rank statistics for all possible splits in LogrankCriterion.init such that LogrankCriterion.update is just a lookup without requiring to sort again.

@sebp
Copy link
Owner

sebp commented Jun 8, 2023

The fix-rsf-performance branch contains improved code to grow trees. For me, it reduces growing a single tree for 500 samples and 25 features from 1330ms to 340ms.

Would be great, if you could give it a try too and let me know what you think.

sebp added a commit that referenced this issue Jun 9, 2023
- Construct total risk set without sorting
- Compute risk set for left node on the fly
- Rename event_times_ to unique_times_ everywhere

Fixes #343
@sebp sebp closed this as completed in #371 Jun 10, 2023
@juancq
Copy link

juancq commented Jun 13, 2023

@sebp there's something wrong about the new implementation.
Here is the code I used:

import pandas as pd
from sksurv.ensemble import RandomSurvivalForest
from sksurv import datasets as sksurv_datasets

X, y = sksurv_datasets.load_flchain()
Xt = X.drop('chapter', axis=1)
Xt = pd.get_dummies(Xt)

mask = Xt.isna().any(axis=1)
Xt = Xt[~mask]
y = y[~mask]

model = RandomSurvivalForest(n_estimators=100)
model.fit(Xt, y)

With 0.20.0, I get the following running times:

n_estimators run time
10 6 seconds
100 16 seconds
200 16 seconds
500 17 seconds
1000 17 seconds

With 0.21.0, I get the following running times:

n_estimators run time
10 5 seconds
100 34 seconds
200 68 seconds
500 97 seconds and then "killed"

System:
python: 3.10.8 (main, Dec 5 2022, 10:38:26) [GCC 12.2.0]
executable: /home/helloworld/venvs/survival_experiments/bin/python
machine: Linux-4.18.0-425.19.2.el8_7.x86_64-x86_64-with-glibc2.28

Python dependencies:
sklearn: 1.2.2
pip: 23.1.2
setuptools: 63.2.0
numpy: 1.23.5
scipy: 1.10.1
Cython: 0.29.35
pandas: 1.5.3
matplotlib: 3.7.1
joblib: 1.2.0
threadpoolctl: 3.1.0

Built with OpenMP: True

threadpoolctl info:
user_api: openmp
internal_api: openmp
prefix: libgomp
filepath: /home/helloworld/venvs/survival_experiments/lib/python3.10/site-packages/scikit_learn.libs/libgomp-a34b3233.so.1.0.0
version: None
num_threads: 1

   user_api: blas

internal_api: openblas
prefix: libopenblas
filepath: /home/helloworld/venvs/survival_experiments/lib/python3.10/site-packages/numpy.libs/libopenblas64_p-r0-742d56dc.3.20.so
version: 0.3.20
threading_layer: pthreads
architecture: SkylakeX
num_threads: 1

   user_api: blas

internal_api: openblas
prefix: libopenblas
filepath: /home/helloworld/venvs/survival_experiments/lib/python3.10/site-packages/scipy.libs/libopenblasp-r0-41284840.3.18.so
version: 0.3.18
threading_layer: pthreads
architecture: SkylakeX
num_threads: 1
sksurv: 0.21.0
Traceback (most recent call last):
File "/home/helloworld/survival-experiments/version.py", line 3, in
import cvxopt; print("cvxopt:", cvxopt.version)
ModuleNotFoundError: No module named 'cvxopt'

Traceback (most recent call last):
File "/home/helloworld/survival-experiments/version.py", line 4, in
import cvxpy; print("cvxpy:", cvxpy.version)
ModuleNotFoundError: No module named 'cvxpy'

sksurv: 0.21.0
numexpr: 2.8.4
osqp: 0.6.2

@sebp
Copy link
Owner

sebp commented Jun 13, 2023

@juancq Could you please post which versions you are using, as described here.

@juancq
Copy link

juancq commented Jun 14, 2023

@sebp I updated the comment with the versions I am using, but wait until I post again, I'm trying to reproduce running times and now I'm not getting the same numbers.

@juancq
Copy link

juancq commented Jun 14, 2023

@sebp my apologies, the times I posted before for 0.20.0 were the times the script received the killed signal, not the successful running time, which explains the consistent time from 100-1000 estimators.

I do see speed ups with version 0.21.0, as evident below:

With 0.20.0:

n_estimators run time
10 5 seconds
100 41 seconds
200 86 seconds
500 217 seconds and then "killed"
1000 217 seconds and then "killed"

With 0.21.0:

n_estimators run time
10 5 seconds
100 34 seconds
200 68 seconds
500 97 seconds and then "killed"

@oelhammouchi
Copy link

Hi, thanks a lot for your work on this package! I've been experiencing similar performance difficulties despite the fix in 0.21. My training data consists of approx. 50K rows and 27 features (after one-hot encoding). Fitting takes 15-20 min for a single run, so it's very difficult to do cross validation, feature selection, etc. Any idea how I could remedy this? Below is the output of cProfile.

image

@sebp
Copy link
Owner

sebp commented Aug 2, 2023

@OthmanElHammouchi The current bottleneck is #382

Under the hood, scikit-learn is used to build trees. Unfortunately, this means that each node in every tree contains a survival and cumulative hazard function, which causes overhead due large portions of memory being copied. I don't have a straight-forward solution for this.

@oelhammouchi
Copy link

@sebp Ah, I see, thanks for your reply. I'm not familiar with the sklearn implementation, tried looking into it yesterday but it does seem quite intricate.

@gpwhs
Copy link

gpwhs commented Aug 29, 2024

Has there been any movement on this issue? I'm still finding RSF to be incredibly slow on a dataset with ~40 features, 200k samples.

@sebp
Copy link
Owner

sebp commented Sep 6, 2024 via email

@gpwhs
Copy link

gpwhs commented Sep 6, 2024

@sebp have you done any scoping at all? I may have some capacity to attack this :)

@sebp
Copy link
Owner

sebp commented Sep 6, 2024

See #382 (please continue discussion there)

Essentially by relying on scikit-learn's implementation, we have to store the survival and cumulative hazard function in each node of the tree. With many samples you also have many time points and thus large overhead in terms of memory that has to be initialized.

Repository owner locked and limited conversation to collaborators Sep 6, 2024
Sign up for free to subscribe to this conversation on GitHub. Already have an account? Sign in.
Projects
None yet
Development

Successfully merging a pull request may close this issue.

7 participants