CppADCodeGen 2.4.3
A C++ Algorithmic Differentiation Package with Source Code Generation
Loading...
Searching...
No Matches
functor_model_library.hpp
1#ifndef CPPAD_CG_FUNCTOR_MODEL_LIBRARY_INCLUDED
2#define CPPAD_CG_FUNCTOR_MODEL_LIBRARY_INCLUDED
3/* --------------------------------------------------------------------------
4 * CppADCodeGen: C++ Algorithmic Differentiation with Source Code Generation:
5 * Copyright (C) 2013 Ciengis
6 * Copyright (C) 2018 Joao Leal
7 *
8 * CppADCodeGen is distributed under multiple licenses:
9 *
10 * - Eclipse Public License Version 1.0 (EPL1), and
11 * - GNU General Public License Version 3 (GPL3).
12 *
13 * EPL1 terms and conditions can be found in the file "epl-v10.txt", while
14 * terms and conditions for the GPL3 can be found in the file "gpl3.txt".
15 * ----------------------------------------------------------------------------
16 * Author: Joao Leal
17 */
18
19namespace CppAD {
20namespace cg {
21
27template<class Base>
28class FunctorModelLibrary : public ModelLibrary<Base> {
29protected:
30 std::set<std::string> _modelNames;
31 unsigned long _version; // API version
32 void (*_onClose)();
33 void (*_setThreadPoolDisabled)(int);
34 int (*_isThreadPoolDisabled)();
35 void (*_setThreads)(unsigned int);
36 unsigned int (*_getThreads)();
37 void (*_setSchedulerStrategy)(int);
38 int (*_getSchedulerStrategy)();
39 void (*_setThreadPoolVerbose)(int v);
40 int (*_isThreadPoolVerbose)();
41 void (*_setThreadPoolGuidedMaxWork)(float v);
42 float (*_getThreadPoolGuidedMaxWork)();
43 void (*_setThreadPoolNumberOfTimeMeas)(unsigned int n);
44 unsigned int (*_getThreadPoolNumberOfTimeMeas)();
45public:
46
47 std::set<std::string> getModelNames() override {
48 return _modelNames;
49 }
50
59 virtual std::unique_ptr<FunctorGenericModel<Base>> modelFunctor(const std::string& modelName) = 0;
60
61 std::unique_ptr<GenericModel<Base>> model(const std::string& modelName) override final {
62 return std::unique_ptr<GenericModel<Base>> (modelFunctor(modelName).release());
63 }
64
70 virtual unsigned long getAPIVersion() {
71 return _version;
72 }
73
87 virtual void* loadFunction(const std::string& functionName,
88 bool required = true) = 0;
89
90 void setThreadPoolDisabled(bool disabled) override {
91 if(_setThreadPoolDisabled != nullptr) {
92 (*_setThreadPoolDisabled)(disabled);
93 }
94 }
95
96 bool isThreadPoolDisabled() const override {
97 if(_isThreadPoolDisabled != nullptr) {
98 return bool((*_isThreadPoolDisabled)());
99 }
100 return true;
101 }
102
103 unsigned int getThreadNumber() const override {
104 if (_getThreads != nullptr) {
105 return (*_getThreads)();
106 }
107 return 1;
108 }
109
110 void setThreadNumber(unsigned int n) override {
111 if (_setThreads != nullptr) {
112 (*_setThreads)(n);
113 }
114 }
115
116 ThreadPoolScheduleStrategy getThreadPoolSchedulerStrategy() const override {
117 if (_getSchedulerStrategy != nullptr) {
118 return ThreadPoolScheduleStrategy((*_getSchedulerStrategy)());
119 }
120 return ThreadPoolScheduleStrategy::DYNAMIC;
121 }
122
123 void setThreadPoolSchedulerStrategy(ThreadPoolScheduleStrategy s) override {
124 if (_setSchedulerStrategy != nullptr) {
125 (*_setSchedulerStrategy)(int(s));
126 }
127 }
128
129 void setThreadPoolVerbose(bool v) override {
130 if (_setThreadPoolVerbose != nullptr) {
131 (*_setThreadPoolVerbose)(int(v));
132 }
133 }
134
135 bool isThreadPoolVerbose() const override {
136 if (_isThreadPoolVerbose != nullptr) {
137 return bool((*_isThreadPoolVerbose)());
138 }
139 return false;
140 }
141
142 void setThreadPoolGuidedMaxWork(float v) override {
143 if (_setThreadPoolGuidedMaxWork != nullptr) {
144 (*_setThreadPoolGuidedMaxWork)(v);
145 }
146 }
147
148 float getThreadPoolGuidedMaxWork() const override {
149 if (_getThreadPoolGuidedMaxWork != nullptr) {
150 return (*_getThreadPoolGuidedMaxWork)();
151 }
152 return 1.0;
153 }
154
155 void setThreadPoolNumberOfTimeMeas(unsigned int n) override {
156 if (_setThreadPoolNumberOfTimeMeas != nullptr) {
157 (*_setThreadPoolNumberOfTimeMeas)(n);
158 }
159 }
160
161 unsigned int getThreadPoolNumberOfTimeMeas() const override {
162 if (_getThreadPoolNumberOfTimeMeas != nullptr) {
163 return (*_getThreadPoolNumberOfTimeMeas)();
164 }
165 return 0;
166 }
167
168 inline virtual ~FunctorModelLibrary() = default;
169
170protected:
172 _version(0), // not really required (but it avoids warnings)
173 _onClose(nullptr),
174 _setThreadPoolDisabled(nullptr),
175 _isThreadPoolDisabled(nullptr),
176 _setThreads(nullptr),
177 _getThreads(nullptr),
178 _setSchedulerStrategy(nullptr),
179 _getSchedulerStrategy(nullptr),
180 _setThreadPoolVerbose(nullptr),
181 _isThreadPoolVerbose(nullptr),
182 _setThreadPoolGuidedMaxWork(nullptr),
183 _getThreadPoolGuidedMaxWork(nullptr),
184 _setThreadPoolNumberOfTimeMeas(nullptr),
185 _getThreadPoolNumberOfTimeMeas(nullptr) {
186 }
187
188 inline void validate() {
192 unsigned long (*versionFunc)();
193 versionFunc = reinterpret_cast<decltype(versionFunc)> (loadFunction(ModelLibraryCSourceGen<Base>::FUNCTION_VERSION));
194
195 _version = (*versionFunc)();
197 throw CGException("The API version of the dynamic library (", _version,
198 ") is incompatible with the current version (",
200
204 void (*modelsFunc)(char const *const**, int*);
205 modelsFunc = reinterpret_cast<decltype(modelsFunc)> (loadFunction(ModelLibraryCSourceGen<Base>::FUNCTION_MODELS));
206
207 char const*const* model_names = nullptr;
208 int model_count;
209 (*modelsFunc)(&model_names, &model_count);
210
211 for (int i = 0; i < model_count; i++) {
212 _modelNames.insert(model_names[i]);
213 }
214
218 _onClose = reinterpret_cast<decltype(_onClose)> (loadFunction(ModelLibraryCSourceGen<Base>::FUNCTION_ONCLOSE, false));
219
223 _setThreadPoolDisabled = reinterpret_cast<decltype(_setThreadPoolDisabled)> (loadFunction(ModelLibraryCSourceGen<Base>::FUNCTION_SETTHREADPOOLDISABLED, false));
224 _isThreadPoolDisabled = reinterpret_cast<decltype(_isThreadPoolDisabled)> (loadFunction(ModelLibraryCSourceGen<Base>::FUNCTION_ISTHREADPOOLDISABLED, false));
225 _setThreads = reinterpret_cast<decltype(_setThreads)> (loadFunction(ModelLibraryCSourceGen<Base>::FUNCTION_SETTHREADS, false));
226 _getThreads = reinterpret_cast<decltype(_getThreads)> (loadFunction(ModelLibraryCSourceGen<Base>::FUNCTION_GETTHREADS, false));
227 _setSchedulerStrategy = reinterpret_cast<decltype(_setSchedulerStrategy)> (this->loadFunction(ModelLibraryCSourceGen<Base>::FUNCTION_SETTHREADSCHEDULERSTRAT, false));
228 _getSchedulerStrategy = reinterpret_cast<decltype(_getSchedulerStrategy)> (this->loadFunction(ModelLibraryCSourceGen<Base>::FUNCTION_GETTHREADSCHEDULERSTRAT, false));
229 _setThreadPoolVerbose = reinterpret_cast<decltype(_setThreadPoolVerbose)> (this->loadFunction(ModelLibraryCSourceGen<Base>::FUNCTION_SETTHREADPOOLVERBOSE, false));
230 _isThreadPoolVerbose = reinterpret_cast<decltype(_isThreadPoolVerbose)> (this->loadFunction(ModelLibraryCSourceGen<Base>::FUNCTION_ISTHREADPOOLVERBOSE, false));
231 _setThreadPoolGuidedMaxWork = reinterpret_cast<decltype(_setThreadPoolGuidedMaxWork)> (this->loadFunction(ModelLibraryCSourceGen<Base>::FUNCTION_SETTHREADPOOLGUIDEDMAXGROUPWORK, false));
232 _getThreadPoolGuidedMaxWork = reinterpret_cast<decltype(_getThreadPoolGuidedMaxWork)> (this->loadFunction(ModelLibraryCSourceGen<Base>::FUNCTION_GETTHREADPOOLGUIDEDMAXGROUPWORK, false));
233 _setThreadPoolNumberOfTimeMeas = reinterpret_cast<decltype(_setThreadPoolNumberOfTimeMeas)> (this->loadFunction(ModelLibraryCSourceGen<Base>::FUNCTION_SETTHREADPOOLNUMBEROFTIMEMEAS, false));
234 _getThreadPoolNumberOfTimeMeas = reinterpret_cast<decltype(_getThreadPoolNumberOfTimeMeas)> (this->loadFunction(ModelLibraryCSourceGen<Base>::FUNCTION_GETTHREADPOOLNUMBEROFTIMEMEAS, false));
235
236 if(_setThreads != nullptr) {
237 (*_setThreads)(std::thread::hardware_concurrency());
238 }
239 }
240};
241
242} // END cg namespace
243} // END CppAD namespace
244
245#endif
ThreadPoolScheduleStrategy getThreadPoolSchedulerStrategy() const override
std::set< std::string > getModelNames() override
unsigned int getThreadPoolNumberOfTimeMeas() const override
void setThreadPoolDisabled(bool disabled) override
void setThreadNumber(unsigned int n) override
std::unique_ptr< GenericModel< Base > > model(const std::string &modelName) override final
virtual void * loadFunction(const std::string &functionName, bool required=true)=0
void setThreadPoolSchedulerStrategy(ThreadPoolScheduleStrategy s) override
unsigned int getThreadNumber() const override
void setThreadPoolNumberOfTimeMeas(unsigned int n) override
virtual std::unique_ptr< FunctorGenericModel< Base > > modelFunctor(const std::string &modelName)=0