00001
00002
00003
00004
00005
00006
00007
00008
00009
00010
00011
00012
00013
00014
00015
00016
00017
00018
00019
00020
00021
00022
00023
00024
00025
00026
00027
00028
00029 #ifdef CHASTE_CVODE
00030
00031 #include "CvodeAdaptor.hpp"
00032
00033 #include "Exception.hpp"
00034 #include "TimeStepper.hpp"
00035 #include "VectorHelperFunctions.hpp"
00036
00037 #include <iostream>
00038 #include <sstream>
00039
00040
00041 #include <sundials/sundials_nvector.h>
00042 #include <cvode/cvode_dense.h>
00043
00044
00061 int CvodeRhsAdaptor(realtype t, N_Vector y, N_Vector ydot, void* pData)
00062 {
00063 assert(pData != NULL);
00064 CvodeData* p_data = (CvodeData*) pData;
00065
00066 static std::vector<realtype> ydot_vec;
00067 CopyToStdVector(y, *p_data->pY);
00068 CopyToStdVector(ydot, ydot_vec);
00069
00070 try
00071 {
00072 p_data->pSystem->EvaluateYDerivatives(t, *(p_data->pY), ydot_vec);
00073 }
00074 catch (const Exception &e)
00075 {
00076 std::cerr << "CVODE RHS Exception: " << e.GetMessage() << std::endl << std::flush;
00077 return -1;
00078 }
00079
00080 CopyFromStdVector(ydot_vec, ydot);
00081 return 0;
00082 }
00083
00104 int CvodeRootAdaptor(realtype t, N_Vector y, realtype* pGOut, void* pData)
00105 {
00106 assert(pData != NULL);
00107 CvodeData* p_data = (CvodeData*) pData;
00108
00109 CopyToStdVector(y, *p_data->pY);
00110
00111 try
00112 {
00113 *pGOut = p_data->pSystem->CalculateRootFunction(t, *p_data->pY);
00114 }
00115 catch (const Exception &e)
00116 {
00117 std::cerr << "CVODE Root Exception: " << e.GetMessage() << std::endl << std::flush;
00118 return -1;
00119 }
00120 return 0;
00121 }
00122
00123
00124
00125
00126
00127
00128
00129
00130
00131
00132
00133
00134
00135
00136
00137
00138
00139
00140
00141
00142
00143
00144
00145
00146
00147
00148
00149
00150
00151
00152
00153
00154
00155
00156
00157
00158
00159
00160 void CvodeErrorHandler(int errorCode, const char *module, const char *function,
00161 char *message, void* pData)
00162 {
00163 std::stringstream err;
00164 err << "CVODE Error " << errorCode << " in module " << module
00165 << " function " << function << ": " << message;
00166 std::cerr << "*" << err.str() << std::endl << std::flush;
00167
00168
00169 }
00170
00171
00172
00173 void CvodeAdaptor::SetupCvode(AbstractOdeSystem* pOdeSystem,
00174 std::vector<double>& rInitialY,
00175 double startTime, double maxStep)
00176 {
00177 assert(rInitialY.size() == pOdeSystem->GetNumberOfStateVariables());
00178 assert(maxStep > 0.0);
00179
00180 mInitialValues = N_VMake_Serial(rInitialY.size(), &(rInitialY[0]));
00181 assert(NV_DATA_S(mInitialValues) == &(rInitialY[0]));
00182 assert(!NV_OWN_DATA_S(mInitialValues));
00183
00184
00185
00186
00187 mpCvodeMem = CVodeCreate(CV_BDF, CV_NEWTON);
00188 if (mpCvodeMem == NULL) EXCEPTION("Failed to SetupCvode CVODE");
00189
00190 CVodeSetErrHandlerFn(mpCvodeMem, CvodeErrorHandler, NULL);
00191
00192 mData.pSystem = pOdeSystem;
00193 mData.pY = &rInitialY;
00194 #if CHASTE_SUNDIALS_VERSION >= 20400
00195 CVodeSetUserData(mpCvodeMem, (void*)(&mData));
00196 #else
00197 CVodeSetFdata(mpCvodeMem, (void*)(&mData));
00198 #endif
00199
00200 #if CHASTE_SUNDIALS_VERSION >= 20400
00201 CVodeInit(mpCvodeMem, CvodeRhsAdaptor, startTime, mInitialValues);
00202 CVodeSStolerances(mpCvodeMem, mRelTol, mAbsTol);
00203 #else
00204 CVodeMalloc(mpCvodeMem, CvodeRhsAdaptor, startTime, mInitialValues,
00205 CV_SS, mRelTol, &mAbsTol);
00206 #endif
00207 CVodeSetMaxStep(mpCvodeMem, maxStep);
00208
00209 if (mCheckForRoots)
00210 {
00211 #if CHASTE_SUNDIALS_VERSION >= 20400
00212 CVodeRootInit(mpCvodeMem, 1, CvodeRootAdaptor);
00213 #else
00214 CVodeRootInit(mpCvodeMem, 1, CvodeRootAdaptor, (void*)(&mData));
00215 #endif
00216 }
00217
00218 if (mMaxSteps > 0)
00219 {
00220 CVodeSetMaxNumSteps(mpCvodeMem, mMaxSteps);
00221 }
00222
00223 CVDense(mpCvodeMem, rInitialY.size());
00224 }
00225
00226 void CvodeAdaptor::FreeCvodeMemory()
00227 {
00228
00229
00230
00231 N_VDestroy_Serial(mInitialValues); mInitialValues = NULL;
00232
00233
00234 CVodeFree(&mpCvodeMem);
00235 }
00236
00237
00238 void CvodeAdaptor::CvodeError(int flag, const char * msg)
00239 {
00240 std::stringstream err;
00241 char* p_flag_name = CVodeGetReturnFlagName(flag);
00242 err << msg << ": " << p_flag_name;
00243 free(p_flag_name);
00244 std::cerr << err.str() << std::endl << std::flush;
00245 EXCEPTION(err.str());
00246 }
00247
00248
00249 OdeSolution CvodeAdaptor::Solve(AbstractOdeSystem* pOdeSystem,
00250 std::vector<double>& rYValues,
00251 double startTime,
00252 double endTime,
00253 double maxStep,
00254 double timeSampling)
00255 {
00256 assert(endTime > startTime);
00257 assert(timeSampling > 0.0);
00258
00259 mStoppingEventOccurred = false;
00260 if (mCheckForRoots && pOdeSystem->CalculateStoppingEvent(startTime, rYValues) == true)
00261 {
00262 EXCEPTION("(Solve with sampling) Stopping event is true for initial condition");
00263 }
00264
00265 SetupCvode(pOdeSystem, rYValues, startTime, maxStep);
00266
00267 TimeStepper stepper(startTime, endTime, timeSampling);
00268 N_Vector yout = mInitialValues;
00269
00270
00271 OdeSolution solutions;
00272 solutions.SetNumberOfTimeSteps(stepper.EstimateTimeSteps());
00273 solutions.rGetSolutions().push_back(rYValues);
00274 solutions.rGetTimes().push_back(startTime);
00275 solutions.SetOdeSystemInformation(pOdeSystem->GetSystemInformation());
00276
00277
00278 while (!stepper.IsTimeAtEnd() && !mStoppingEventOccurred)
00279 {
00280 double tend;
00281 int ierr = CVode(mpCvodeMem, stepper.GetNextTime(), yout, &tend, CV_NORMAL);
00282 if (ierr<0)
00283 {
00284 FreeCvodeMemory();
00285 CvodeError(ierr, "CVODE failed to solve system");
00286 }
00287
00288 solutions.rGetSolutions().push_back(rYValues);
00289 solutions.rGetTimes().push_back(tend);
00290 if (ierr == CV_ROOT_RETURN)
00291 {
00292
00293 mStoppingEventOccurred = true;
00294 mStoppingTime = tend;
00295 }
00296 stepper.AdvanceOneTimeStep();
00297 }
00298
00299
00300 solutions.SetNumberOfTimeSteps(stepper.GetTotalTimeStepsTaken());
00301
00302 int ierr = CVodeGetLastStep(mpCvodeMem, &mLastInternalStepSize);
00303 assert(ierr == CV_SUCCESS); ierr=ierr;
00304 FreeCvodeMemory();
00305
00306 return solutions;
00307 }
00308
00309
00310 void CvodeAdaptor::Solve(AbstractOdeSystem* pOdeSystem,
00311 std::vector<double>& rYValues,
00312 double startTime,
00313 double endTime,
00314 double maxStep)
00315 {
00316 assert(endTime > startTime);
00317
00318 mStoppingEventOccurred = false;
00319 if (mCheckForRoots && pOdeSystem->CalculateStoppingEvent(startTime, rYValues) == true)
00320 {
00321 EXCEPTION("(Solve) Stopping event is true for initial condition");
00322 }
00323
00324 SetupCvode(pOdeSystem, rYValues, startTime, maxStep);
00325
00326 N_Vector yout = mInitialValues;
00327 double tend;
00328 int ierr = CVode(mpCvodeMem, endTime, yout, &tend, CV_NORMAL);
00329 if (ierr<0)
00330 {
00331 FreeCvodeMemory();
00332 CvodeError(ierr, "CVODE failed to solve system");
00333 }
00334 if (ierr == CV_ROOT_RETURN)
00335 {
00336
00337 mStoppingEventOccurred = true;
00338 mStoppingTime = tend;
00339 }
00340 assert(NV_DATA_S(yout) == &(rYValues[0]));
00341 assert(!NV_OWN_DATA_S(yout));
00342
00343
00344
00345
00346
00347 ierr = CVodeGetLastStep(mpCvodeMem, &mLastInternalStepSize);
00348 assert(ierr == CV_SUCCESS); ierr=ierr;
00349 FreeCvodeMemory();
00350 }
00351
00352 CvodeAdaptor::CvodeAdaptor(double relTol, double absTol)
00353 : AbstractIvpOdeSolver(),
00354 mpCvodeMem(NULL), mInitialValues(NULL),
00355 mRelTol(relTol), mAbsTol(absTol),
00356 mLastInternalStepSize(-0.0),
00357 mMaxSteps(0),
00358 mCheckForRoots(false)
00359 {
00360 }
00361
00362 void CvodeAdaptor::SetTolerances(double relTol, double absTol)
00363 {
00364 mRelTol = relTol;
00365 mAbsTol = absTol;
00366 }
00367
00368 double CvodeAdaptor::GetRelativeTolerance()
00369 {
00370 return mRelTol;
00371 }
00372
00373 double CvodeAdaptor::GetAbsoluteTolerance()
00374 {
00375 return mAbsTol;
00376 }
00377
00378 double CvodeAdaptor::GetLastStepSize()
00379 {
00380 return mLastInternalStepSize;
00381 }
00382
00383 void CvodeAdaptor::CheckForStoppingEvents()
00384 {
00385 mCheckForRoots = true;
00386 }
00387
00388 void CvodeAdaptor::SetMaxSteps(long int numSteps)
00389 {
00390 mMaxSteps = numSteps;
00391 }
00392
00393 long int CvodeAdaptor::GetMaxSteps()
00394 {
00395 return mMaxSteps;
00396 }
00397
00398
00399 #include "SerializationExportWrapperForCpp.hpp"
00400 CHASTE_CLASS_EXPORT(CvodeAdaptor)
00401
00402 #endif // CHASTE_CVODE