Skip to content

Commit

Permalink
Refactor SUNLinSolWrapper (#2520)
Browse files Browse the repository at this point in the history
* Make it consistent that SUNLinSolWrapper always holds the associated matrix
* Always use SUNMatrixWrapper instead of raw SUNMatrix objects
* Implement declared but missing move assignment

This makes it a bit easier to finally address #1164.
  • Loading branch information
dweindl authored Oct 1, 2024
1 parent 7f27bb3 commit be02a6e
Show file tree
Hide file tree
Showing 4 changed files with 74 additions and 79 deletions.
67 changes: 24 additions & 43 deletions include/amici/sundials_linsol_wrapper.h
Original file line number Diff line number Diff line change
Expand Up @@ -33,10 +33,21 @@ class SUNLinSolWrapper {

/**
* @brief Wrap existing SUNLinearSolver
* @param linsol
*
* @param linsol SUNLinSolWrapper takes ownership of `linsol`.
*/
explicit SUNLinSolWrapper(SUNLinearSolver linsol);

/**
* @brief Wrap existing SUNLinearSolver
*
* @param linsol SUNLinSolWrapper takes ownership of `linsol`.
* @param A Matrix
*/
explicit SUNLinSolWrapper(
SUNLinearSolver linsol, SUNMatrixWrapper const& A
);

virtual ~SUNLinSolWrapper();

/**
Expand Down Expand Up @@ -80,26 +91,17 @@ class SUNLinSolWrapper {
/**
* @brief Performs any linear solver setup needed, based on an updated
* system matrix A.
* @param A
*/
void setup(SUNMatrix A) const;

/**
* @brief Performs any linear solver setup needed, based on an updated
* system matrix A.
* @param A
*/
void setup(SUNMatrixWrapper const& A) const;
void setup() const;

/**
* @brief Solves a linear system A*x = b
* @param A
* @param x A template for cloning vectors needed within the solver.
* @param b
* @param tol Tolerance (weighted 2-norm), iterative solvers only
* @return error flag
*/
int Solve(SUNMatrix A, N_Vector x, N_Vector b, realtype tol) const;
int solve(N_Vector x, N_Vector b, realtype tol) const;

/**
* @brief Returns the last error flag encountered within the linear solver
Expand All @@ -119,7 +121,7 @@ class SUNLinSolWrapper {
* @brief Get the matrix A (matrix solvers only).
* @return A
*/
virtual SUNMatrix getMatrix() const;
virtual SUNMatrixWrapper& getMatrix();

protected:
/**
Expand All @@ -131,6 +133,9 @@ class SUNLinSolWrapper {

/** Wrapped solver */
SUNLinearSolver solver_{nullptr};

/** Matrix A for solver. */
SUNMatrixWrapper A_;
};

/**
Expand All @@ -139,12 +144,12 @@ class SUNLinSolWrapper {
class SUNLinSolBand : public SUNLinSolWrapper {
public:
/**
* @brief Create solver using existing matrix A without taking ownership of
* A.
* @brief Create solver using existing matrix A
*
* @param x A template for cloning vectors needed within the solver.
* @param A square matrix
*/
SUNLinSolBand(N_Vector x, SUNMatrix A);
SUNLinSolBand(N_Vector x, SUNMatrixWrapper A);

/**
* @brief Create new band solver and matrix A.
Expand All @@ -153,12 +158,6 @@ class SUNLinSolBand : public SUNLinSolWrapper {
* @param lbw lower bandwidth of band matrix A
*/
SUNLinSolBand(AmiVector const& x, int ubw, int lbw);

SUNMatrix getMatrix() const override;

private:
/** Matrix A for solver, only if created by here. */
SUNMatrixWrapper A_;
};

/**
Expand All @@ -171,12 +170,6 @@ class SUNLinSolDense : public SUNLinSolWrapper {
* @param x A template for cloning vectors needed within the solver.
*/
explicit SUNLinSolDense(AmiVector const& x);

SUNMatrix getMatrix() const override;

private:
/** Matrix A for solver, only if created by here. */
SUNMatrixWrapper A_;
};

/**
Expand All @@ -192,7 +185,7 @@ class SUNLinSolKLU : public SUNLinSolWrapper {
* @param x A template for cloning vectors needed within the solver.
* @param A sparse matrix
*/
SUNLinSolKLU(N_Vector x, SUNMatrix A);
SUNLinSolKLU(N_Vector x, SUNMatrixWrapper A);

