00001
00002
00003
00004
00005
00006
00007
00008
00009
00010
00011
00012
00013
00014
00015
00016
00017
00018
00019
00020
00021
00022
00023
00024
00025
00026
00027
00028
00029 #ifdef CHASTE_CVODE
00030
00031 #include "CvodeAdaptor.hpp"
00032
00033 #include "Exception.hpp"
00034 #include "TimeStepper.hpp"
00035 #include "VectorHelperFunctions.hpp"
00036
00037 #include <iostream>
00038 #include <sstream>
00039
00040
00041 #include <sundials/sundials_nvector.h>
00042 #include <cvode/cvode_dense.h>
00043
00044
00061 int CvodeRhsAdaptor(realtype t, N_Vector y, N_Vector ydot, void* pData)
00062 {
00063 assert(pData != NULL);
00064 CvodeData* p_data = (CvodeData*) pData;
00065
00066 static std::vector<realtype> ydot_vec;
00067 CopyToStdVector(y, *p_data->pY);
00068 CopyToStdVector(ydot, ydot_vec);
00069
00070 try
00071 {
00072 p_data->pSystem->EvaluateYDerivatives(t, *(p_data->pY), ydot_vec);
00073 }
00074 catch (const Exception &e)
00075 {
00076 std::cerr << "CVODE RHS Exception: " << e.GetMessage() << std::endl << std::flush;
00077 return -1;
00078 }
00079
00080 CopyFromStdVector(ydot_vec, ydot);
00081 return 0;
00082 }
00083
00104 int CvodeRootAdaptor(realtype t, N_Vector y, realtype* pGOut, void* pData)
00105 {
00106 assert(pData != NULL);
00107 CvodeData* p_data = (CvodeData*) pData;
00108
00109 CopyToStdVector(y, *p_data->pY);
00110
00111 try
00112 {
00113 *pGOut = p_data->pSystem->CalculateRootFunction(t, *p_data->pY);
00114 }
00115 catch (const Exception &e)
00116 {
00117 std::cerr << "CVODE Root Exception: " << e.GetMessage() << std::endl << std::flush;
00118 return -1;
00119 }
00120 return 0;
00121 }
00122
00123
00124
00125
00126
00127
00128
00129
00130
00131
00132
00133
00134
00135
00136
00137
00138
00139
00140
00141
00142
00143
00144
00145
00146
00147
00148
00149
00150
00151
00152
00153
00154
00155
00156
00157
00158
00159
00160 void CvodeErrorHandler(int errorCode, const char *module, const char *function,
00161 char *message, void* pData)
00162 {
00163 std::stringstream err;
00164 err << "CVODE Error " << errorCode << " in module " << module
00165 << " function " << function << ": " << message;
00166 std::cerr << "*" << err.str() << std::endl << std::flush;
00167
00168
00169 }
00170
00171
00172
00173 void CvodeAdaptor::SetupCvode(AbstractOdeSystem* pOdeSystem,
00174 std::vector<double>& rInitialY,
00175 double startTime, double maxStep)
00176 {
00177 assert(rInitialY.size() == pOdeSystem->GetNumberOfStateVariables());
00178 assert(maxStep > 0.0);
00179
00180 mInitialValues = N_VMake_Serial(rInitialY.size(), &(rInitialY[0]));
00181 assert(NV_DATA_S(mInitialValues) == &(rInitialY[0]));
00182 assert(!NV_OWN_DATA_S(mInitialValues));
00183
00184
00185
00186
00187 mpCvodeMem = CVodeCreate(CV_BDF, CV_NEWTON);
00188 if (mpCvodeMem == NULL) EXCEPTION("Failed to SetupCvode CVODE");
00189
00190 CVodeSetErrHandlerFn(mpCvodeMem, CvodeErrorHandler, NULL);
00191
00192 mData.pSystem = pOdeSystem;
00193 mData.pY = &rInitialY;
00194 CVodeSetFdata(mpCvodeMem, (void*)(&mData));
00195
00196 CVodeMalloc(mpCvodeMem, CvodeRhsAdaptor, startTime, mInitialValues,
00197 CV_SS, mRelTol, &mAbsTol);
00198 CVodeSetMaxStep(mpCvodeMem, maxStep);
00199
00200 if (mCheckForRoots)
00201 {
00202 CVodeRootInit(mpCvodeMem, 1, CvodeRootAdaptor, (void*)(&mData));
00203 }
00204
00205 if (mMaxSteps > 0)
00206 {
00207 CVodeSetMaxNumSteps(mpCvodeMem, mMaxSteps);
00208 }
00209
00210 CVDense(mpCvodeMem, rInitialY.size());
00211 }
00212
00213 void CvodeAdaptor::FreeCvodeMemory()
00214 {
00215
00216
00217
00218 N_VDestroy_Serial(mInitialValues); mInitialValues = NULL;
00219
00220
00221 CVodeFree(&mpCvodeMem);
00222 }
00223
00224
00225 void CvodeAdaptor::CvodeError(int flag, const char * msg)
00226 {
00227 std::stringstream err;
00228 char* p_flag_name = CVodeGetReturnFlagName(flag);
00229 err << msg << ": " << p_flag_name;
00230 free(p_flag_name);
00231 std::cerr << err.str() << std::endl << std::flush;
00232 EXCEPTION(err.str());
00233 }
00234
00235
00236 OdeSolution CvodeAdaptor::Solve(AbstractOdeSystem* pOdeSystem,
00237 std::vector<double>& rYValues,
00238 double startTime,
00239 double endTime,
00240 double maxStep,
00241 double timeSampling)
00242 {
00243 assert(endTime > startTime);
00244 assert(timeSampling > 0.0);
00245
00246 mStoppingEventOccurred = false;
00247 if (mCheckForRoots && pOdeSystem->CalculateStoppingEvent(startTime, rYValues) == true)
00248 {
00249 EXCEPTION("(Solve with sampling) Stopping event is true for initial condition");
00250 }
00251
00252 SetupCvode(pOdeSystem, rYValues, startTime, maxStep);
00253
00254 TimeStepper stepper(startTime, endTime, timeSampling);
00255 N_Vector yout = mInitialValues;
00256
00257
00258 OdeSolution solutions;
00259 solutions.SetNumberOfTimeSteps(stepper.EstimateTimeSteps());
00260 solutions.rGetSolutions().push_back(rYValues);
00261 solutions.rGetTimes().push_back(startTime);
00262 solutions.SetOdeSystemInformation(pOdeSystem->GetSystemInformation());
00263
00264
00265 while (!stepper.IsTimeAtEnd() && !mStoppingEventOccurred)
00266 {
00267 double tend;
00268 int ierr = CVode(mpCvodeMem, stepper.GetNextTime(), yout, &tend, CV_NORMAL);
00269 if (ierr<0)
00270 {
00271 FreeCvodeMemory();
00272 CvodeError(ierr, "CVODE failed to solve system");
00273 }
00274
00275 solutions.rGetSolutions().push_back(rYValues);
00276 solutions.rGetTimes().push_back(tend);
00277 if (ierr == CV_ROOT_RETURN)
00278 {
00279
00280 mStoppingEventOccurred = true;
00281 mStoppingTime = tend;
00282 }
00283 stepper.AdvanceOneTimeStep();
00284 }
00285
00286
00287 solutions.SetNumberOfTimeSteps(stepper.GetTotalTimeStepsTaken());
00288
00289 int ierr = CVodeGetLastStep(mpCvodeMem, &mLastInternalStepSize);
00290 assert(ierr == CV_SUCCESS); ierr=ierr;
00291 FreeCvodeMemory();
00292
00293 return solutions;
00294 }
00295
00296
00297 void CvodeAdaptor::Solve(AbstractOdeSystem* pOdeSystem,
00298 std::vector<double>& rYValues,
00299 double startTime,
00300 double endTime,
00301 double maxStep)
00302 {
00303 assert(endTime > startTime);
00304
00305 mStoppingEventOccurred = false;
00306 if (mCheckForRoots && pOdeSystem->CalculateStoppingEvent(startTime, rYValues) == true)
00307 {
00308 EXCEPTION("(Solve) Stopping event is true for initial condition");
00309 }
00310
00311 SetupCvode(pOdeSystem, rYValues, startTime, maxStep);
00312
00313 N_Vector yout = mInitialValues;
00314 double tend;
00315 int ierr = CVode(mpCvodeMem, endTime, yout, &tend, CV_NORMAL);
00316 if (ierr<0)
00317 {
00318 FreeCvodeMemory();
00319 CvodeError(ierr, "CVODE failed to solve system");
00320 }
00321 if (ierr == CV_ROOT_RETURN)
00322 {
00323
00324 mStoppingEventOccurred = true;
00325 mStoppingTime = tend;
00326 }
00327 assert(NV_DATA_S(yout) == &(rYValues[0]));
00328 assert(!NV_OWN_DATA_S(yout));
00329
00330
00331
00332
00333
00334 ierr = CVodeGetLastStep(mpCvodeMem, &mLastInternalStepSize);
00335 assert(ierr == CV_SUCCESS); ierr=ierr;
00336 FreeCvodeMemory();
00337 }
00338
00339 CvodeAdaptor::CvodeAdaptor(double relTol, double absTol)
00340 : AbstractIvpOdeSolver(),
00341 mpCvodeMem(NULL), mInitialValues(NULL),
00342 mRelTol(relTol), mAbsTol(absTol),
00343 mLastInternalStepSize(-0.0),
00344 mMaxSteps(0),
00345 mCheckForRoots(false)
00346 {
00347 }
00348
00349 void CvodeAdaptor::SetTolerances(double relTol, double absTol)
00350 {
00351 mRelTol = relTol;
00352 mAbsTol = absTol;
00353 }
00354
00355 double CvodeAdaptor::GetRelativeTolerance()
00356 {
00357 return mRelTol;
00358 }
00359
00360 double CvodeAdaptor::GetAbsoluteTolerance()
00361 {
00362 return mAbsTol;
00363 }
00364
00365 double CvodeAdaptor::GetLastStepSize()
00366 {
00367 return mLastInternalStepSize;
00368 }
00369
00370 void CvodeAdaptor::CheckForStoppingEvents()
00371 {
00372 mCheckForRoots = true;
00373 }
00374
00375 void CvodeAdaptor::SetMaxSteps(long int numSteps)
00376 {
00377 mMaxSteps = numSteps;
00378 }
00379
00380 long int CvodeAdaptor::GetMaxSteps()
00381 {
00382 return mMaxSteps;
00383 }
00384
00385
00386 #include "SerializationExportWrapperForCpp.hpp"
00387 CHASTE_CLASS_EXPORT(CvodeAdaptor)
00388
00389 #endif // CHASTE_CVODE