Skip to content

Commit

Permalink
Refactor NewtonSolver (#2521)
Browse files Browse the repository at this point in the history
Use SUNLinSolWrapper. Get rid of subclasses. Simplify.

Closes #1164.

We should also be able to get rid of the remaining data members,
but that's for another time...
  • Loading branch information
dweindl authored Oct 3, 2024
1 parent c199897 commit f66f0b0
Show file tree
Hide file tree
Showing 6 changed files with 144 additions and 291 deletions.
135 changes: 21 additions & 114 deletions include/amici/newton_solver.h
Original file line number Diff line number Diff line change
Expand Up @@ -2,11 +2,8 @@
#define amici_newton_solver_h

#include "amici/solver.h"
#include "amici/sundials_matrix_wrapper.h"
#include "amici/vector.h"

#include <memory>

namespace amici {

class Model;
Expand All @@ -26,35 +23,30 @@ class NewtonSolver {
* @brief Initializes solver according to the dimensions in the provided
* model
*
* @param model pointer to the model object
* @param model the model object
* @param linsol_type type of linear solver to use
*/
explicit NewtonSolver(Model const& model);
explicit NewtonSolver(Model const& model, LinearSolver linsol_type);

/**
* @brief Factory method to create a NewtonSolver based on linsolType
*
* @param simulationSolver solver with settings
* @param model pointer to the model instance
* @return solver NewtonSolver according to the specified linsolType
*/
static std::unique_ptr<NewtonSolver>
getSolver(Solver const& simulationSolver, Model const& model);
NewtonSolver(NewtonSolver const&) = delete;

NewtonSolver& operator=(NewtonSolver const& other) = delete;

/**
* @brief Computes the solution of one Newton iteration
*
* @param delta containing the RHS of the linear system, will be
* overwritten by solution to the linear system
* @param model pointer to the model instance
* @param model the model instance
* @param state current simulation state
*/
void getStep(AmiVector& delta, Model& model, SimulationState const& state);

/**
* @brief Computes steady state sensitivities
*
* @param sx pointer to state variable sensitivities
* @param model pointer to the model instance
* @param sx state variable sensitivities
* @param model the model instance
* @param state current simulation state
*/
void computeNewtonSensis(
Expand All @@ -65,51 +57,45 @@ class NewtonSolver {
* @brief Writes the Jacobian for the Newton iteration and passes it to the
* linear solver
*
* @param model pointer to the model instance
* @param model the model instance
* @param state current simulation state
*/
virtual void prepareLinearSystem(Model& model, SimulationState const& state)
= 0;
void prepareLinearSystem(Model& model, SimulationState const& state);

/**
* Writes the Jacobian (JB) for the Newton iteration and passes it to the
* linear solver
*
* @param model pointer to the model instance
* @param model the model instance
* @param state current simulation state
*/
virtual void
prepareLinearSystemB(Model& model, SimulationState const& state)
= 0;
void prepareLinearSystemB(Model& model, SimulationState const& state);

/**
* @brief Solves the linear system for the Newton step
*
* @param rhs containing the RHS of the linear system, will be
* overwritten by solution to the linear system
*/
virtual void solveLinearSystem(AmiVector& rhs) = 0;
void solveLinearSystem(AmiVector& rhs);

/**
* @brief Reinitialize the linear solver
*
*/
virtual void reinitialize() = 0;
void reinitialize();

/**
* @brief Checks whether linear system is singular
* @brief Checks whether the linear system is singular
*
* @param model pointer to the model instance
* @param model the model instance
* @param state current simulation state
* @return boolean indicating whether the linear system is singular
* (condition number < 1/machine precision)
*/
virtual bool is_singular(Model& model, SimulationState const& state) const
= 0;

virtual ~NewtonSolver() = default;
bool is_singular(Model& model, SimulationState const& state) const;

protected:
private:
/** dummy rhs, used as dummy argument when computing J and JB */
AmiVector xdot_;
/** dummy state, attached to linear solver */
Expand All @@ -119,88 +105,9 @@ class NewtonSolver {
/** dummy differential adjoint state, used as dummy argument when computing
* JB */
AmiVector dxB_;
};

/**
* @brief The NewtonSolverDense provides access to the dense linear solver for
* the Newton method.
*/

class NewtonSolverDense : public NewtonSolver {

public:
/**
* @brief constructor for sparse solver
*
* @param model model instance that provides problem dimensions
*/
explicit NewtonSolverDense(Model const& model);

NewtonSolverDense(NewtonSolverDense const&) = delete;

NewtonSolverDense& operator=(NewtonSolverDense const& other) = delete;

~NewtonSolverDense() override;

void solveLinearSystem(AmiVector& rhs) override;

void
prepareLinearSystem(Model& model, SimulationState const& state) override;

void
prepareLinearSystemB(Model& model, SimulationState const& state) override;

void reinitialize() override;

bool is_singular(Model& model, SimulationState const& state) const override;

private:
/** temporary storage of Jacobian */
SUNMatrixWrapper Jtmp_;

/** dense linear solver */
SUNLinearSolver linsol_{nullptr};
};

/**
* @brief The NewtonSolverSparse provides access to the sparse linear solver for
* the Newton method.
*/

class NewtonSolverSparse : public NewtonSolver {

public:
/**
* @brief constructor for dense solver
*
* @param model model instance that provides problem dimensions
*/
explicit NewtonSolverSparse(Model const& model);

NewtonSolverSparse(NewtonSolverSparse const&) = delete;

NewtonSolverSparse& operator=(NewtonSolverSparse const& other) = delete;

~NewtonSolverSparse() override;

void solveLinearSystem(AmiVector& rhs) override;

void
prepareLinearSystem(Model& model, SimulationState const& state) override;

void
prepareLinearSystemB(Model& model, SimulationState const& state) override;

bool is_singular(Model& model, SimulationState const& state) const override;

void reinitialize() override;

private:
/** temporary storage of Jacobian */
SUNMatrixWrapper Jtmp_;

/** sparse linear solver */
SUNLinearSolver linsol_{nullptr};
/** linear solver */
std::unique_ptr<SUNLinSolWrapper> linsol_;
};

} // namespace amici
Expand Down
2 changes: 1 addition & 1 deletion include/amici/steadystateproblem.h
Original file line number Diff line number Diff line change
Expand Up @@ -452,7 +452,7 @@ class SteadystateProblem {
realtype rtol_quad_{NAN};

/** newton solver */
std::unique_ptr<NewtonSolver> newton_solver_{nullptr};
NewtonSolver newton_solver_;

/** damping factor flag */
NewtonDampingFactorMode damping_factor_mode_{NewtonDampingFactorMode::on};
Expand Down
11 changes: 10 additions & 1 deletion include/amici/sundials_linsol_wrapper.h
Original file line number Diff line number Diff line change
Expand Up @@ -195,7 +195,8 @@ class SUNLinSolKLU : public SUNLinSolWrapper {
* @param ordering
*/
SUNLinSolKLU(
AmiVector const& x, int nnz, int sparsetype, StateOrdering ordering = StateOrdering::COLAMD
AmiVector const& x, int nnz, int sparsetype,
StateOrdering ordering = StateOrdering::COLAMD
);

/**
Expand All @@ -214,6 +215,14 @@ class SUNLinSolKLU : public SUNLinSolWrapper {
* @param ordering
*/
void setOrdering(StateOrdering ordering);

/**
* @brief Checks whether the linear system is singular
*
* @return boolean indicating whether the linear system is singular
* (condition number < 1/machine precision)
*/
bool is_singular() const;
};

#ifdef SUNDIALS_SUPERLUMT
Expand Down
Loading

0 comments on commit f66f0b0

Please sign in to comment.