/**
* @brief Create KLU solver and matrix to operate on
Expand All @@ -202,11 +195,9 @@ class SUNLinSolKLU : public SUNLinSolWrapper {
* @param ordering
*/
SUNLinSolKLU(
AmiVector const& x, int nnz, int sparsetype, StateOrdering ordering
AmiVector const& x, int nnz, int sparsetype, StateOrdering ordering = StateOrdering::COLAMD
);

SUNMatrix getMatrix() const override;

/**
* @brief Reinitializes memory and flags for a new factorization
* (symbolic and numeric) to be conducted at the next solver setup call.
Expand All @@ -223,10 +214,6 @@ class SUNLinSolKLU : public SUNLinSolWrapper {
* @param ordering
*/
void setOrdering(StateOrdering ordering);

private:
/** Sparse matrix A for solver, only if created by here. */
SUNMatrixWrapper A_;
};

#ifdef SUNDIALS_SUPERLUMT
Expand All @@ -249,7 +236,7 @@ class SUNLinSolSuperLUMT : public SUNLinSolWrapper {
* @param A sparse matrix
* @param numThreads Number of threads to be used by SuperLUMT
*/
SUNLinSolSuperLUMT(N_Vector x, SUNMatrix A, int numThreads);
SUNLinSolSuperLUMT(N_Vector x, SUNMatrixWrapper A, int numThreads);

/**
* @brief Create SuperLUMT solver and matrix to operate on
Expand Down Expand Up @@ -279,18 +266,12 @@ class SUNLinSolSuperLUMT : public SUNLinSolWrapper {
int numThreads
);

SUNMatrix getMatrix() const override;

/**
* @brief Sets the ordering used by SuperLUMT for reducing fill in the
* linear solve.
* @param ordering
*/
void setOrdering(StateOrdering ordering);

private:
/** Sparse matrix A for solver, only if created by here. */
SUNMatrixWrapper A;
};

#endif
Expand Down
8 changes: 5 additions & 3 deletions src/model.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -2221,10 +2221,12 @@ void Model::fdJydy(int const it, AmiVector const& x, ExpData const& edata) {
BLASLayout::colMajor, BLASTranspose::noTrans,
BLASTranspose::noTrans, nJ, ny, ny, 1.0,
&derived_state_.dJydsigma_.at(iyt * nJ * ny), nJ,
derived_state_.dsigmaydy_.data(), ny, 1.0, derived_state_.dJydy_dense_.data(), nJ
derived_state_.dsigmaydy_.data(), ny, 1.0,
derived_state_.dJydy_dense_.data(), nJ
);

auto tmp_sparse = SUNMatrixWrapper(derived_state_.dJydy_dense_, 0.0, CSC_MAT);
auto tmp_sparse
= SUNMatrixWrapper(derived_state_.dJydy_dense_, 0.0, CSC_MAT);
auto ret = SUNMatScaleAdd(
1.0, derived_state_.dJydy_.at(iyt), tmp_sparse
);
Expand Down Expand Up @@ -3079,7 +3081,7 @@ std::vector<double> Model::get_trigger_timepoints() const {
return trigger_timepoints;
}

