CppADCodeGen 2.4.3
A C++ Algorithmic Differentiation Package with Source Code Generation
Loading...
Searching...
No Matches
evaluator_ad.hpp
1#ifndef CPPAD_CG_EVALUATOR_AD_INCLUDED
2#define CPPAD_CG_EVALUATOR_AD_INCLUDED
3/* --------------------------------------------------------------------------
4 * CppADCodeGen: C++ Algorithmic Differentiation with Source Code Generation:
5 * Copyright (C) 2016 Ciengis
6 * Copyright (C) 2020 Joao Leal
7 *
8 * CppADCodeGen is distributed under multiple licenses:
9 *
10 * - Eclipse Public License Version 1.0 (EPL1), and
11 * - GNU General Public License Version 3 (GPL3).
12 *
13 * EPL1 terms and conditions can be found in the file "epl-v10.txt", while
14 * terms and conditions for the GPL3 can be found in the file "gpl3.txt".
15 * ----------------------------------------------------------------------------
16 * Author: Joao Leal
17 */
18
19namespace CppAD {
20namespace cg {
21
26template<class ScalarIn, class ScalarOut, class FinalEvaluatorType>
27class EvaluatorAD : public EvaluatorOperations<ScalarIn, ScalarOut, CppAD::AD<ScalarOut>, FinalEvaluatorType> {
34public:
39protected:
40 using Super::handler_;
41 using Super::evalArrayCreationOperation;
42protected:
43 std::set<NodeIn*> evalsAtomic_;
44 std::map<size_t, CppAD::atomic_base<ScalarOut>* > atomicFunctions_;
50public:
51
52 inline EvaluatorAD(CodeHandler<ScalarIn>& handler) :
53 Super(handler),
55 }
56
57 inline virtual ~EvaluatorAD() = default;
58
63 inline void setPrintOutPrintOperations(bool print) {
65 }
66
71 inline bool isPrintOutPrintOperations() const {
73 }
74
83 virtual bool addAtomicFunction(size_t id, atomic_base<ScalarOut>& atomic) {
84 bool exists = atomicFunctions_.find(id) != atomicFunctions_.end();
85 atomicFunctions_[id] = &atomic;
86 return exists;
87 }
88
89 virtual void addAtomicFunctions(const std::map<size_t, atomic_base<ScalarOut>* >& atomics) {
90 for (const auto& it : atomics) {
91 atomic_base<ScalarOut>* atomic = it.second;
92 if (atomic != nullptr) {
93 atomicFunctions_[it.first] = atomic;
94 }
95 }
96 }
97
105 return evalsAtomic_.size();
106 }
107
108protected:
109
114 inline void prepareNewEvaluation() {
120 Super::prepareNewEvaluation();
121
122 evalsAtomic_.clear();
123 }
124
131 inline void evalAtomicOperation(NodeIn& node) {
132
133 if (evalsAtomic_.find(&node) != evalsAtomic_.end()) {
134 return;
135 }
136
137 if (node.getOperationType() != CGOpCode::AtomicForward) {
138 throw CGException("Evaluator can only handle zero forward mode for atomic functions");
139 }
140
141 const std::vector<size_t>& info = node.getInfo();
142 const std::vector<Argument<ScalarIn> >& args = node.getArguments();
143 CPPADCG_ASSERT_KNOWN(args.size() == 2, "Invalid number of arguments for atomic forward mode")
144 CPPADCG_ASSERT_KNOWN(info.size() == 3, "Invalid number of information data for atomic forward mode")
145
146 // find the atomic function
147 size_t id = info[0];
148 typename std::map<size_t, atomic_base<ScalarOut>* >::const_iterator itaf = atomicFunctions_.find(id);
149 atomic_base<ScalarOut>* atomicFunction = nullptr;
150 if (itaf != atomicFunctions_.end()) {
151 atomicFunction = itaf->second;
152 }
153
154 if (atomicFunction == nullptr) {
155 std::stringstream ss;
156 ss << "No atomic function defined in the evaluator for ";
157 const std::string & atomName = handler_.getAtomicFunctionName(id);
158 if (!atomName.empty()) {
159 ss << "'" << atomName << "'";
160 } else
161 ss << "id '" << id << "'";
162 throw CGException(ss.str());
163 }
164
165 size_t p = info[2];
166 if (p != 0) {
167 throw CGException("Evaluator can only handle zero forward mode for atomic functions");
168 }
169 const std::vector<ActiveOut>& ax = evalArrayCreationOperation(*args[0].getOperation());
170 std::vector<ActiveOut>& ay = evalArrayCreationOperation(*args[1].getOperation());
171
172 (*atomicFunction)(ax, ay);
173
174 evalsAtomic_.insert(&node);
175 }
176
181 inline ActiveOut evalPrint(const NodeIn& node) {
182 const std::vector<ArgIn>& args = node.getArguments();
183 CPPADCG_ASSERT_KNOWN(args.size() == 1, "Invalid number of arguments for print()")
184 ActiveOut out(this->evalArg(args, 0));
185
186 const auto& nodePri = static_cast<const PrintOperationNode<ScalarIn>&>(node);
188 std::cout << nodePri.getBeforeString() << out << nodePri.getAfterString();
189 }
190
191 CppAD::PrintFor(ActiveOut(0), nodePri.getBeforeString().c_str(), out, nodePri.getAfterString().c_str());
192
193 return out;
194 }
195
196};
197
201template<class ScalarIn, class ScalarOut>
202class Evaluator<ScalarIn, ScalarOut, CppAD::AD<ScalarOut> > : public EvaluatorAD<ScalarIn, ScalarOut, Evaluator<ScalarIn, ScalarOut, CppAD::AD<ScalarOut> > > {
203public:
206public:
207
208 inline Evaluator(CodeHandler<ScalarIn>& handler) :
209 Super(handler) {
210 }
211
212};
213
214} // END cg namespace
215} // END CppAD namespace
216
217#endif
std::string getAtomicFunctionName(size_t id) const
void evalAtomicOperation(NodeIn &node)
bool isPrintOutPrintOperations() const
size_t getNumberOfEvaluatedAtomics() const
void setPrintOutPrintOperations(bool print)
ActiveOut evalPrint(const NodeIn &node)
virtual bool addAtomicFunction(size_t id, atomic_base< ScalarOut > &atomic)
const std::vector< size_t > & getInfo() const
const std::vector< Argument< Base > > & getArguments() const
CGOpCode getOperationType() const