CppADCodeGen  2.3.0
A C++ Algorithmic Differentiation Package with Source Code Generation
functor_model_library.hpp
1 #ifndef CPPAD_CG_FUNCTOR_MODEL_LIBRARY_INCLUDED
2 #define CPPAD_CG_FUNCTOR_MODEL_LIBRARY_INCLUDED
3 /* --------------------------------------------------------------------------
4  * CppADCodeGen: C++ Algorithmic Differentiation with Source Code Generation:
5  * Copyright (C) 2013 Ciengis
6  * Copyright (C) 2018 Joao Leal
7  *
8  * CppADCodeGen is distributed under multiple licenses:
9  *
10  * - Eclipse Public License Version 1.0 (EPL1), and
11  * - GNU General Public License Version 3 (GPL3).
12  *
13  * EPL1 terms and conditions can be found in the file "epl-v10.txt", while
14  * terms and conditions for the GPL3 can be found in the file "gpl3.txt".
15  * ----------------------------------------------------------------------------
16  * Author: Joao Leal
17  */
18 
19 namespace CppAD {
20 namespace cg {
21 
27 template<class Base>
28 class FunctorModelLibrary : public ModelLibrary<Base> {
29 protected:
30  std::set<std::string> _modelNames;
31  unsigned long _version; // API version
32  void (*_onClose)();
33  void (*_setThreadPoolDisabled)(int);
34  int (*_isThreadPoolDisabled)();
35  void (*_setThreads)(unsigned int);
36  unsigned int (*_getThreads)();
37  void (*_setSchedulerStrategy)(int);
38  int (*_getSchedulerStrategy)();
39  void (*_setThreadPoolVerbose)(int v);
40  int (*_isThreadPoolVerbose)();
41  void (*_setThreadPoolGuidedMaxWork)(float v);
42  float (*_getThreadPoolGuidedMaxWork)();
43  void (*_setThreadPoolNumberOfTimeMeas)(unsigned int n);
44  unsigned int (*_getThreadPoolNumberOfTimeMeas)();
45 public:
46 
47  std::set<std::string> getModelNames() override {
48  return _modelNames;
49  }
50 
60  virtual std::unique_ptr<FunctorGenericModel<Base>> modelFunctor(const std::string& modelName) = 0;
61 
62  std::unique_ptr<GenericModel<Base>> model(const std::string& modelName) override final {
63  return std::unique_ptr<GenericModel<Base>> (modelFunctor(modelName).release());
64  }
65 
71  virtual unsigned long getAPIVersion() {
72  return _version;
73  }
74 
88  virtual void* loadFunction(const std::string& functionName,
89  bool required = true) = 0;
90 
91  void setThreadPoolDisabled(bool disabled) override {
92  if(_setThreadPoolDisabled != nullptr) {
93  (*_setThreadPoolDisabled)(disabled);
94  }
95  }
96 
97  virtual bool isThreadPoolDisabled() const override {
98  if(_isThreadPoolDisabled != nullptr) {
99  return bool((*_isThreadPoolDisabled)());
100  }
101  return true;
102  }
103 
104  unsigned int getThreadNumber() const override {
105  if (_getThreads != nullptr) {
106  return (*_getThreads)();
107  }
108  return 1;
109  }
110 
111  void setThreadNumber(unsigned int n) override {
112  if (_setThreads != nullptr) {
113  (*_setThreads)(n);
114  }
115  }
116 
117  ThreadPoolScheduleStrategy getThreadPoolSchedulerStrategy() const override {
118  if (_getSchedulerStrategy != nullptr) {
119  return ThreadPoolScheduleStrategy((*_getSchedulerStrategy)());
120  }
121  return ThreadPoolScheduleStrategy::DYNAMIC;
122  }
123 
124  void setThreadPoolSchedulerStrategy(ThreadPoolScheduleStrategy s) override {
125  if (_setSchedulerStrategy != nullptr) {
126  (*_setSchedulerStrategy)(int(s));
127  }
128  }
129 
130  void setThreadPoolVerbose(bool v) override {
131  if (_setThreadPoolVerbose != nullptr) {
132  (*_setThreadPoolVerbose)(int(v));
133  }
134  }
135 
136  bool isThreadPoolVerbose() const override {
137  if (_isThreadPoolVerbose != nullptr) {
138  return bool((*_isThreadPoolVerbose)());
139  }
140  return false;
141  }
142 
143  void setThreadPoolGuidedMaxWork(float v) override {
144  if (_setThreadPoolGuidedMaxWork != nullptr) {
145  (*_setThreadPoolGuidedMaxWork)(v);
146  }
147  }
148 
149  float getThreadPoolGuidedMaxWork() const override {
150  if (_getThreadPoolGuidedMaxWork != nullptr) {
151  return (*_getThreadPoolGuidedMaxWork)();
152  }
153  return 1.0;
154  }
155 
156  void setThreadPoolNumberOfTimeMeas(unsigned int n) override {
157  if (_setThreadPoolNumberOfTimeMeas != nullptr) {
158  (*_setThreadPoolNumberOfTimeMeas)(n);
159  }
160  }
161 
162  unsigned int getThreadPoolNumberOfTimeMeas() const override {
163  if (_getThreadPoolNumberOfTimeMeas != nullptr) {
164  return (*_getThreadPoolNumberOfTimeMeas)();
165  }
166  return 0;
167  }
168 
169  inline virtual ~FunctorModelLibrary() = default;
170 
171 protected:
173  _version(0), // not really required (but it avoids warnings)
174  _onClose(nullptr),
175  _setThreadPoolDisabled(nullptr),
176  _isThreadPoolDisabled(nullptr),
177  _setThreads(nullptr),
178  _getThreads(nullptr),
179  _setSchedulerStrategy(nullptr),
180  _getSchedulerStrategy(nullptr),
181  _setThreadPoolVerbose(nullptr),
182  _isThreadPoolVerbose(nullptr),
183  _setThreadPoolGuidedMaxWork(nullptr),
184  _getThreadPoolGuidedMaxWork(nullptr),
185  _setThreadPoolNumberOfTimeMeas(nullptr),
186  _getThreadPoolNumberOfTimeMeas(nullptr) {
187  }
188 
189  inline void validate() {
193  unsigned long (*versionFunc)();
194  versionFunc = reinterpret_cast<decltype(versionFunc)> (loadFunction(ModelLibraryCSourceGen<Base>::FUNCTION_VERSION));
195 
196  _version = (*versionFunc)();
198  throw CGException("The API version of the dynamic library (", _version,
199  ") is incompatible with the current version (",
201 
205  void (*modelsFunc)(char const *const**, int*);
206  modelsFunc = reinterpret_cast<decltype(modelsFunc)> (loadFunction(ModelLibraryCSourceGen<Base>::FUNCTION_MODELS));
207 
208  char const*const* model_names = nullptr;
209  int model_count;
210  (*modelsFunc)(&model_names, &model_count);
211 
212  for (int i = 0; i < model_count; i++) {
213  _modelNames.insert(model_names[i]);
214  }
215 
219  _onClose = reinterpret_cast<decltype(_onClose)> (loadFunction(ModelLibraryCSourceGen<Base>::FUNCTION_ONCLOSE, false));
220 
224  _setThreadPoolDisabled = reinterpret_cast<decltype(_setThreadPoolDisabled)> (loadFunction(ModelLibraryCSourceGen<Base>::FUNCTION_SETTHREADPOOLDISABLED, false));
225  _isThreadPoolDisabled = reinterpret_cast<decltype(_isThreadPoolDisabled)> (loadFunction(ModelLibraryCSourceGen<Base>::FUNCTION_ISTHREADPOOLDISABLED, false));
226  _setThreads = reinterpret_cast<decltype(_setThreads)> (loadFunction(ModelLibraryCSourceGen<Base>::FUNCTION_SETTHREADS, false));
227  _getThreads = reinterpret_cast<decltype(_getThreads)> (loadFunction(ModelLibraryCSourceGen<Base>::FUNCTION_GETTHREADS, false));
228  _setSchedulerStrategy = reinterpret_cast<decltype(_setSchedulerStrategy)> (this->loadFunction(ModelLibraryCSourceGen<Base>::FUNCTION_SETTHREADSCHEDULERSTRAT, false));
229  _getSchedulerStrategy = reinterpret_cast<decltype(_getSchedulerStrategy)> (this->loadFunction(ModelLibraryCSourceGen<Base>::FUNCTION_GETTHREADSCHEDULERSTRAT, false));
230  _setThreadPoolVerbose = reinterpret_cast<decltype(_setThreadPoolVerbose)> (this->loadFunction(ModelLibraryCSourceGen<Base>::FUNCTION_SETTHREADPOOLVERBOSE, false));
231  _isThreadPoolVerbose = reinterpret_cast<decltype(_isThreadPoolVerbose)> (this->loadFunction(ModelLibraryCSourceGen<Base>::FUNCTION_ISTHREADPOOLVERBOSE, false));
232  _setThreadPoolGuidedMaxWork = reinterpret_cast<decltype(_setThreadPoolGuidedMaxWork)> (this->loadFunction(ModelLibraryCSourceGen<Base>::FUNCTION_SETTHREADPOOLGUIDEDMAXGROUPWORK, false));
233  _getThreadPoolGuidedMaxWork = reinterpret_cast<decltype(_getThreadPoolGuidedMaxWork)> (this->loadFunction(ModelLibraryCSourceGen<Base>::FUNCTION_GETTHREADPOOLGUIDEDMAXGROUPWORK, false));
234  _setThreadPoolNumberOfTimeMeas = reinterpret_cast<decltype(_setThreadPoolNumberOfTimeMeas)> (this->loadFunction(ModelLibraryCSourceGen<Base>::FUNCTION_SETTHREADPOOLNUMBEROFTIMEMEAS, false));
235  _getThreadPoolNumberOfTimeMeas = reinterpret_cast<decltype(_getThreadPoolNumberOfTimeMeas)> (this->loadFunction(ModelLibraryCSourceGen<Base>::FUNCTION_GETTHREADPOOLNUMBEROFTIMEMEAS, false));
236 
237  if(_setThreads != nullptr) {
238  (*_setThreads)(std::thread::hardware_concurrency());
239  }
240  }
241 };
242 
243 } // END cg namespace
244 } // END CppAD namespace
245 
246 #endif
virtual bool isThreadPoolDisabled() const override
void setThreadPoolNumberOfTimeMeas(unsigned int n) override
virtual std::unique_ptr< FunctorGenericModel< Base > > modelFunctor(const std::string &modelName)=0
std::set< std::string > getModelNames() override
void setThreadPoolSchedulerStrategy(ThreadPoolScheduleStrategy s) override
unsigned int getThreadNumber() const override
void setThreadPoolDisabled(bool disabled) override
unsigned int getThreadPoolNumberOfTimeMeas() const override
void setThreadNumber(unsigned int n) override
virtual unsigned long getAPIVersion()
ThreadPoolScheduleStrategy getThreadPoolSchedulerStrategy() const override
std::unique_ptr< GenericModel< Base > > model(const std::string &modelName) override final
virtual void * loadFunction(const std::string &functionName, bool required=true)=0