CppADCodeGen 2.4.3
A C++ Algorithmic Differentiation Package with Source Code Generation
Loading...
Searching...
No Matches
operation_path.hpp
1#ifndef CPPAD_CG_OPERATION_PATH_INCLUDED
2#define CPPAD_CG_OPERATION_PATH_INCLUDED
3/* --------------------------------------------------------------------------
4 * CppADCodeGen: C++ Algorithmic Differentiation with Source Code Generation:
5 * Copyright (C) 2012 Ciengis
6 *
7 * CppADCodeGen is distributed under multiple licenses:
8 *
9 * - Eclipse Public License Version 1.0 (EPL1), and
10 * - GNU General Public License Version 3 (GPL3).
11 *
12 * EPL1 terms and conditions can be found in the file "epl-v10.txt", while
13 * terms and conditions for the GPL3 can be found in the file "gpl3.txt".
14 * ----------------------------------------------------------------------------
15 * Author: Joao Leal
16 */
17
18#include <cppad/cg/operation_path_node.hpp>
19#include <cppad/cg/bidir_graph.hpp>
20
21namespace CppAD {
22namespace cg {
23
34template<class Base>
35inline bool findPathGraph(BidirGraph<Base>& foundGraph,
36 OperationNode<Base>& root,
37 OperationNode<Base>& target,
38 size_t& bifurcations,
39 size_t maxBifurcations = (std::numeric_limits<size_t>::max)()) {
40 if (bifurcations >= maxBifurcations) {
41 return false;
42 }
43
44 if (&root == &target) {
45 return true;
46 }
47
48 if(foundGraph.contains(root)) {
49 return true; // been here and it was saved in foundGraph
50 }
51
52 auto* h = root.getCodeHandler();
53
54 if(h->isVisited(root)) {
55 return false; // been here but it was not saved in foundGraph
56 }
57
58 // not visited yet
59 h->markVisited(root); // mark node as visited
60
61 PathNodeEdges<Base>& info = foundGraph[root];
62
63 const auto& args = root.getArguments();
64
65 bool found = false;
66 for(size_t i = 0; i < args.size(); ++i) {
67 const Argument<Base>& a = args[i];
68 if(a.getOperation() != nullptr ) {
69 auto& aNode = *a.getOperation();
70 if(findPathGraph(foundGraph, aNode, target, bifurcations, maxBifurcations)) {
71 foundGraph.connect(info, root, i);
72 if(found) {
73 bifurcations++; // multiple ways to get to target
74 } else {
75 found = true;
76 }
77 }
78 }
79 }
80
81 if(!found) {
82 foundGraph.erase(root);
83 }
84
85 return found;
86}
87
88template<class Base>
89inline BidirGraph<Base> CodeHandler<Base>::findPathGraph(OperationNode<Base>& root,
90 OperationNode<Base>& target) {
91 size_t bifurcations = 0;
92 return findPathGraph(root, target, bifurcations);
93}
94
95template<class Base>
96inline BidirGraph<Base> CodeHandler<Base>::findPathGraph(OperationNode<Base>& root,
97 OperationNode<Base>& target,
98 size_t& bifurcations,
99 size_t maxBifurcations) {
100 startNewOperationTreeVisit();
101
102 BidirGraph<Base> foundGraph;
103
104 if (bifurcations <= maxBifurcations) {
105 if (&root == &target) {
106 foundGraph[root];
107 } else {
108 CppAD::cg::findPathGraph<Base>(foundGraph, root, target, bifurcations, maxBifurcations);
109 }
110 }
111
112 return foundGraph;
113}
114
115
116template<class Base>
117inline std::vector<std::vector<OperationPathNode<Base> > > CodeHandler<Base>::findPaths(OperationNode<Base>& root,
119 size_t max) {
120 std::vector<std::vector<OperationPathNode<Base> > > found;
121
122 startNewOperationTreeVisit();
123
124 if (max > 0) {
125 std::vector<OperationPathNode<Base> > path2node;
126 path2node.reserve(30);
127 path2node.push_back(OperationPathNode<Base> (&root, 0));
128
129 if (&root == &code) {
130 found.push_back(path2node);
131 } else {
132 findPaths(path2node, code, found, max);
133 }
134 }
135
136 return found;
137}
138
139template<class Base>
140inline void CodeHandler<Base>::findPaths(SourceCodePath& currPath,
142 std::vector<SourceCodePath>& found,
143 size_t max) {
144
145 OperationNode<Base>* currNode = currPath.back().node;
146 if (&code == currNode) {
147 found.push_back(currPath);
148 return;
149 }
150
151 const std::vector<Argument<Base> >& args = currNode->getArguments();
152 if (args.empty())
153 return; // nothing to look in
154
155 if (isVisited(*currNode)) {
156 // already searched inside this node
157 // any match would have been saved in found
158 std::vector<SourceCodePath> pathsFromNode = findPathsFromNode(found, *currNode);
159 for (const SourceCodePath& pathFromNode : pathsFromNode) {
160 SourceCodePath newPath(currPath.size() + pathFromNode.size());
161 std::copy(currPath.begin(), currPath.end(), newPath.begin());
162 std::copy(pathFromNode.begin(), pathFromNode.end(), newPath.begin() + currPath.size());
163 found.push_back(newPath);
164 }
165
166 } else {
167 // not visited yet
168 markVisited(*currNode); // mark node as visited
169
170 size_t size = args.size();
171 for (size_t i = 0; i < size; ++i) {
172 OperationNode<Base>* a = args[i].getOperation();
173 if (a != nullptr) {
174 currPath.push_back(OperationPathNode<Base> (a, i));
175 findPaths(currPath, code, found, max);
176 currPath.pop_back();
177 if (found.size() == max) {
178 return;
179 }
180 }
181 }
182 }
183}
184
185template<class Base>
186inline std::vector<std::vector<OperationPathNode<Base> > > CodeHandler<Base>::findPathsFromNode(const std::vector<SourceCodePath> nodePaths,
187 OperationNode<Base>& node) {
188
189 std::vector<SourceCodePath> foundPaths;
190 std::set<size_t> argsFound;
191
192 for (const SourceCodePath& path : nodePaths) {
193 size_t size = path.size();
194 for (size_t i = 0; i < size - 1; i++) {
195 const OperationPathNode<Base>& pnode = path[i];
196 if (pnode.node == &node) {
197 if (argsFound.find(path[i + 1].argIndex) == argsFound.end()) {
198 foundPaths.push_back(SourceCodePath(path.begin() + i + 1, path.end()));
199 argsFound.insert(path[i + 1].argIndex);
200 }
201 }
202 }
203 }
204
205 return foundPaths;
206}
207
208} // END cg namespace
209} // END CppAD namespace
210
211#endif
std::vector< SourceCodePath > findPaths(Node &root, Node &target, size_t max)
const std::vector< Argument< Base > > & getArguments() const