CppADCodeGen 2.4.3
A C++ Algorithmic Differentiation Package with Source Code Generation
Loading...
Searching...
No Matches
solver.hpp
1#ifndef CPPAD_CG_SOLVER_INCLUDED
2#define CPPAD_CG_SOLVER_INCLUDED
3/* --------------------------------------------------------------------------
4 * CppADCodeGen: C++ Algorithmic Differentiation with Source Code Generation:
5 * Copyright (C) 2012 Ciengis
6 *
7 * CppADCodeGen is distributed under multiple licenses:
8 *
9 * - Eclipse Public License Version 1.0 (EPL1), and
10 * - GNU General Public License Version 3 (GPL3).
11 *
12 * EPL1 terms and conditions can be found in the file "epl-v10.txt", while
13 * terms and conditions for the GPL3 can be found in the file "gpl3.txt".
14 * ----------------------------------------------------------------------------
15 * Author: Joao Leal
16 */
17
18#include <cppad/cg/evaluator/evaluator_solve.hpp>
19#include <cppad/cg/lang/dot/dot.hpp>
20
21namespace CppAD {
22namespace cg {
23
24template<class Base>
27 using std::vector;
28
29 // find code in expression
30 if (&expression == &var)
31 return CG<Base>(var);
32
33 size_t bifurcations = (std::numeric_limits<size_t>::max)(); // so that it is possible to enter the loop
34
35 std::vector<SourceCodePath> paths;
36 BidirGraph<Base> foundGraph;
37 OperationNode<Base> *root = &expression;
38
39 while (bifurcations > 0) {
40 CPPADCG_ASSERT_UNKNOWN(root != nullptr);
41
42 // find possible paths from expression to var
43 size_t oldBif = bifurcations;
44 bifurcations = 0;
45 foundGraph = findPathGraph(*root, var, bifurcations, 50000);
46 CPPADCG_ASSERT_UNKNOWN(oldBif > bifurcations);
47
48 if (!foundGraph.contains(var)) {
49 std::cerr << "Missing variable " << var << std::endl;
50 printExpression(expression, std::cerr);
51 throw CGException("The provided variable ", var.getName() != nullptr ? ("(" + *var.getName() + ")") : "", " is not present in the expression");
52 }
53
54 // find a bifurcation which does not contain any other bifurcations
55 size_t bifPos = 0;
56 paths = foundGraph.findSingleBifurcation(*root, var, bifPos);
57 if (paths.empty()) {
58 throw CGException("The provided variable is not present in the expression");
59
60 } else if (paths.size() == 1) {
61 CPPADCG_ASSERT_UNKNOWN(paths[0][0].node == root);
62 CPPADCG_ASSERT_UNKNOWN(paths[0].back().node == &var);
63
64 return solveFor(paths[0]);
65
66 } else {
67 CPPADCG_ASSERT_UNKNOWN(paths.size() >= 1);
68 CPPADCG_ASSERT_UNKNOWN(paths[0].back().node == &var);
69
70 CG<Base> expression2 = collectVariable(*root, paths[0], paths[1], bifPos);
71 root = expression2.getOperationNode();
72 if (root == nullptr) {
73 throw CGException("It is not possible to solve the expression for the requested variable: the variable disappears after symbolic manipulations (e.g., y=x-x).");
74 }
75 }
76 }
77
78 CPPADCG_ASSERT_UNKNOWN(paths.size() == 1);
79 return solveFor(paths[0]);
80}
81
82template<class Base>
83inline CG<Base> CodeHandler<Base>::solveFor(const SourceCodePath& path) {
84
85 CG<Base> rightHs(0.0);
86
87 for (size_t n = 0; n < path.size() - 1; ++n) {
88 const OperationPathNode<Base>& pnodeOp = path[n];
89 size_t argIndex = path[n].argIndex;
90 const std::vector<Argument<Base> >& args = pnodeOp.node->getArguments();
91
92 CGOpCode op = pnodeOp.node->getOperationType();
93 switch (op) {
94 case CGOpCode::Mul:
95 {
96 const Argument<Base>& other = args[argIndex == 0 ? 1 : 0];
97 rightHs /= CG<Base>(other);
98 break;
99 }
100 case CGOpCode::Div:
101 if (argIndex == 0) {
102 const Argument<Base>& other = args[1];
103 rightHs *= CG<Base>(other);
104 } else {
105 const Argument<Base>& other = args[0];
106 rightHs = CG<Base>(other) / rightHs;
107 }
108 break;
109
110 case CGOpCode::UnMinus:
111 rightHs *= Base(-1.0);
112 break;
113 case CGOpCode::Add:
114 {
115 const Argument<Base>& other = args[argIndex == 0 ? 1 : 0];
116 rightHs -= CG<Base>(other);
117 break;
118 }
119 case CGOpCode::Alias:
120 // do nothing
121 break;
122 case CGOpCode::Sub:
123 {
124 if (argIndex == 0) {
125 rightHs += CG<Base>(args[1]);
126 } else {
127 rightHs = CG<Base>(args[0]) - rightHs;
128 }
129 break;
130 }
131 case CGOpCode::Exp:
132 rightHs = log(rightHs);
133 break;
134 case CGOpCode::Log:
135 rightHs = exp(rightHs);
136 break;
137 case CGOpCode::Pow:
138 {
139 if (argIndex == 0) {
140 // base
141 const Argument<Base>& exponent = args[1];
142 if (exponent.getParameter() != nullptr && *exponent.getParameter() == Base(0.0)) {
143 throw CGException("Invalid zero exponent");
144 } else if (exponent.getParameter() != nullptr && *exponent.getParameter() == Base(1.0)) {
145 continue; // do nothing
146 } else {
147 throw CGException("Unable to invert operation '", op, "'");
148 /*
149 if (exponent.getParameter() != nullptr && *exponent.getParameter() == Base(2.0)) {
150 rightHs = sqrt(rightHs); // TODO: should -sqrt(rightHs) somehow be considered???
151 } else {
152 rightHs = pow(rightHs, Base(1.0) / CG<Base>(exponent));
153 }
154 */
155 }
156 } else {
157 //
158 const Argument<Base>& base = args[0];
159 rightHs = log(rightHs) / log(CG<Base>(base));
160 }
161 break;
162 }
163 case CGOpCode::Sqrt:
164 rightHs *= rightHs;
165 break;
166 //case CGAcosOp: // asin(variable)
167 //case CGAsinOp: // asin(variable)
168 //case Atan: // atan(variable)
169 case CGOpCode::Cosh: // cosh(variable)
170 {
171 rightHs = log(rightHs + sqrt(rightHs * rightHs - Base(1.0))); // asinh
172 break;
173 //case Cos: // cos(variable)
174 }
175 case CGOpCode::Sinh: // sinh(variable)
176 rightHs = log(rightHs + sqrt(rightHs * rightHs + Base(1.0))); // asinh
177 break;
178 //case CGSinOp: // sin(variable)
179 case CGOpCode::Tanh: // tanh(variable)
180 rightHs = Base(0.5) * (log(Base(1.0) + rightHs) - log(Base(1.0) - rightHs)); // atanh
181 break;
182 //case CGTanOp: // tan(variable)
183 default:
184 throw CGException("Unable to invert operation '", op, "'");
185 };
186 }
187
188 return rightHs;
189}
190
191template<class Base>
193 OperationNode<Base>& var) {
194 size_t bifurcations = 0;
195 BidirGraph<Base> g = findPathGraph(expression, var, bifurcations);
196
197 if(bifurcations == 0) {
198 size_t bifIndex = 0;
199 auto paths = g.findSingleBifurcation(expression, var, bifIndex);
200 if (paths.empty() || paths[0].empty())
201 return false;
202
203 return isSolvable(paths[0]);
204 } else {
205 // TODO: improve this
206 //bool v = isCollectableVariableAddSub();
207 try {
208 solveFor(expression, var);
209 return true;
210 } catch(const CGException& e) {
211 return false;
212 }
213 }
214}
215
216template<class Base>
217inline bool CodeHandler<Base>::isSolvable(const SourceCodePath& path) const {
218 for (size_t n = 0; n < path.size() - 1; ++n) {
219 const OperationPathNode<Base>& pnodeOp = path[n];
220 size_t argIndex = path[n].argIndex;
221 const std::vector<Argument<Base> >& args = pnodeOp.node->getArguments();
222
223 CGOpCode op = pnodeOp.node->getOperationType();
224 switch (op) {
225 case CGOpCode::Mul:
226 case CGOpCode::Div:
227 case CGOpCode::UnMinus:
228 case CGOpCode::Add:
229 case CGOpCode::Alias:
230 case CGOpCode::Sub:
231 case CGOpCode::Exp:
232 case CGOpCode::Log:
233 case CGOpCode::Sqrt:
234 case CGOpCode::Cosh: // cosh(variable)
235 case CGOpCode::Sinh: // sinh(variable)
236 case CGOpCode::Tanh: // tanh(variable)
237 break;
238 case CGOpCode::Pow:
239 {
240 if (argIndex == 0) {
241 // base
242 const Argument<Base>& exponent = args[1];
243 if (exponent.getParameter() != nullptr && *exponent.getParameter() == Base(0.0)) {
244 return false;
245 } else if (exponent.getParameter() != nullptr && *exponent.getParameter() == Base(1.0)) {
246 break;
247 } else {
248 return false;
249 }
250 } else {
251 break;
252 }
253 break;
254 }
255
256 default:
257 return false;
258 };
259 }
260 return true;
261}
262
263} // END cg namespace
264} // END CppAD namespace
265
266#endif
std::vector< SourceCodePath > findSingleBifurcation(Node &expression, Node &target, size_t &bifIndex) const
CGB solveFor(Node &expression, Node &var)
Definition solver.hpp:25
const std::string * getName() const