CppADCodeGen 2.4.3
A C++ Algorithmic Differentiation Package with Source Code Generation
Loading...
Searching...
No Matches
evaluator_solve.hpp
1#ifndef CPPAD_CG_EVALUATOR_SOLVE_INCLUDED
2#define CPPAD_CG_EVALUATOR_SOLVE_INCLUDED
3/* --------------------------------------------------------------------------
4 * CppADCodeGen: C++ Algorithmic Differentiation with Source Code Generation:
5 * Copyright (C) 2016 Ciengis
6 *
7 * CppADCodeGen is distributed under multiple licenses:
8 *
9 * - Eclipse Public License Version 1.0 (EPL1), and
10 * - GNU General Public License Version 3 (GPL3).
11 *
12 * EPL1 terms and conditions can be found in the file "epl-v10.txt", while
13 * terms and conditions for the GPL3 can be found in the file "gpl3.txt".
14 * ----------------------------------------------------------------------------
15 * Author: Joao Leal
16 */
17
18namespace CppAD {
19namespace cg {
20
26template<class Scalar>
27class EvaluatorCloneSolve : public EvaluatorCG<Scalar, Scalar, EvaluatorCloneSolve<Scalar>> {
34public:
35 using ActiveOut = CG<Scalar>;
36 using SourceCodePath = typename CodeHandler<Scalar>::SourceCodePath;
37protected:
39private:
44 const std::vector<const SourceCodePath*>* paths_;
49 const std::vector<const std::vector<CG<Scalar>*>*>* replaceOnPath_;
54 const BidirGraph<Scalar>* pathGraph_;
58 const std::map<const PathNodeEdges<Scalar>*, CG<Scalar>>* replaceOnGraph_;
62 const std::set<const OperationNode<Scalar>*>* clone_;
66 const std::map<const OperationPathNode<Scalar>, CG<Scalar>>* replaceArgument_;
67public:
68
79 const std::vector<const SourceCodePath*>& paths,
80 const std::vector<const std::vector<CG<Scalar>*>*>& replaceOnPath) :
81 Super(handler),
82 paths_(&paths),
83 replaceOnPath_(&replaceOnPath),
84 pathGraph_(nullptr),
85 replaceOnGraph_(nullptr),
86 clone_(nullptr),
87 replaceArgument_(nullptr) {
88 CPPADCG_ASSERT_UNKNOWN(paths_->size() == replaceOnPath_->size());
89#ifndef NDEBUG
90 for (size_t i = 0; i < paths.size(); ++i) {
91 CPPADCG_ASSERT_UNKNOWN(paths[i]->size() == replaceOnPath[i]->size());
92 }
93#endif
94 }
95
104 const BidirGraph<Scalar>& pathGraph,
105 const std::map<const PathNodeEdges<Scalar>*, CG<Scalar> >& replaceOnGraph) :
106 Super(handler),
107 paths_(nullptr),
108 replaceOnPath_(nullptr),
109 pathGraph_(&pathGraph),
110 replaceOnGraph_(&replaceOnGraph),
111 clone_(nullptr),
112 replaceArgument_(nullptr) {
113 }
114
123 const std::set<const OperationNode<Scalar>*>& clone,
124 const std::map<const OperationPathNode<Scalar>, CG<Scalar>>& replaceArgument) :
125 Super(handler),
126 paths_(nullptr),
127 replaceOnPath_(nullptr),
128 pathGraph_(nullptr),
129 replaceOnGraph_(nullptr),
130 clone_(&clone),
131 replaceArgument_(&replaceArgument) {
132 }
133
134protected:
135
141 CPPADCG_ASSERT_UNKNOWN(this->depth_ > 0);
142
143 if(paths_ != nullptr) {
144 const auto& paths = *paths_;
145 for (size_t i = 0; i < paths.size(); ++i) {
146 size_t d = this->depth_ - 1;
147 if (isOnPath(*paths[i])) {
148 // in one of the paths
149
150 auto* r = (*(*replaceOnPath_)[i])[d];
151 if (r != nullptr) {
152 return *r;
153 } else {
154 return Super::evalOperation(node);
155 }
156 }
157 }
158 }
159
160 if(pathGraph_ != nullptr) {
161 const PathNodeEdges<Scalar>* egdes = pathGraph_->find(node);
162 if (egdes != nullptr) {
163 auto it = replaceOnGraph_->find(egdes);
164 if (it != replaceOnGraph_->end()) {
165 return it->second;
166 } else {
167 return Super::evalOperation(node);
168 }
169 }
170 }
171
172 if (clone_ != nullptr) {
173 if (clone_->find(&node) != clone_->end()) {
174 return Super::evalOperation(node);
175 }
176 }
177
178 if (replaceArgument_ != nullptr) {
179 size_t d = this->depth_ - 1;
180 if (d > 0) {
181 auto it = replaceArgument_->find(this->path_[d - 1]);
182 if (it != replaceArgument_->end()) {
183 return it->second;
184 }
185 }
186 }
187
188 return CG<Scalar>(node); // use original
189 }
190
191private:
192 inline bool isOnPath(const SourceCodePath& path) const {
193 size_t d = this->depth_ - 1;
194
195 if (d >= path.size())
196 return false;
197
198 if (this->path_[d].node != path[d].node) // compare only the node
199 return false;
200
201 if (d > 0) {
202 for (size_t j = 0; j < d; ++j) {
203 if (this->path_[j] != path[j]) { // compare node and argument index
204 return false;
205 }
206 }
207 }
208
209 return true;
210 }
211
212};
213
214} // END cg namespace
215} // END CppAD namespace
216
217#endif
EvaluatorCloneSolve(CodeHandler< Scalar > &handler, const std::vector< const SourceCodePath * > &paths, const std::vector< const std::vector< CG< Scalar > * > * > &replaceOnPath)
ActiveOut evalOperation(OperationNode< Scalar > &node)
EvaluatorCloneSolve(CodeHandler< Scalar > &handler, const BidirGraph< Scalar > &pathGraph, const std::map< const PathNodeEdges< Scalar > *, CG< Scalar > > &replaceOnGraph)
EvaluatorCloneSolve(CodeHandler< Scalar > &handler, const std::set< const OperationNode< Scalar > * > &clone, const std::map< const OperationPathNode< Scalar >, CG< Scalar > > &replaceArgument)