Skip to content

Commit

Permalink
[MRG] Expose n_iter_ to BaseLibSVM (scikit-learn#21408)
Browse files Browse the repository at this point in the history

Co-authored-by: Julien Jerphanion <[email protected]>
Co-authored-by: Olivier Grisel <[email protected]>
Co-authored-by: Adrin Jalali <[email protected]>
  • Loading branch information
4 people authored and venkyyuvy committed Jan 1, 2022
1 parent 3f03dfc commit d71163d
Show file tree
Hide file tree
Showing 13 changed files with 218 additions and 16 deletions.
7 changes: 7 additions & 0 deletions doc/whats_new/v1.1.rst
Original file line number Diff line number Diff line change
Expand Up @@ -321,6 +321,13 @@ Changelog
Setting a transformer to "passthrough" will pass the features unchanged.
:pr:`20860` by :user:`Shubhraneel Pal <shubhraneel>`.

:mod:`sklearn.svm`
...................

- |Enhancement| :class:`svm.OneClassSVM`, :class:`svm.NuSVC`,
:class:`svm.NuSVR`, :class:`svm.SVC` and :class:`svm.SVR` now expose
`n_iter_`, the number of iterations of the libsvm optimization routine.
:pr:`21408` by :user:`Juan Martín Loyola <jmloyola>`.
- |Fix| :class: `pipeline.Pipeline` now does not validate hyper-parameters in
`__init__` but in `.fit()`.
:pr:`21888` by :user:`iofall <iofall>` and :user: `Arisa Y. <arisayosh>`.
Expand Down
13 changes: 13 additions & 0 deletions sklearn/svm/_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -274,6 +274,17 @@ def fit(self, X, y, sample_weight=None):
self.intercept_ *= -1
self.dual_coef_ = -self.dual_coef_

# Since, in the case of SVC and NuSVC, the number of models optimized by
# libSVM could be greater than one (depending on the input), `n_iter_`
# stores an ndarray.
# For the other sub-classes (SVR, NuSVR, and OneClassSVM), the number of
# models optimized by libSVM is always one, so `n_iter_` stores an
# integer.
if self._impl in ["c_svc", "nu_svc"]:
self.n_iter_ = self._num_iter
else:
self.n_iter_ = self._num_iter.item()

return self

def _validate_targets(self, y):
Expand Down Expand Up @@ -320,6 +331,7 @@ def _dense_fit(self, X, y, sample_weight, solver_type, kernel, random_seed):
self._probA,
self._probB,
self.fit_status_,
self._num_iter,
) = libsvm.fit(
X,
y,
Expand Down Expand Up @@ -360,6 +372,7 @@ def _sparse_fit(self, X, y, sample_weight, solver_type, kernel, random_seed):
self._probA,
self._probB,
self.fit_status_,
self._num_iter,
) = libsvm_sparse.libsvm_sparse_train(
X.shape[1],
X.data,
Expand Down
29 changes: 29 additions & 0 deletions sklearn/svm/_classes.py
Original file line number Diff line number Diff line change
Expand Up @@ -670,6 +670,13 @@ class SVC(BaseSVC):
.. versionadded:: 1.0
n_iter_ : ndarray of shape (n_classes * (n_classes - 1) // 2,)
Number of iterations run by the optimization routine to fit the model.
The shape of this attribute depends on the number of models optimized
which in turn depends on the number of classes.
.. versionadded:: 1.1
support_ : ndarray of shape (n_SV)
Indices of support vectors.
Expand Down Expand Up @@ -925,6 +932,13 @@ class NuSVC(BaseSVC):
.. versionadded:: 1.0
n_iter_ : ndarray of shape (n_classes * (n_classes - 1) // 2,)
Number of iterations run by the optimization routine to fit the model.
The shape of this attribute depends on the number of models optimized
which in turn depends on the number of classes.
.. versionadded:: 1.1
support_ : ndarray of shape (n_SV,)
Indices of support vectors.
Expand Down Expand Up @@ -1140,6 +1154,11 @@ class SVR(RegressorMixin, BaseLibSVM):
.. versionadded:: 1.0
n_iter_ : int
Number of iterations run by the optimization routine to fit the model.
.. versionadded:: 1.1
n_support_ : ndarray of shape (n_classes,), dtype=int32
Number of support vectors for each class.
Expand Down Expand Up @@ -1328,6 +1347,11 @@ class NuSVR(RegressorMixin, BaseLibSVM):
.. versionadded:: 1.0
n_iter_ : int
Number of iterations run by the optimization routine to fit the model.
.. versionadded:: 1.1
n_support_ : ndarray of shape (n_classes,), dtype=int32
Number of support vectors for each class.
Expand Down Expand Up @@ -1512,6 +1536,11 @@ class OneClassSVM(OutlierMixin, BaseLibSVM):
.. versionadded:: 1.0
n_iter_ : int
Number of iterations run by the optimization routine to fit the model.
.. versionadded:: 1.1
n_support_ : ndarray of shape (n_classes,), dtype=int32
Number of support vectors for each class.
Expand Down
1 change: 1 addition & 0 deletions sklearn/svm/_libsvm.pxi
Original file line number Diff line number Diff line change
Expand Up @@ -56,6 +56,7 @@ cdef extern from "libsvm_helper.c":
char *, char *, char *, char *)

void copy_sv_coef (char *, svm_model *)
void copy_n_iter (char *, svm_model *)
void copy_intercept (char *, svm_model *, np.npy_intp *)
void copy_SV (char *, svm_model *, np.npy_intp *)
int copy_support (char *data, svm_model *model)
Expand Down
9 changes: 8 additions & 1 deletion sklearn/svm/_libsvm.pyx
Original file line number Diff line number Diff line change
Expand Up @@ -152,6 +152,9 @@ def fit(
probA, probB : array of shape (n_class*(n_class-1)/2,)
Probability estimates, empty array for probability=False.
n_iter : ndarray of shape (max(1, (n_class * (n_class - 1) // 2)),)
Number of iterations run by the optimization routine to fit the model.
"""

cdef svm_parameter param
Expand Down Expand Up @@ -199,6 +202,10 @@ def fit(
SV_len = get_l(model)
n_class = get_nr(model)

cdef np.ndarray[int, ndim=1, mode='c'] n_iter
n_iter = np.empty(max(1, n_class * (n_class - 1) // 2), dtype=np.intc)
copy_n_iter(n_iter.data, model)

cdef np.ndarray[np.float64_t, ndim=2, mode='c'] sv_coef
sv_coef = np.empty((n_class-1, SV_len), dtype=np.float64)
copy_sv_coef (sv_coef.data, model)
Expand Down Expand Up @@ -248,7 +255,7 @@ def fit(
free(problem.x)

return (support, support_vectors, n_class_SV, sv_coef, intercept,
probA, probB, fit_status)
probA, probB, fit_status, n_iter)


cdef void set_predict_params(
Expand Down
7 changes: 6 additions & 1 deletion sklearn/svm/_libsvm_sparse.pyx
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,7 @@ cdef extern from "libsvm_sparse_helper.c":
double, int, int, int, char *, char *, int,
int)
void copy_sv_coef (char *, svm_csr_model *)
void copy_n_iter (char *, svm_csr_model *)
void copy_support (char *, svm_csr_model *)
void copy_intercept (char *, svm_csr_model *, np.npy_intp *)
int copy_predict (char *, svm_csr_model *, np.npy_intp *, char *, BlasFunctions *)
Expand Down Expand Up @@ -159,6 +160,10 @@ def libsvm_sparse_train ( int n_features,
cdef np.npy_intp SV_len = get_l(model)
cdef np.npy_intp n_class = get_nr(model)

cdef np.ndarray[int, ndim=1, mode='c'] n_iter
n_iter = np.empty(max(1, n_class * (n_class - 1) // 2), dtype=np.intc)
copy_n_iter(n_iter.data, model)

# copy model.sv_coef
# we create a new array instead of resizing, otherwise
# it would not erase previous information
Expand Down Expand Up @@ -217,7 +222,7 @@ def libsvm_sparse_train ( int n_features,
free_param(param)

return (support, support_vectors_, sv_coef_data, intercept, n_class_SV,
probA, probB, fit_status)
probA, probB, fit_status, n_iter)


def libsvm_sparse_predict (np.ndarray[np.float64_t, ndim=1, mode='c'] T_data,
Expand Down
1 change: 1 addition & 0 deletions sklearn/svm/src/libsvm/LIBSVM_CHANGES
Original file line number Diff line number Diff line change
Expand Up @@ -7,4 +7,5 @@ This is here mainly as checklist for incorporation of new versions of libsvm.
* Improved random number generator (fix on windows, enhancement on other
platforms). See <https:/scikit-learn/scikit-learn/pull/13511#issuecomment-481729756>
* invoke scipy blas api for svm kernel function to improve performance with speedup rate of 1.5X to 2X for dense data only. See <https:/scikit-learn/scikit-learn/pull/16530>
* Expose the number of iterations run in optimization. See <https:/scikit-learn/scikit-learn/pull/21408>
The changes made with respect to upstream are detailed in the heading of svm.cpp
25 changes: 23 additions & 2 deletions sklearn/svm/src/libsvm/libsvm_helper.c
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,13 @@
#include <numpy/arrayobject.h>
#include "svm.h"
#include "_svm_cython_blas_helpers.h"


#ifndef MAX
#define MAX(x, y) (((x) > (y)) ? (x) : (y))
#endif


/*
* Some helper methods for libsvm bindings.
*
Expand Down Expand Up @@ -128,6 +135,9 @@ struct svm_model *set_model(struct svm_parameter *param, int nr_class,
if ((model->rho = malloc( m * sizeof(double))) == NULL)
goto rho_error;

// This is only allocated in dynamic memory while training.
model->n_iter = NULL;

model->nr_class = nr_class;
model->param = *param;
model->l = (int) support_dims[0];
Expand Down Expand Up @@ -218,6 +228,15 @@ npy_intp get_nr(struct svm_model *model)
return (npy_intp) model->nr_class;
}

/*
* Get the number of iterations run in optimization
*/
void copy_n_iter(char *data, struct svm_model *model)
{
const int n_models = MAX(1, model->nr_class * (model->nr_class-1) / 2);
memcpy(data, model->n_iter, n_models * sizeof(int));
}

/*
* Some helpers to convert from libsvm sparse data structures
* model->sv_coef is a double **, whereas data is just a double *,
Expand Down Expand Up @@ -363,9 +382,11 @@ int free_model(struct svm_model *model)
if (model == NULL) return -1;
free(model->SV);

/* We don't free sv_ind, since we did not create them in
/* We don't free sv_ind and n_iter, since we did not create them in
set_model */
/* free(model->sv_ind); */
/* free(model->sv_ind);
* free(model->n_iter);
*/
free(model->sv_coef);
free(model->rho);
free(model->label);
Expand Down
19 changes: 19 additions & 0 deletions sklearn/svm/src/libsvm/libsvm_sparse_helper.c
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,12 @@
#include "svm.h"
#include "_svm_cython_blas_helpers.h"


#ifndef MAX
#define MAX(x, y) (((x) > (y)) ? (x) : (y))
#endif


/*
* Convert scipy.sparse.csr to libsvm's sparse data structure
*/
Expand Down Expand Up @@ -122,6 +128,9 @@ struct svm_csr_model *csr_set_model(struct svm_parameter *param, int nr_class,
if ((model->rho = malloc( m * sizeof(double))) == NULL)
goto rho_error;

// This is only allocated in dynamic memory while training.
model->n_iter = NULL;

/* in the case of precomputed kernels we do not use
dense_to_precomputed because we don't want the leading 0. As
indices start at 1 (not at 0) this will work */
Expand Down Expand Up @@ -348,6 +357,15 @@ void copy_sv_coef(char *data, struct svm_csr_model *model)
}
}

/*
* Get the number of iterations run in optimization
*/
void copy_n_iter(char *data, struct svm_csr_model *model)
{
const int n_models = MAX(1, model->nr_class * (model->nr_class-1) / 2);
memcpy(data, model->n_iter, n_models * sizeof(int));
}

/*
* Get the number of support vectors in a model.
*/
Expand Down Expand Up @@ -402,6 +420,7 @@ int free_problem(struct svm_csr_problem *problem)
int free_model(struct svm_csr_model *model)
{
/* like svm_free_and_destroy_model, but does not free sv_coef[i] */
/* We don't free n_iter, since we did not create them in set_model. */
if (model == NULL) return -1;
free(model->SV);
free(model->sv_coef);
Expand Down
19 changes: 19 additions & 0 deletions sklearn/svm/src/libsvm/svm.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -55,6 +55,10 @@ SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
Sylvain Marie, Schneider Electric
see <https:/scikit-learn/scikit-learn/pull/13511#issuecomment-481729756>
Modified 2021:
- Exposed number of iterations run in optimization, Juan Martín Loyola.
See <https:/scikit-learn/scikit-learn/pull/21408/>
*/

#include <math.h>
Expand Down Expand Up @@ -553,6 +557,7 @@ class Solver {
double *upper_bound;
double r; // for Solver_NU
bool solve_timed_out;
int n_iter;
};

void Solve(int l, const QMatrix& Q, const double *p_, const schar *y_,
Expand Down Expand Up @@ -919,6 +924,9 @@ void Solver::Solve(int l, const QMatrix& Q, const double *p_, const schar *y_,
for(int i=0;i<l;i++)
si->upper_bound[i] = C[i];

// store number of iterations
si->n_iter = iter;

info("\noptimization finished, #iter = %d\n",iter);

delete[] p;
Expand Down Expand Up @@ -1837,6 +1845,7 @@ struct decision_function
{
double *alpha;
double rho;
int n_iter;
};

static decision_function svm_train_one(
Expand Down Expand Up @@ -1902,6 +1911,7 @@ static decision_function svm_train_one(
decision_function f;
f.alpha = alpha;
f.rho = si.rho;
f.n_iter = si.n_iter;
return f;
}

Expand Down Expand Up @@ -2387,6 +2397,8 @@ PREFIX(model) *PREFIX(train)(const PREFIX(problem) *prob, const svm_parameter *p
NAMESPACE::decision_function f = NAMESPACE::svm_train_one(prob,param,0,0, status,blas_functions);
model->rho = Malloc(double,1);
model->rho[0] = f.rho;
model->n_iter = Malloc(int,1);
model->n_iter[0] = f.n_iter;

int nSV = 0;
int i;
Expand Down Expand Up @@ -2523,8 +2535,12 @@ PREFIX(model) *PREFIX(train)(const PREFIX(problem) *prob, const svm_parameter *p
model->label[i] = label[i];

model->rho = Malloc(double,nr_class*(nr_class-1)/2);
model->n_iter = Malloc(int,nr_class*(nr_class-1)/2);
for(i=0;i<nr_class*(nr_class-1)/2;i++)
{
model->rho[i] = f[i].rho;
model->n_iter[i] = f[i].n_iter;
}

if(param->probability)
{
Expand Down Expand Up @@ -2978,6 +2994,9 @@ void PREFIX(free_model_content)(PREFIX(model)* model_ptr)

free(model_ptr->nSV);
model_ptr->nSV = NULL;

free(model_ptr->n_iter);
model_ptr->n_iter = NULL;
}

void PREFIX(free_and_destroy_model)(PREFIX(model)** model_ptr_ptr)
Expand Down
2 changes: 2 additions & 0 deletions sklearn/svm/src/libsvm/svm.h
Original file line number Diff line number Diff line change
Expand Up @@ -76,6 +76,7 @@ struct svm_model
int l; /* total #SV */
struct svm_node *SV; /* SVs (SV[l]) */
double **sv_coef; /* coefficients for SVs in decision functions (sv_coef[k-1][l]) */
int *n_iter; /* number of iterations run by the optimization routine to fit the model */

int *sv_ind; /* index of support vectors */

Expand All @@ -101,6 +102,7 @@ struct svm_csr_model
int l; /* total #SV */
struct svm_csr_node **SV; /* SVs (SV[l]) */
double **sv_coef; /* coefficients for SVs in decision functions (sv_coef[k-1][l]) */
int *n_iter; /* number of iterations run by the optimization routine to fit the model */

int *sv_ind; /* index of support vectors */

Expand Down
Loading

0 comments on commit d71163d

Please sign in to comment.