1#ifndef CPPAD_CG_EVALUATOR_AD_INCLUDED
2#define CPPAD_CG_EVALUATOR_AD_INCLUDED
26template<
class ScalarIn,
class ScalarOut,
class FinalEvaluatorType>
40 using Super::handler_;
41 using Super::evalArrayCreationOperation;
43 std::set<NodeIn*> evalsAtomic_;
44 std::map<size_t, CppAD::atomic_base<ScalarOut>* > atomicFunctions_;
84 bool exists = atomicFunctions_.find(
id) != atomicFunctions_.end();
85 atomicFunctions_[
id] = &atomic;
89 virtual void addAtomicFunctions(
const std::map<
size_t, atomic_base<ScalarOut>* >&
atomics) {
91 atomic_base<ScalarOut>* atomic =
it.second;
92 if (atomic !=
nullptr) {
93 atomicFunctions_[
it.first] = atomic;
105 return evalsAtomic_.size();
120 Super::prepareNewEvaluation();
122 evalsAtomic_.clear();
133 if (evalsAtomic_.find(&node) != evalsAtomic_.end()) {
138 throw CGException(
"Evaluator can only handle zero forward mode for atomic functions");
143 CPPADCG_ASSERT_KNOWN(
args.size() == 2,
"Invalid number of arguments for atomic forward mode")
144 CPPADCG_ASSERT_KNOWN(
info.size() == 3,
"Invalid number of information data for atomic forward mode")
148 typename std::map<size_t, atomic_base<ScalarOut>* >::const_iterator
itaf = atomicFunctions_.find(
id);
150 if (
itaf != atomicFunctions_.end()) {
155 std::stringstream
ss;
156 ss <<
"No atomic function defined in the evaluator for ";
161 ss <<
"id '" <<
id <<
"'";
167 throw CGException(
"Evaluator can only handle zero forward mode for atomic functions");
169 const std::vector<ActiveOut>&
ax = evalArrayCreationOperation(*
args[0].getOperation());
170 std::vector<ActiveOut>&
ay = evalArrayCreationOperation(*
args[1].getOperation());
172 (*atomicFunction)(
ax,
ay);
174 evalsAtomic_.insert(&node);
183 CPPADCG_ASSERT_KNOWN(
args.size() == 1,
"Invalid number of arguments for print()")
201template<
class ScalarIn,
class ScalarOut>
std::string getAtomicFunctionName(size_t id) const
void evalAtomicOperation(NodeIn &node)
bool printOutPriOperations_
void prepareNewEvaluation()
bool isPrintOutPrintOperations() const
size_t getNumberOfEvaluatedAtomics() const
void setPrintOutPrintOperations(bool print)
ActiveOut evalPrint(const NodeIn &node)
virtual bool addAtomicFunction(size_t id, atomic_base< ScalarOut > &atomic)
const std::vector< size_t > & getInfo() const
const std::vector< Argument< Base > > & getArguments() const
CGOpCode getOperationType() const
bool GreaterThanZero(const cg::CG< Base > &x)