Skip to content

Commit

Permalink
Merge pull request #225 from THUzyt21/master
Browse files Browse the repository at this point in the history
Joblib support in parallel function
  • Loading branch information
guofei9987 authored Jun 23, 2024
2 parents 647980c + ed51eb6 commit f7e6f06
Showing 1 changed file with 11 additions and 2 deletions.
13 changes: 11 additions & 2 deletions sko/tools.py
Original file line number Diff line number Diff line change
Expand Up @@ -80,7 +80,7 @@ def func_transformed(X):
set_run_mode(func, 'vectorization')

mode = getattr(func, 'mode', 'others')
valid_mode = ('common', 'multithreading', 'multiprocessing', 'vectorization', 'cached', 'others')
valid_mode = ('common', 'multithreading', 'multiprocessing', 'vectorization', 'cached', 'joblib', 'others')
assert mode in valid_mode, 'valid mode should be in ' + str(valid_mode)
if mode == 'vectorization':
return func
Expand Down Expand Up @@ -116,7 +116,16 @@ def func_transformed(X):
return np.array(pool.map(func, X))

return func_transformed


elif mode == "joblib":
from joblib import Parallel, delayed
def func_transformed(X):
res = Parallel(n_jobs=-1, batch_size='auto')(
delayed(func)(x) for x in X
)
return np.array(res)
return func_transformed

else: # common
def func_transformed(X):
return np.array([func(x) for x in X])
Expand Down

0 comments on commit f7e6f06

Please sign in to comment.