38 #include "CvodeAdaptor.hpp"
42 #include "TimeStepper.hpp"
49 #include <sundials/sundials_nvector.h>
51 #if CHASTE_SUNDIALS_VERSION >= 30000
52 #include <cvode/cvode_direct.h>
53 #include <sundials/sundials_types.h>
54 #include <sunlinsol/sunlinsol_dense.h>
55 #include <sunmatrix/sunmatrix_dense.h>
57 #include <cvode/cvode_dense.h>
78 assert(pData !=
nullptr);
81 static std::vector<realtype> ydot_vec;
91 std::cerr <<
"CVODE RHS Exception: " << e.
GetMessage() << std::endl
120 int CvodeRootAdaptor(realtype t,
N_Vector y, realtype* pGOut,
void* pData)
122 assert(pData !=
nullptr);
133 std::cerr <<
"CVODE Root Exception: " << e.
GetMessage() << std::endl
175 void CvodeErrorHandler(
int errorCode,
const char* module,
const char*
function,
176 char* message,
void* pData)
178 std::stringstream err;
179 err <<
"CVODE Error " << errorCode <<
" in module " << module
180 <<
" function " <<
function <<
": " << message;
181 std::cerr <<
"*" << err.str() << std::endl
188 std::vector<double>& rInitialY,
193 assert(maxStep > 0.0);
196 N_Vector initial_values = N_VMake_Serial(rInitialY.size(), &(rInitialY[0]));
197 assert(NV_DATA_S(initial_values) == &(rInitialY[0]));
198 assert(!NV_OWN_DATA_S(initial_values));
208 for (
unsigned i = 0; i < size; i++)
224 CVodeSetErrHandlerFn(
mpCvodeMem, CvodeErrorHandler,
nullptr);
228 #if CHASTE_SUNDIALS_VERSION >= 20400
235 #if CHASTE_SUNDIALS_VERSION >= 20400
236 CVodeInit(
mpCvodeMem, CvodeRhsAdaptor, startTime, initial_values);
239 CVodeMalloc(
mpCvodeMem, CvodeRhsAdaptor, startTime, initial_values,
246 #if CHASTE_SUNDIALS_VERSION >= 20400
247 CVodeRootInit(
mpCvodeMem, 1, CvodeRootAdaptor);
253 #if CHASTE_SUNDIALS_VERSION >= 30000
255 mpSundialsDenseMatrix = SUNDenseMatrix(rInitialY.size(), rInitialY.size());
258 mpSundialsLinearSolver = SUNDenseLinearSolver(initial_values, mpSundialsDenseMatrix);
261 CVDlsSetLinearSolver(
mpCvodeMem, mpSundialsLinearSolver, mpSundialsDenseMatrix);
272 #if CHASTE_SUNDIALS_VERSION >= 20400
278 #if CHASTE_SUNDIALS_VERSION >= 20400
279 CVodeReInit(
mpCvodeMem, startTime, initial_values);
282 CVodeReInit(
mpCvodeMem, CvodeRhsAdaptor, startTime, initial_values,
286 #if CHASTE_SUNDIALS_VERSION >= 30000
287 if (mpSundialsLinearSolver)
290 SUNLinSolFree(mpSundialsLinearSolver);
292 if (mpSundialsDenseMatrix)
295 SUNMatDestroy(mpSundialsDenseMatrix);
299 mpSundialsDenseMatrix = SUNDenseMatrix(rInitialY.size(), rInitialY.size());
302 mpSundialsLinearSolver = SUNDenseLinearSolver(initial_values, mpSundialsDenseMatrix);
305 CVDlsSetLinearSolver(
mpCvodeMem, mpSundialsLinearSolver, mpSundialsDenseMatrix);
330 #if CHASTE_SUNDIALS_VERSION >= 30000
331 if (mpSundialsLinearSolver)
334 SUNLinSolFree(mpSundialsLinearSolver);
336 mpSundialsLinearSolver =
nullptr;
338 if (mpSundialsDenseMatrix)
341 SUNMatDestroy(mpSundialsDenseMatrix);
343 mpSundialsDenseMatrix =
nullptr;
374 std::stringstream err;
375 char* p_flag_name = CVodeGetReturnFlagName(flag);
376 err << msg <<
": " << p_flag_name;
378 std::cerr << err.str() << std::endl
384 std::vector<double>& rYValues,
390 assert(endTime > startTime);
391 assert(timeSampling > 0.0);
396 EXCEPTION(
"(Solve with sampling) Stopping event is true for initial condition");
399 SetupCvode(pOdeSystem, rYValues, startTime, maxStep);
401 TimeStepper stepper(startTime, endTime, timeSampling);
402 N_Vector yout = N_VMake_Serial(rYValues.size(), &(rYValues[0]));
408 solutions.
rGetTimes().push_back(startTime);
416 assert(ierr == CV_SUCCESS);
425 CvodeError(ierr,
"CVODE failed to solve system");
430 if (ierr == CV_ROOT_RETURN)
444 assert(ierr == CV_SUCCESS);
453 std::vector<double>& rYValues,
458 assert(endTime > startTime);
463 EXCEPTION(
"(Solve) Stopping event is true for initial condition");
466 SetupCvode(pOdeSystem, rYValues, startTime, maxStep);
468 N_Vector yout = N_VMake_Serial(rYValues.size(), &(rYValues[0]));
471 int ierr = CVodeSetStopTime(
mpCvodeMem, endTime);
472 assert(ierr == CV_SUCCESS);
476 ierr = CVode(
mpCvodeMem, endTime, yout, &tend, CV_NORMAL);
481 CvodeError(ierr,
"CVODE failed to solve system");
483 if (ierr == CV_ROOT_RETURN)
489 assert(NV_DATA_S(yout) == &(rYValues[0]));
490 assert(!NV_OWN_DATA_S(yout));
497 assert(ierr == CV_SUCCESS);
508 mLastInternalStepSize(-0.0),
510 mCheckForRoots(false),
511 mLastSolutionState(nullptr),
512 mLastSolutionTime(0.0),
513 #if CHASTE_SUNDIALS_VERSION >= 20400
518 mForceMinimalReset(false)
519 #if CHASTE_SUNDIALS_VERSION >= 30000
521 mpSundialsDenseMatrix(nullptr),
522 mpSundialsLinearSolver(nullptr)
539 for (
unsigned i = 0; i < size; i++)
587 #endif // CHASTE_CVODE
AbstractOdeSystem * pSystem
bool mStoppingEventOccurred
void SetupCvode(AbstractOdeSystem *pOdeSystem, std::vector< double > &rInitialY, double startTime, double maxStep)
static bool WithinAnyTolerance(double number1, double number2, double relTol=DBL_EPSILON, double absTol=DBL_EPSILON, bool printError=false)
CvodeAdaptor(double relTol=1e-4, double absTol=1e-6)
std::vector< std::vector< double > > & rGetSolutions()
void CheckForStoppingEvents()
#define EXCEPTION(message)
N_Vector mLastSolutionState
virtual void EvaluateYDerivatives(double time, const std::vector< double > &rY, std::vector< double > &rDY)=0
double mLastInternalStepSize
void AdvanceOneTimeStep()
double GetAbsoluteTolerance()
unsigned GetNumberOfStateVariables() const
void CreateVectorIfEmpty(VECTOR &rVec, unsigned size)
void SetOdeSystemInformation(boost::shared_ptr< const AbstractOdeSystemInformation > pOdeSystemInfo)
void CvodeError(int flag, const char *msg)
void SetMinimalReset(bool minimalReset)
std::string GetMessage() const
virtual double CalculateRootFunction(double time, const std::vector< double > &rY)
void SetVectorComponent(VECTOR &rVec, unsigned index, double value)
std::vector< realtype > * pY
unsigned GetTotalTimeStepsTaken() const
boost::shared_ptr< const AbstractOdeSystemInformation > GetSystemInformation() const
std::vector< double > & rGetTimes()
unsigned EstimateTimeSteps() const
void SetTolerances(double relTol=1e-4, double absTol=1e-6)
#define CHASTE_CLASS_EXPORT(T)
void RecordStoppingPoint(double stopTime, N_Vector yEnd)
void SetForceReset(bool autoReset)
void CopyFromStdVector(const std::vector< double > &rSrc, VECTOR &rDest)
void CopyToStdVector(const VECTOR &rSrc, std::vector< double > &rDest)
double GetNextTime() const
double GetVectorComponent(const VECTOR &rVec, unsigned index)
void SetMaxSteps(long int numSteps)
double GetRelativeTolerance()
OdeSolution Solve(AbstractOdeSystem *pOdeSystem, std::vector< double > &rYValues, double startTime, double endTime, double maxStep, double timeSampling)
void SetNumberOfTimeSteps(unsigned numTimeSteps)
void DeleteVector(VECTOR &rVec)
unsigned GetVectorSize(const VECTOR &rVec)
virtual bool CalculateStoppingEvent(double time, const std::vector< double > &rY)