1 #ifndef CPPAD_CG_EQUATION_PATTERN_INCLUDED
2 #define CPPAD_CG_EQUATION_PATTERN_INCLUDED
29 using MapDep2Indep_type = std::map<size_t, const OperationNode<Base>*>;
45 return op2Arguments.at(operation);
49 const auto itIndexes = op2Arguments.find(node);
50 if (itIndexes == op2Arguments.end()) {
67 const size_t depRefIndex;
68 std::set<size_t> dependents;
99 handler_(ref.getCodeHandler()) {
102 EquationPattern(
const EquationPattern<Base>& other) =
delete;
104 EquationPattern& operator=(
const EquationPattern<Base>& rhs) =
delete;
106 bool testAdd(
size_t iDep2,
107 const CG<Base>& dep2,
109 CodeHandlerVector<Base, size_t>& varColor) {
112 std::map<size_t, std::map<const OperationNode<Base>*, OperationNode<Base>*> > operation2ReferenceBackup =
operationEO2Reference;
115 minColor_ = minColor;
116 cmpColor_ = minColor_;
118 bool equals = comparePath(depRef, dep2, iDep2, varColor);
120 minColor = cmpColor_;
123 dependents.insert(iDep2);
128 indexedOpIndep.op2Arguments.swap(independentsBackup.op2Arguments);
136 inline void findIndexedPath(
size_t dep,
137 const std::vector<CG<Base> >& depVals,
138 CodeHandlerVector<Base, bool>& varIndexed,
139 std::set<
const OperationNode<Base>*>& indexedOperations) {
140 findIndexedPath(depRef, depVals[dep], varIndexed, indexedOperations);
143 std::set<const OperationNode<Base>*> findOperationsUsingIndependents(OperationNode<Base>& node)
const {
144 std::set<const OperationNode<Base>*> ops;
146 handler_->startNewOperationTreeVisit();
148 findOperationsWithIndeps(node, ops);
153 static inline void uncolor(OperationNode<Base>* node,
154 CodeHandlerVector<Base, bool>& varIndexed) {
155 if (node ==
nullptr || !varIndexed[*node])
158 varIndexed[*node] =
false;
160 const std::vector<Argument<Base> >& args = node->getArguments();
161 size_t size = args.size();
162 for (
size_t a = 0; a < size; a++) {
163 uncolor(args[a].getOperation(), varIndexed);
167 inline bool containsConstantIndependent(
const OperationNode<Base>* operation,
size_t argumentIndex)
const {
170 if (it->second.find(argumentIndex) != it->second.end()) {
182 using MapIndep2Dep_type =
typename OperationIndexedIndependents<Base>::MapDep2Indep_type;
194 for (
size_t argIndex = 0; argIndex < aSize; argIndex++) {
201 bool isIndexed =
false;
202 typename MapIndep2Dep_type::const_iterator itDep2Ind = dep2Ind.begin();
205 for (++itDep2Ind; itDep2Ind != dep2Ind.end(); ++itDep2Ind) {
206 if (indep != itDep2Ind->second) {
238 bool comparePath(
const CG<Base>& dep1,
246 if (h1 !=
nullptr && h2 !=
nullptr)
247 throw CGException(
"Only one code handler allowed");
255 OperationNode<Base>* depRefOp = dep1.getOperationNode();
256 OperationNode<Base>* dep2Op = dep2.getOperationNode();
257 CPPADCG_ASSERT_UNKNOWN(depRefOp->getOperationType() != CGOpCode::Inv)
259 return comparePath(depRefOp, dep2Op, dep2Index, varColor);
265 bool comparePath(OperationNode<Base>* scRef,
266 OperationNode<Base>* sc2,
268 CodeHandlerVector<Base, size_t>& varColor) {
269 saveOperationReference(dep2, sc2, scRef);
270 if (dependents.size() == 1) {
271 saveOperationReference(depRefIndex, scRef, scRef);
274 while (scRef->getOperationType() == CGOpCode::Alias) {
275 CPPADCG_ASSERT_KNOWN(scRef->getArguments().size() == 1,
"Invalid number of arguments for alias")
276 OperationNode<Base>* sc = scRef->getArguments()[0].getOperation();
277 if (sc !=
nullptr && sc->getOperationType() == CGOpCode::Inv) break;
280 while (sc2->getOperationType() == CGOpCode::Alias) {
281 CPPADCG_ASSERT_KNOWN(sc2->getArguments().size() == 1,
"Invalid number of arguments for alias")
282 OperationNode<Base>* sc = sc2->getArguments()[0].getOperation();
283 if (sc !=
nullptr && sc->getOperationType() == CGOpCode::Inv) break;
288 if (varColor[*sc2] >= minColor_ && varColor[*scRef] >= minColor_) {
298 if (varColor[*sc2] == varColor[*scRef])
301 varColor[*scRef] = cmpColor_;
302 varColor[*sc2] = cmpColor_;
306 if (scRef->getOperationType() != sc2->getOperationType()) {
310 CPPADCG_ASSERT_UNKNOWN(scRef->getOperationType() != CGOpCode::Inv)
312 const std::vector<size_t>& info1 = scRef->getInfo();
313 const std::vector<size_t>& info2 = sc2->getInfo();
314 if (info1.size() != info2.size()) {
318 for (
size_t e = 0; e < info1.size(); e++) {
319 if (info1[e] != info2[e]) {
324 const std::vector<Argument<Base> >& args1 = scRef->getArguments();
325 const std::vector<Argument<Base> >& args2 = sc2->getArguments();
326 size_t size = args1.size();
327 if (size != args2.size()) {
330 for (
size_t a = 0; a < size; a++) {
331 const Argument<Base>& a1 = args1[a];
332 const Argument<Base>& a2 = args2[a];
334 if (a1.getParameter() !=
nullptr) {
335 if (a2.getParameter() ==
nullptr || *a1.getParameter() != *a2.getParameter())
338 if (a2.getOperation() ==
nullptr) {
341 OperationNode<Base>* argRefOp = a1.getOperation();
342 OperationNode<Base>* arg2Op = a2.getOperation();
344 if (argRefOp->getOperationType() == CGOpCode::Inv) {
345 related = saveIndependent(scRef, a, argRefOp, arg2Op);
347 related = comparePath(argRefOp, arg2Op, dep2, varColor);
358 inline void saveOperationReference(
size_t dep2,
359 const OperationNode<Base>* sc2,
360 OperationNode<Base>* scRef) {
364 bool saveIndependent(
const OperationNode<Base>* parentOp,
366 const OperationNode<Base>* argRefOp,
367 const OperationNode<Base>* arg2Op) {
368 if (argRefOp->getOperationType() != CGOpCode::Inv || arg2Op->getOperationType() != CGOpCode::Inv) {
378 if (it->second.find(argIndex) != it->second.end()) {
383 OperationIndexedIndependents<Base>& opIndexedIndep =
indexedOpIndep.op2Arguments[parentOp];
384 opIndexedIndep.arg2Independents.resize(parentOp !=
nullptr ? parentOp->getArguments().size() : 1);
386 std::map<size_t, const OperationNode<Base>*>& dep2Indeps = opIndexedIndep.arg2Independents[argIndex];
387 if (dep2Indeps.empty())
388 dep2Indeps[depRefIndex] = argRefOp;
389 dep2Indeps[currDep_] = arg2Op;
394 inline void findIndexedPath(
const CG<Base>& depRef,
395 const CG<Base>& dep2,
396 CodeHandlerVector<Base, bool>& varIndexed,
397 std::set<
const OperationNode<Base>*>& indexedOperations) {
398 if (depRef.isVariable() && dep2.isVariable()) {
399 OperationNode<Base>* depRefOp = depRef.getOperationNode();
400 OperationNode<Base>* dep2Op = dep2.getOperationNode();
401 if (depRefOp->getOperationType() != CGOpCode::Inv) {
402 findIndexedPath(depRefOp, dep2Op, varIndexed, indexedOperations);
405 typename std::map<const OperationNode<Base>*, OperationIndexedIndependents<Base> >::iterator itop2a;
407 if (itop2a !=
indexedOpIndep.op2Arguments.end() && !itop2a->second.arg2Independents[0].empty()) {
409 indexedOperations.insert(
nullptr);
415 inline bool findIndexedPath(
const OperationNode<Base>* scRef,
416 OperationNode<Base>* sc2,
417 CodeHandlerVector<Base, bool>& varIndexed,
418 std::set<
const OperationNode<Base>*>& indexedOperations) {
420 while (scRef->getOperationType() == CGOpCode::Alias) {
421 CPPADCG_ASSERT_KNOWN(scRef->getArguments().size() == 1,
"Invalid number of arguments for alias")
422 OperationNode<Base>* sc = scRef->getArguments()[0].getOperation();
423 if (sc !=
nullptr && sc->getOperationType() == CGOpCode::Inv) break;
426 while (sc2->getOperationType() == CGOpCode::Alias) {
427 CPPADCG_ASSERT_KNOWN(sc2->getArguments().size() == 1,
"Invalid number of arguments for alias")
428 OperationNode<Base>* sc = sc2->getArguments()[0].getOperation();
429 if (sc !=
nullptr && sc->getOperationType() == CGOpCode::Inv) break;
433 CPPADCG_ASSERT_UNKNOWN(scRef->getOperationType() == sc2->getOperationType())
435 const std::vector<Argument<Base> >& argsRef = scRef->getArguments();
437 typename std::map<const OperationNode<Base>*, OperationIndexedIndependents<Base> >::iterator itop2a;
438 bool searched = false;
439 bool indexedDependentPath = false;
440 bool usesIndexedIndependent = false;
442 size_t size = argsRef.size();
443 for (
size_t a = 0; a < size; a++) {
444 OperationNode<Base>* argRefOp = argsRef[a].getOperation();
445 if (argRefOp !=
nullptr) {
446 bool indexedArg =
false;
447 if (argRefOp->getOperationType() == CGOpCode::Inv) {
453 if (itop2a !=
indexedOpIndep.op2Arguments.end() && !itop2a->second.arg2Independents[a].empty()) {
456 indexedDependentPath =
true;
457 usesIndexedIndependent =
true;
462 const std::vector<Argument<Base> >& args2 = sc2->getArguments();
463 CPPADCG_ASSERT_UNKNOWN(size == args2.size())
464 indexedDependentPath |= findIndexedPath(argsRef[a].getOperation(), args2[a].getOperation(), varIndexed, indexedOperations);
469 varIndexed[*sc2] = indexedDependentPath;
471 if (usesIndexedIndependent)
472 indexedOperations.insert(sc2);
474 return indexedDependentPath;
477 void findOperationsWithIndeps(OperationNode<Base>& node,
478 std::set<const OperationNode<Base>*>& ops)
const {
479 if (handler_->isVisited(node))
482 handler_->markVisited(node);
484 const std::vector<Argument<Base> >& args = node.getArguments();
485 size_t size = args.size();
486 for (
size_t a = 0; a < size; a++) {
487 OperationNode<Base>* argOp = args[a].getOperation();
488 if (argOp !=
nullptr) {
489 if (argOp->getOperationType() == CGOpCode::Inv) {
492 findOperationsWithIndeps(*argOp, ops);