00001
00002
00003
00004
00005
00006
00007
00008
00009
00010
00011
00012
00013
00014
00015
00016
00017
00018
00019
00020
00021
00022
00023
00024
00025
00026
00027
00028
00029
00030
00031
00032
00033
00034
00035
00036 #ifdef CHASTE_CVODE
00037
00038 #include "CvodeAdaptor.hpp"
00039
00040 #include "Exception.hpp"
00041 #include "TimeStepper.hpp"
00042 #include "VectorHelperFunctions.hpp"
00043 #include "MathsCustomFunctions.hpp"
00044
00045 #include <iostream>
00046 #include <sstream>
00047
00048
00049 #include <sundials/sundials_nvector.h>
00050 #include <cvode/cvode_dense.h>
00051
00052
00069 int CvodeRhsAdaptor(realtype t, N_Vector y, N_Vector ydot, void* pData)
00070 {
00071 assert(pData != NULL);
00072 CvodeData* p_data = (CvodeData*) pData;
00073
00074 static std::vector<realtype> ydot_vec;
00075 CopyToStdVector(y, *p_data->pY);
00076 CopyToStdVector(ydot, ydot_vec);
00077
00078 try
00079 {
00080 p_data->pSystem->EvaluateYDerivatives(t, *(p_data->pY), ydot_vec);
00081 }
00082 catch (const Exception &e)
00083 {
00084 std::cerr << "CVODE RHS Exception: " << e.GetMessage() << std::endl << std::flush;
00085 return -1;
00086 }
00087
00088 CopyFromStdVector(ydot_vec, ydot);
00089 return 0;
00090 }
00091
00112 int CvodeRootAdaptor(realtype t, N_Vector y, realtype* pGOut, void* pData)
00113 {
00114 assert(pData != NULL);
00115 CvodeData* p_data = (CvodeData*) pData;
00116
00117 CopyToStdVector(y, *p_data->pY);
00118
00119 try
00120 {
00121 *pGOut = p_data->pSystem->CalculateRootFunction(t, *p_data->pY);
00122 }
00123 catch (const Exception &e)
00124 {
00125 std::cerr << "CVODE Root Exception: " << e.GetMessage() << std::endl << std::flush;
00126 return -1;
00127 }
00128 return 0;
00129 }
00130
00131
00132
00133
00134
00135
00136
00137
00138
00139
00140
00141
00142
00143
00144
00145
00146
00147
00148
00149
00150
00151
00152
00153
00154
00155
00156
00157
00158
00159
00160
00161
00162
00163
00164
00165
00166
00167
00168 void CvodeErrorHandler(int errorCode, const char *module, const char *function,
00169 char *message, void* pData)
00170 {
00171 std::stringstream err;
00172 err << "CVODE Error " << errorCode << " in module " << module
00173 << " function " << function << ": " << message;
00174 std::cerr << "*" << err.str() << std::endl << std::flush;
00175
00176
00177 }
00178
00179
00180
00181 void CvodeAdaptor::SetupCvode(AbstractOdeSystem* pOdeSystem,
00182 std::vector<double>& rInitialY,
00183 double startTime,
00184 double maxStep)
00185 {
00186 assert(rInitialY.size() == pOdeSystem->GetNumberOfStateVariables());
00187 assert(maxStep > 0.0);
00188
00190 N_Vector initial_values = N_VMake_Serial(rInitialY.size(), &(rInitialY[0]));
00191 assert(NV_DATA_S(initial_values) == &(rInitialY[0]));
00192 assert(!NV_OWN_DATA_S(initial_values));
00193
00194
00195
00196
00197
00198 bool reinit = !mpCvodeMem || mForceReset || !mLastSolutionState || !CompareDoubles::WithinAnyTolerance(startTime, mLastSolutionTime);
00199 if (!reinit && !mForceMinimalReset)
00200 {
00201 const unsigned size = GetVectorSize(rInitialY);
00202 for (unsigned i=0; i<size; i++)
00203 {
00204 if (!CompareDoubles::WithinAnyTolerance(GetVectorComponent(mLastSolutionState, i), GetVectorComponent(rInitialY, i)))
00205 {
00206 reinit = true;
00207 break;
00208 }
00209 }
00210 }
00211
00212 if (!mpCvodeMem)
00213 {
00214
00215 mpCvodeMem = CVodeCreate(CV_BDF, CV_NEWTON);
00216 if (mpCvodeMem == NULL) EXCEPTION("Failed to SetupCvode CVODE");
00217
00218 CVodeSetErrHandlerFn(mpCvodeMem, CvodeErrorHandler, NULL);
00219
00220 mData.pSystem = pOdeSystem;
00221 mData.pY = &rInitialY;
00222 #if CHASTE_SUNDIALS_VERSION >= 20400
00223 CVodeSetUserData(mpCvodeMem, (void*)(&mData));
00224 #else
00225 CVodeSetFdata(mpCvodeMem, (void*)(&mData));
00226 #endif
00227
00228
00229 #if CHASTE_SUNDIALS_VERSION >= 20400
00230 CVodeInit(mpCvodeMem, CvodeRhsAdaptor, startTime, initial_values);
00231 CVodeSStolerances(mpCvodeMem, mRelTol, mAbsTol);
00232 #else
00233 CVodeMalloc(mpCvodeMem, CvodeRhsAdaptor, startTime, initial_values,
00234 CV_SS, mRelTol, &mAbsTol);
00235 #endif
00236
00237
00238 if (mCheckForRoots)
00239 {
00240 #if CHASTE_SUNDIALS_VERSION >= 20400
00241 CVodeRootInit(mpCvodeMem, 1, CvodeRootAdaptor);
00242 #else
00243 CVodeRootInit(mpCvodeMem, 1, CvodeRootAdaptor, (void*)(&mData));
00244 #endif
00245 }
00246
00247 CVDense(mpCvodeMem, rInitialY.size());
00248 }
00249 else if (reinit)
00250 {
00251
00252 mData.pSystem = pOdeSystem;
00253 mData.pY = &rInitialY;
00254 #if CHASTE_SUNDIALS_VERSION >= 20400
00255 CVodeSetUserData(mpCvodeMem, (void*)(&mData));
00256 #else
00257 CVodeSetFdata(mpCvodeMem, (void*)(&mData));
00258 #endif
00259
00260 #if CHASTE_SUNDIALS_VERSION >= 20400
00261 CVodeReInit(mpCvodeMem, startTime, initial_values);
00262 CVodeSStolerances(mpCvodeMem, mRelTol, mAbsTol);
00263 #else
00264 CVodeReInit(mpCvodeMem, CvodeRhsAdaptor, startTime, initial_values,
00265 CV_SS, mRelTol, &mAbsTol);
00266 #endif
00267
00268
00269 CVDense(mpCvodeMem, rInitialY.size());
00270 }
00271
00272 CVodeSetMaxStep(mpCvodeMem, maxStep);
00273
00274 if (mMaxSteps > 0)
00275 {
00276 CVodeSetMaxNumSteps(mpCvodeMem, mMaxSteps);
00277 CVodeSetMaxErrTestFails(mpCvodeMem, 15);
00278 }
00279 DeleteVector(initial_values);
00280 }
00281
00282 void CvodeAdaptor::FreeCvodeMemory()
00283 {
00284 if (mpCvodeMem)
00285 {
00286 CVodeFree(&mpCvodeMem);
00287 }
00288 mpCvodeMem = NULL;
00289 }
00290
00291
00292 void CvodeAdaptor::SetForceReset(bool autoReset)
00293 {
00294 mForceReset = autoReset;
00295 if (mForceReset)
00296 {
00297 SetMinimalReset(false);
00298 ResetSolver();
00299 }
00300 }
00301
00302 void CvodeAdaptor::SetMinimalReset(bool minimalReset)
00303 {
00304 mForceMinimalReset = minimalReset;
00305 if (mForceMinimalReset)
00306 {
00307 SetForceReset(false);
00308 }
00309 }
00310
00311 void CvodeAdaptor::ResetSolver()
00312 {
00313
00314 DeleteVector(mLastSolutionState);
00315 }
00316
00317 void CvodeAdaptor::CvodeError(int flag, const char * msg)
00318 {
00319 std::stringstream err;
00320 char* p_flag_name = CVodeGetReturnFlagName(flag);
00321 err << msg << ": " << p_flag_name;
00322 free(p_flag_name);
00323 std::cerr << err.str() << std::endl << std::flush;
00324 EXCEPTION(err.str());
00325 }
00326
00327
00328 OdeSolution CvodeAdaptor::Solve(AbstractOdeSystem* pOdeSystem,
00329 std::vector<double>& rYValues,
00330 double startTime,
00331 double endTime,
00332 double maxStep,
00333 double timeSampling)
00334 {
00335 assert(endTime > startTime);
00336 assert(timeSampling > 0.0);
00337
00338 mStoppingEventOccurred = false;
00339 if (mCheckForRoots && pOdeSystem->CalculateStoppingEvent(startTime, rYValues) == true)
00340 {
00341 EXCEPTION("(Solve with sampling) Stopping event is true for initial condition");
00342 }
00343
00344 SetupCvode(pOdeSystem, rYValues, startTime, maxStep);
00345
00346 TimeStepper stepper(startTime, endTime, timeSampling);
00347 N_Vector yout = N_VMake_Serial(rYValues.size(), &(rYValues[0]));
00348
00349
00350 OdeSolution solutions;
00351 solutions.SetNumberOfTimeSteps(stepper.EstimateTimeSteps());
00352 solutions.rGetSolutions().push_back(rYValues);
00353 solutions.rGetTimes().push_back(startTime);
00354 solutions.SetOdeSystemInformation(pOdeSystem->GetSystemInformation());
00355
00356
00357 while (!stepper.IsTimeAtEnd() && !mStoppingEventOccurred)
00358 {
00359
00360 int ierr = CVodeSetStopTime(mpCvodeMem, stepper.GetNextTime());
00361 assert(ierr == CV_SUCCESS); UNUSED_OPT(ierr);
00362
00363 double tend;
00364 ierr = CVode(mpCvodeMem, stepper.GetNextTime(), yout, &tend, CV_NORMAL);
00365 if (ierr<0)
00366 {
00367 FreeCvodeMemory();
00368 DeleteVector(yout);
00369 CvodeError(ierr, "CVODE failed to solve system");
00370 }
00371
00372 solutions.rGetSolutions().push_back(rYValues);
00373 solutions.rGetTimes().push_back(tend);
00374 if (ierr == CV_ROOT_RETURN)
00375 {
00376
00377 mStoppingEventOccurred = true;
00378 mStoppingTime = tend;
00379 }
00380 mLastSolutionTime = tend;
00381 stepper.AdvanceOneTimeStep();
00382 }
00383
00384
00385 solutions.SetNumberOfTimeSteps(stepper.GetTotalTimeStepsTaken());
00386
00387 int ierr = CVodeGetLastStep(mpCvodeMem, &mLastInternalStepSize);
00388 assert(ierr == CV_SUCCESS); UNUSED_OPT(ierr);
00389 RecordStoppingPoint(mLastSolutionTime, yout);
00390 DeleteVector(yout);
00391
00392 return solutions;
00393 }
00394
00395
00396 void CvodeAdaptor::Solve(AbstractOdeSystem* pOdeSystem,
00397 std::vector<double>& rYValues,
00398 double startTime,
00399 double endTime,
00400 double maxStep)
00401 {
00402 assert(endTime > startTime);
00403
00404 mStoppingEventOccurred = false;
00405 if (mCheckForRoots && pOdeSystem->CalculateStoppingEvent(startTime, rYValues) == true)
00406 {
00407 EXCEPTION("(Solve) Stopping event is true for initial condition");
00408 }
00409
00410 SetupCvode(pOdeSystem, rYValues, startTime, maxStep);
00411
00412 N_Vector yout = N_VMake_Serial(rYValues.size(), &(rYValues[0]));
00413
00414
00415 int ierr = CVodeSetStopTime(mpCvodeMem, endTime);
00416 assert(ierr == CV_SUCCESS); UNUSED_OPT(ierr);
00417
00418 double tend;
00419 ierr = CVode(mpCvodeMem, endTime, yout, &tend, CV_NORMAL);
00420 if (ierr<0)
00421 {
00422 FreeCvodeMemory();
00423 DeleteVector(yout);
00424 CvodeError(ierr, "CVODE failed to solve system");
00425 }
00426 if (ierr == CV_ROOT_RETURN)
00427 {
00428
00429 mStoppingEventOccurred = true;
00430 mStoppingTime = tend;
00431 }
00432 assert(NV_DATA_S(yout) == &(rYValues[0]));
00433 assert(!NV_OWN_DATA_S(yout));
00434
00435
00436
00437
00438
00439 ierr = CVodeGetLastStep(mpCvodeMem, &mLastInternalStepSize);
00440 assert(ierr == CV_SUCCESS); UNUSED_OPT(ierr);
00441 RecordStoppingPoint(tend, yout);
00442 DeleteVector(yout);
00443 }
00444
00445 CvodeAdaptor::CvodeAdaptor(double relTol, double absTol)
00446 : AbstractIvpOdeSolver(),
00447 mpCvodeMem(NULL),
00448 mRelTol(relTol),
00449 mAbsTol(absTol),
00450 mLastInternalStepSize(-0.0),
00451 mMaxSteps(0),
00452 mCheckForRoots(false),
00453 mLastSolutionState(NULL),
00454 mLastSolutionTime(0.0),
00455 #if CHASTE_SUNDIALS_VERSION >= 20400
00456 mForceReset(false),
00457 #else
00458 mForceReset(true),
00459 #endif
00460 mForceMinimalReset(false)
00461 {
00462 }
00463
00464 CvodeAdaptor::~CvodeAdaptor()
00465 {
00466 FreeCvodeMemory();
00467 DeleteVector(mLastSolutionState);
00468 }
00469
00470 void CvodeAdaptor::RecordStoppingPoint(double stopTime, N_Vector yEnd)
00471 {
00472 if (!mForceReset)
00473 {
00474 const unsigned size = GetVectorSize(yEnd);
00475 CreateVectorIfEmpty(mLastSolutionState, size);
00476 for (unsigned i=0; i<size; i++)
00477 {
00478 SetVectorComponent(mLastSolutionState, i, GetVectorComponent(yEnd, i));
00479 }
00480 mLastSolutionTime = stopTime;
00481 }
00482 }
00483
00484 void CvodeAdaptor::SetTolerances(double relTol, double absTol)
00485 {
00486 mRelTol = relTol;
00487 mAbsTol = absTol;
00488 }
00489
00490 double CvodeAdaptor::GetRelativeTolerance()
00491 {
00492 return mRelTol;
00493 }
00494
00495 double CvodeAdaptor::GetAbsoluteTolerance()
00496 {
00497 return mAbsTol;
00498 }
00499
00500 double CvodeAdaptor::GetLastStepSize()
00501 {
00502 return mLastInternalStepSize;
00503 }
00504
00505 void CvodeAdaptor::CheckForStoppingEvents()
00506 {
00507 mCheckForRoots = true;
00508 }
00509
00510 void CvodeAdaptor::SetMaxSteps(long int numSteps)
00511 {
00512 mMaxSteps = numSteps;
00513 }
00514
00515 long int CvodeAdaptor::GetMaxSteps()
00516 {
00517 return mMaxSteps;
00518 }
00519
00520
00521 #include "SerializationExportWrapperForCpp.hpp"
00522 CHASTE_CLASS_EXPORT(CvodeAdaptor)
00523
00524 #endif // CHASTE_CVODE