void Model::set_steadystate_mask(const std::vector<realtype> &mask) {
void Model::set_steadystate_mask(std::vector<realtype> const& mask) {
if (mask.size() == 0) {
steadystate_mask_.clear();
return;
Expand Down
11 changes: 5 additions & 6 deletions src/steadystateproblem.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -552,8 +552,7 @@ SteadystateProblem::getWrms(Model& model, SensitivityMethod sensi_method) {
"steady state computations. Stopping."
);
wrms = getWrmsNorm(
xQB_, xQBdot_, steadystate_mask_, atol_quad_,
rtol_quad_, ewtQB_
xQB_, xQBdot_, steadystate_mask_, atol_quad_, rtol_quad_, ewtQB_
);
} else {
/* If we're doing a forward simulation (with or without sensitivities:
Expand All @@ -563,8 +562,8 @@ SteadystateProblem::getWrms(Model& model, SensitivityMethod sensi_method) {
else
updateRightHandSide(model);
wrms = getWrmsNorm(
state_.x, newton_step_conv_ ? delta_ : xdot_,
steadystate_mask_, atol_, rtol_, ewt_
state_.x, newton_step_conv_ ? delta_ : xdot_, steadystate_mask_,
atol_, rtol_, ewt_
);
}
return wrms;
Expand All @@ -586,8 +585,8 @@ realtype SteadystateProblem::getWrmsFSA(Model& model) {
if (newton_step_conv_)
newton_solver_->solveLinearSystem(xdot_);
wrms = getWrmsNorm(
state_.sx[ip], xdot_, steadystate_mask_, atol_sensi_,
rtol_sensi_, ewt_
state_.sx[ip], xdot_, steadystate_mask_, atol_sensi_, rtol_sensi_,
ewt_
);
/* ideally this function would report the maximum of all wrms over
all ip, but for practical purposes we can just report the wrms for
Expand Down
67 changes: 40 additions & 27 deletions src/sundials_linsol_wrapper.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -9,13 +9,26 @@ namespace amici {
SUNLinSolWrapper::SUNLinSolWrapper(SUNLinearSolver linsol)
: solver_(linsol) {}

SUNLinSolWrapper::SUNLinSolWrapper(
SUNLinearSolver linsol, SUNMatrixWrapper const& A
)
: solver_(linsol)
, A_(A) {}

SUNLinSolWrapper::~SUNLinSolWrapper() {
if (solver_)
SUNLinSolFree(solver_);
}

SUNLinSolWrapper::SUNLinSolWrapper(SUNLinSolWrapper&& other) noexcept {
std::swap(solver_, other.solver_);
std::swap(A_, other.A_);
}

SUNLinSolWrapper& SUNLinSolWrapper::operator=(SUNLinSolWrapper&& other) noexcept {
std::swap(solver_, other.solver_);
std::swap(A_, other.A_);
return *this;
}

SUNLinearSolver SUNLinSolWrapper::get() const { return solver_; }
Expand All @@ -31,19 +44,14 @@ int SUNLinSolWrapper::initialize() {
return res;
}

void SUNLinSolWrapper::setup(SUNMatrix A) const {
auto res = SUNLinSolSetup(solver_, A);
void SUNLinSolWrapper::setup() const {
auto res = SUNLinSolSetup(solver_, A_.get());
if (res != SUNLS_SUCCESS)
throw AmiException("Solver setup failed with code %d", res);
}

void SUNLinSolWrapper::setup(SUNMatrixWrapper const& A) const {
return setup(A.get());
}

int SUNLinSolWrapper::Solve(SUNMatrix A, N_Vector x, N_Vector b, realtype tol)
const {
return SUNLinSolSolve(solver_, A, x, b, tol);
int SUNLinSolWrapper::solve(N_Vector x, N_Vector b, realtype tol) const {
return SUNLinSolSolve(solver_, A_.get(), x, b, tol);
}

long SUNLinSolWrapper::getLastFlag() const {
Expand All @@ -54,7 +62,7 @@ int SUNLinSolWrapper::space(long* lenrwLS, long* leniwLS) const {
return SUNLinSolSpace(solver_, lenrwLS, leniwLS);
}

SUNMatrix SUNLinSolWrapper::getMatrix() const { return nullptr; }
SUNMatrixWrapper& SUNLinSolWrapper::getMatrix() { return A_; }

SUNNonLinSolWrapper::SUNNonLinSolWrapper(SUNNonlinearSolver sol)
: solver(sol) {}
Expand Down Expand Up @@ -153,31 +161,29 @@ void SUNNonLinSolWrapper::initialize() {
);
}

SUNLinSolBand::SUNLinSolBand(N_Vector x, SUNMatrix A)
SUNLinSolBand::SUNLinSolBand(N_Vector x, SUNMatrixWrapper A)
: SUNLinSolWrapper(SUNLinSol_Band(x, A)) {
if (!solver_)
throw AmiException("Failed to create solver.");
}

SUNLinSolBand::SUNLinSolBand(AmiVector const& x, int ubw, int lbw)
: A_(SUNMatrixWrapper(x.getLength(), ubw, lbw)) {
: SUNLinSolWrapper(nullptr, SUNMatrixWrapper(x.getLength(), ubw, lbw)) {
solver_ = SUNLinSol_Band(const_cast<N_Vector>(x.getNVector()), A_);
if (!solver_)
throw AmiException("Failed to create solver.");
}

SUNMatrix SUNLinSolBand::getMatrix() const { return A_.get(); }

SUNLinSolDense::SUNLinSolDense(AmiVector const& x)
: A_(SUNMatrixWrapper(x.getLength(), x.getLength())) {
: SUNLinSolWrapper(
nullptr, SUNMatrixWrapper(x.getLength(), x.getLength())
) {
solver_ = SUNLinSol_Dense(const_cast<N_Vector>(x.getNVector()), A_);
if (!solver_)
throw AmiException("Failed to create solver.");
}

SUNMatrix SUNLinSolDense::getMatrix() const { return A_.get(); }

SUNLinSolKLU::SUNLinSolKLU(N_Vector x, SUNMatrix A)
SUNLinSolKLU::SUNLinSolKLU(N_Vector x, SUNMatrixWrapper A)
: SUNLinSolWrapper(SUNLinSol_KLU(x, A)) {
if (!solver_)
throw AmiException("Failed to create solver.");
Expand All @@ -186,16 +192,17 @@ SUNLinSolKLU::SUNLinSolKLU(N_Vector x, SUNMatrix A)
SUNLinSolKLU::SUNLinSolKLU(
AmiVector const& x, int nnz, int sparsetype, StateOrdering ordering
)
: A_(SUNMatrixWrapper(x.getLength(), x.getLength(), nnz, sparsetype)) {
: SUNLinSolWrapper(
nullptr,
SUNMatrixWrapper(x.getLength(), x.getLength(), nnz, sparsetype)
) {
solver_ = SUNLinSol_KLU(const_cast<N_Vector>(x.getNVector()), A_);
if (!solver_)
throw AmiException("Failed to create solver.");

setOrdering(ordering);
}

SUNMatrix SUNLinSolKLU::getMatrix() const { return A_.get(); }

void SUNLinSolKLU::reInit(int nnz, int reinit_type) {
int status = SUNLinSol_KLUReInit(solver_, A_, nnz, reinit_type);
if (status != SUNLS_SUCCESS)
Expand Down Expand Up @@ -413,8 +420,10 @@ int SUNNonLinSolFixedPoint::getSysFn(SUNNonlinSolSysFn* SysFn) const {

#ifdef SUNDIALS_SUPERLUMT

SUNLinSolSuperLUMT::SUNLinSolSuperLUMT(N_Vector x, SUNMatrix A, int numThreads)
: SUNLinSolWrapper(SUNLinSol_SuperLUMT(x, A, numThreads)) {
SUNLinSolSuperLUMT::SUNLinSolSuperLUMT(
N_Vector x, SUNMatrixWrapper A, int numThreads
)
: SUNLinSolWrapper(SUNLinSol_SuperLUMT(x, A, numThreads), A) {
if (!solver)
throw AmiException("Failed to create solver.");
}
Expand All @@ -423,7 +432,10 @@ SUNLinSolSuperLUMT::SUNLinSolSuperLUMT(
AmiVector const& x, int nnz, int sparsetype,
SUNLinSolSuperLUMT::StateOrdering ordering
)
: A(SUNMatrixWrapper(x.getLength(), x.getLength(), nnz, sparsetype)) {
: SUNLinSolWrapper(
nullptr,
SUNMatrixWrapper(x.getLength(), x.getLength(), nnz, sparsetype)
) {
int numThreads = 1;
if (auto env = std::getenv("AMICI_SUPERLUMT_NUM_THREADS")) {
numThreads = std::max(1, std::stoi(env));
Expand All @@ -440,16 +452,17 @@ SUNLinSolSuperLUMT::SUNLinSolSuperLUMT(
AmiVector const& x, int nnz, int sparsetype, StateOrdering ordering,
int numThreads
)
: A(SUNMatrixWrapper(x.getLength(), x.getLength(), nnz, sparsetype)) {
: SUNLinSolWrapper(
nullptr,
SUNMatrixWrapper(x.getLength(), x.getLength(), nnz, sparsetype)
) {
solver = SUNLinSol_SuperLUMT(x.getNVector(), A.get(), numThreads);
if (!solver)
throw AmiException("Failed to create solver.");

setOrdering(ordering);
}

SUNMatrix SUNLinSolSuperLUMT::getMatrix() const { return A.get(); }

void SUNLinSolSuperLUMT::setOrdering(StateOrdering ordering) {
auto status
= SUNLinSol_SuperLUMTSetOrdering(solver, static_cast<int>(ordering));
Expand Down

0 comments on commit be02a6e

Please sign in to comment.