CppADCodeGen  2.4.3
A C++ Algorithmic Differentiation Package with Source Code Generation
functor_model_library.hpp
1 #ifndef CPPAD_CG_FUNCTOR_MODEL_LIBRARY_INCLUDED
2 #define CPPAD_CG_FUNCTOR_MODEL_LIBRARY_INCLUDED
3 /* --------------------------------------------------------------------------
4  * CppADCodeGen: C++ Algorithmic Differentiation with Source Code Generation:
5  * Copyright (C) 2013 Ciengis
6  * Copyright (C) 2018 Joao Leal
7  *
8  * CppADCodeGen is distributed under multiple licenses:
9  *
10  * - Eclipse Public License Version 1.0 (EPL1), and
11  * - GNU General Public License Version 3 (GPL3).
12  *
13  * EPL1 terms and conditions can be found in the file "epl-v10.txt", while
14  * terms and conditions for the GPL3 can be found in the file "gpl3.txt".
15  * ----------------------------------------------------------------------------
16  * Author: Joao Leal
17  */
18 
19 namespace CppAD {
20 namespace cg {
21 
27 template<class Base>
28 class FunctorModelLibrary : public ModelLibrary<Base> {
29 protected:
30  std::set<std::string> _modelNames;
31  unsigned long _version; // API version
32  void (*_onClose)();
33  void (*_setThreadPoolDisabled)(int);
34  int (*_isThreadPoolDisabled)();
35  void (*_setThreads)(unsigned int);
36  unsigned int (*_getThreads)();
37  void (*_setSchedulerStrategy)(int);
38  int (*_getSchedulerStrategy)();
39  void (*_setThreadPoolVerbose)(int v);
40  int (*_isThreadPoolVerbose)();
41  void (*_setThreadPoolGuidedMaxWork)(float v);
42  float (*_getThreadPoolGuidedMaxWork)();
43  void (*_setThreadPoolNumberOfTimeMeas)(unsigned int n);
44  unsigned int (*_getThreadPoolNumberOfTimeMeas)();
45 public:
46 
47  std::set<std::string> getModelNames() override {
48  return _modelNames;
49  }
50 
59  virtual std::unique_ptr<FunctorGenericModel<Base>> modelFunctor(const std::string& modelName) = 0;
60 
61  std::unique_ptr<GenericModel<Base>> model(const std::string& modelName) override final {
62  return std::unique_ptr<GenericModel<Base>> (modelFunctor(modelName).release());
63  }
64 
70  virtual unsigned long getAPIVersion() {
71  return _version;
72  }
73 
87  virtual void* loadFunction(const std::string& functionName,
88  bool required = true) = 0;
89 
90  void setThreadPoolDisabled(bool disabled) override {
91  if(_setThreadPoolDisabled != nullptr) {
92  (*_setThreadPoolDisabled)(disabled);
93  }
94  }
95 
96  bool isThreadPoolDisabled() const override {
97  if(_isThreadPoolDisabled != nullptr) {
98  return bool((*_isThreadPoolDisabled)());
99  }
100  return true;
101  }
102 
103  unsigned int getThreadNumber() const override {
104  if (_getThreads != nullptr) {
105  return (*_getThreads)();
106  }
107  return 1;
108  }
109 
110  void setThreadNumber(unsigned int n) override {
111  if (_setThreads != nullptr) {
112  (*_setThreads)(n);
113  }
114  }
115 
116  ThreadPoolScheduleStrategy getThreadPoolSchedulerStrategy() const override {
117  if (_getSchedulerStrategy != nullptr) {
118  return ThreadPoolScheduleStrategy((*_getSchedulerStrategy)());
119  }
120  return ThreadPoolScheduleStrategy::DYNAMIC;
121  }
122 
123  void setThreadPoolSchedulerStrategy(ThreadPoolScheduleStrategy s) override {
124  if (_setSchedulerStrategy != nullptr) {
125  (*_setSchedulerStrategy)(int(s));
126  }
127  }
128 
129  void setThreadPoolVerbose(bool v) override {
130  if (_setThreadPoolVerbose != nullptr) {
131  (*_setThreadPoolVerbose)(int(v));
132  }
133  }
134 
135  bool isThreadPoolVerbose() const override {
136  if (_isThreadPoolVerbose != nullptr) {
137  return bool((*_isThreadPoolVerbose)());
138  }
139  return false;
140  }
141 
142  void setThreadPoolGuidedMaxWork(float v) override {
143  if (_setThreadPoolGuidedMaxWork != nullptr) {
144  (*_setThreadPoolGuidedMaxWork)(v);
145  }
146  }
147 
148  float getThreadPoolGuidedMaxWork() const override {
149  if (_getThreadPoolGuidedMaxWork != nullptr) {
150  return (*_getThreadPoolGuidedMaxWork)();
151  }
152  return 1.0;
153  }
154 
155  void setThreadPoolNumberOfTimeMeas(unsigned int n) override {
156  if (_setThreadPoolNumberOfTimeMeas != nullptr) {
157  (*_setThreadPoolNumberOfTimeMeas)(n);
158  }
159  }
160 
161  unsigned int getThreadPoolNumberOfTimeMeas() const override {
162  if (_getThreadPoolNumberOfTimeMeas != nullptr) {
163  return (*_getThreadPoolNumberOfTimeMeas)();
164  }
165  return 0;
166  }
167 
168  inline virtual ~FunctorModelLibrary() = default;
169 
170 protected:
172  _version(0), // not really required (but it avoids warnings)
173  _onClose(nullptr),
174  _setThreadPoolDisabled(nullptr),
175  _isThreadPoolDisabled(nullptr),
176  _setThreads(nullptr),
177  _getThreads(nullptr),
178  _setSchedulerStrategy(nullptr),
179  _getSchedulerStrategy(nullptr),
180  _setThreadPoolVerbose(nullptr),
181  _isThreadPoolVerbose(nullptr),
182  _setThreadPoolGuidedMaxWork(nullptr),
183  _getThreadPoolGuidedMaxWork(nullptr),
184  _setThreadPoolNumberOfTimeMeas(nullptr),
185  _getThreadPoolNumberOfTimeMeas(nullptr) {
186  }
187 
188  inline void validate() {
192  unsigned long (*versionFunc)();
193  versionFunc = reinterpret_cast<decltype(versionFunc)> (loadFunction(ModelLibraryCSourceGen<Base>::FUNCTION_VERSION));
194 
195  _version = (*versionFunc)();
197  throw CGException("The API version of the dynamic library (", _version,
198  ") is incompatible with the current version (",
200 
204  void (*modelsFunc)(char const *const**, int*);
205  modelsFunc = reinterpret_cast<decltype(modelsFunc)> (loadFunction(ModelLibraryCSourceGen<Base>::FUNCTION_MODELS));
206 
207  char const*const* model_names = nullptr;
208  int model_count;
209  (*modelsFunc)(&model_names, &model_count);
210 
211  for (int i = 0; i < model_count; i++) {
212  _modelNames.insert(model_names[i]);
213  }
214 
218  _onClose = reinterpret_cast<decltype(_onClose)> (loadFunction(ModelLibraryCSourceGen<Base>::FUNCTION_ONCLOSE, false));
219 
223  _setThreadPoolDisabled = reinterpret_cast<decltype(_setThreadPoolDisabled)> (loadFunction(ModelLibraryCSourceGen<Base>::FUNCTION_SETTHREADPOOLDISABLED, false));
224  _isThreadPoolDisabled = reinterpret_cast<decltype(_isThreadPoolDisabled)> (loadFunction(ModelLibraryCSourceGen<Base>::FUNCTION_ISTHREADPOOLDISABLED, false));
225  _setThreads = reinterpret_cast<decltype(_setThreads)> (loadFunction(ModelLibraryCSourceGen<Base>::FUNCTION_SETTHREADS, false));
226  _getThreads = reinterpret_cast<decltype(_getThreads)> (loadFunction(ModelLibraryCSourceGen<Base>::FUNCTION_GETTHREADS, false));
227  _setSchedulerStrategy = reinterpret_cast<decltype(_setSchedulerStrategy)> (this->loadFunction(ModelLibraryCSourceGen<Base>::FUNCTION_SETTHREADSCHEDULERSTRAT, false));
228  _getSchedulerStrategy = reinterpret_cast<decltype(_getSchedulerStrategy)> (this->loadFunction(ModelLibraryCSourceGen<Base>::FUNCTION_GETTHREADSCHEDULERSTRAT, false));
229  _setThreadPoolVerbose = reinterpret_cast<decltype(_setThreadPoolVerbose)> (this->loadFunction(ModelLibraryCSourceGen<Base>::FUNCTION_SETTHREADPOOLVERBOSE, false));
230  _isThreadPoolVerbose = reinterpret_cast<decltype(_isThreadPoolVerbose)> (this->loadFunction(ModelLibraryCSourceGen<Base>::FUNCTION_ISTHREADPOOLVERBOSE, false));
231  _setThreadPoolGuidedMaxWork = reinterpret_cast<decltype(_setThreadPoolGuidedMaxWork)> (this->loadFunction(ModelLibraryCSourceGen<Base>::FUNCTION_SETTHREADPOOLGUIDEDMAXGROUPWORK, false));
232  _getThreadPoolGuidedMaxWork = reinterpret_cast<decltype(_getThreadPoolGuidedMaxWork)> (this->loadFunction(ModelLibraryCSourceGen<Base>::FUNCTION_GETTHREADPOOLGUIDEDMAXGROUPWORK, false));
233  _setThreadPoolNumberOfTimeMeas = reinterpret_cast<decltype(_setThreadPoolNumberOfTimeMeas)> (this->loadFunction(ModelLibraryCSourceGen<Base>::FUNCTION_SETTHREADPOOLNUMBEROFTIMEMEAS, false));
234  _getThreadPoolNumberOfTimeMeas = reinterpret_cast<decltype(_getThreadPoolNumberOfTimeMeas)> (this->loadFunction(ModelLibraryCSourceGen<Base>::FUNCTION_GETTHREADPOOLNUMBEROFTIMEMEAS, false));
235 
236  if(_setThreads != nullptr) {
237  (*_setThreads)(std::thread::hardware_concurrency());
238  }
239  }
240 };
241 
242 } // END cg namespace
243 } // END CppAD namespace
244 
245 #endif
void setThreadPoolNumberOfTimeMeas(unsigned int n) override
bool isThreadPoolDisabled() const override
virtual std::unique_ptr< FunctorGenericModel< Base > > modelFunctor(const std::string &modelName)=0
std::set< std::string > getModelNames() override
void setThreadPoolSchedulerStrategy(ThreadPoolScheduleStrategy s) override
unsigned int getThreadNumber() const override
void setThreadPoolDisabled(bool disabled) override
unsigned int getThreadPoolNumberOfTimeMeas() const override
void setThreadNumber(unsigned int n) override
virtual unsigned long getAPIVersion()
ThreadPoolScheduleStrategy getThreadPoolSchedulerStrategy() const override
std::unique_ptr< GenericModel< Base > > model(const std::string &modelName) override final
virtual void * loadFunction(const std::string &functionName, bool required=true)=0