Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix allowed values for kernel #444

Merged
merged 1 commit into from
Apr 2, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 3 additions & 3 deletions sksurv/svm/minlip.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
from scipy import linalg, sparse
from sklearn.base import BaseEstimator
from sklearn.exceptions import ConvergenceWarning
from sklearn.metrics.pairwise import pairwise_kernels
from sklearn.metrics.pairwise import PAIRWISE_KERNEL_FUNCTIONS, pairwise_kernels
from sklearn.utils._param_validation import Interval, StrOptions

from ..base import SurvivalAnalysisMixin
Expand Down Expand Up @@ -207,7 +207,7 @@ class MinlipSurvivalAnalysis(BaseEstimator, SurvivalAnalysisMixin):
solver : {'ecos', 'osqp'}, optional, default: 'ecos'
Which quadratic program solver to use.

kernel : {'linear', 'poly', 'rbf', 'sigmoid', 'cosine', 'precomputed'} or callable, default: 'linear'.
kernel : str or callable, default: 'linear'.
Kernel mapping used internally. This parameter is directly passed to
:func:`sklearn.metrics.pairwise.pairwise_kernels`.
If `kernel` is a string, it must be one of the metrics
Expand Down Expand Up @@ -290,7 +290,7 @@ class MinlipSurvivalAnalysis(BaseEstimator, SurvivalAnalysisMixin):
"solver": [StrOptions({"ecos", "osqp"})],
"alpha": [Interval(numbers.Real, 0, None, closed="neither")],
"kernel": [
StrOptions({"linear", "poly", "rbf", "sigmoid", "precomputed"}),
StrOptions(set(PAIRWISE_KERNEL_FUNCTIONS.keys()) | {"precomputed"}),
callable,
],
"degree": [Interval(numbers.Integral, 0, None, closed="left")],
Expand Down
6 changes: 3 additions & 3 deletions sksurv/svm/survival_svm.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@
from scipy.optimize import minimize
from sklearn.base import BaseEstimator
from sklearn.exceptions import ConvergenceWarning
from sklearn.metrics.pairwise import pairwise_kernels
from sklearn.metrics.pairwise import PAIRWISE_KERNEL_FUNCTIONS, pairwise_kernels
from sklearn.utils import check_array, check_consistent_length, check_random_state, check_X_y
from sklearn.utils._param_validation import Interval, StrOptions
from sklearn.utils.extmath import safe_sparse_dot, squared_norm
Expand Down Expand Up @@ -998,7 +998,7 @@ class FastKernelSurvivalSVM(BaseSurvivalSVM, SurvivalAnalysisMixin):
Whether to calculate an intercept for the regression model. If set to ``False``, no intercept
will be calculated. Has no effect if ``rank_ratio = 1``, i.e., only ranking is performed.

kernel : {'linear', 'poly', 'rbf', 'sigmoid', 'cosine', 'precomputed'} or callable, default: 'linear'.
kernel : str or callable, default: 'linear'.
Kernel mapping used internally. This parameter is directly passed to
:func:`sklearn.metrics.pairwise.pairwise_kernels`.
If `kernel` is a string, it must be one of the metrics
Expand Down Expand Up @@ -1088,7 +1088,7 @@ class FastKernelSurvivalSVM(BaseSurvivalSVM, SurvivalAnalysisMixin):
_parameter_constraints = {
**FastSurvivalSVM._parameter_constraints,
"kernel": [
StrOptions({"linear", "poly", "rbf", "sigmoid", "precomputed"}),
StrOptions(set(PAIRWISE_KERNEL_FUNCTIONS.keys()) | {"precomputed"}),
callable,
],
"gamma": [Interval(Real, 0.0, None, closed="left"), None],
Expand Down