1 #ifndef CPPAD_CG_DEPENDENT_PATTERN_MATCHER_INCLUDED
2 #define CPPAD_CG_DEPENDENT_PATTERN_MATCHER_INCLUDED
32 if (e1->depRefIndex < e2->depRefIndex) {
44 return p1.eq1->depRefIndex < p2.eq1->depRefIndex || (!(p2.eq1->depRefIndex < p1.eq1->depRefIndex) && p1.eq2->depRefIndex < p2.eq2->depRefIndex);
56 enum class INDEXED_OPERATION_TYPE {
61 using Indexed2OpCountType = std::pair<INDEXED_OPERATION_TYPE, size_t>;
62 using Dep1Dep2SharedType = std::map<size_t, std::map<size_t, std::map<OperationNode<Base>*, Indexed2OpCountType> > >;
63 using DepPairType = std::pair<size_t, size_t>;
64 using TotalOps2validDepsType = std::map<size_t, std::map<DepPairType, const std::map<OperationNode<Base>*, Indexed2OpCountType>* > >;
65 using Eq2totalOps2validDepsType = std::map<UniqueEquationPair<Base>, TotalOps2validDepsType*>;
66 using MaxOps2eq2totalOps2validDepsType = std::map<size_t, Eq2totalOps2validDepsType>;
72 const std::vector<std::set<size_t> >& relatedDepCandidates_;
73 std::vector<CGBase> dependents_;
74 const std::vector<CGBase>& independents_;
75 std::vector<EquationPattern<Base>*> equations_;
77 std::map<size_t, EquationPattern<Base>*> dep2Equation_;
78 std::map<EquationPattern<Base>*,
Loop<Base>*> equation2Loop_;
79 std::vector<Loop<Base>*> loops_;
83 std::map<EquationPattern<Base>*, std::set<EquationPattern<Base>*> > incompatible_;
87 std::map<UniqueEquationPair<Base>, Dep1Dep2SharedType> equationShared_;
92 std::map<OperationNode<Base>*,
size_t> origTemp2Index_;
93 std::vector<std::set<size_t> > id2Deps;
114 const std::vector<CGBase>& dependents,
115 const std::vector<CGBase>& independents) :
116 handler_(independents[0].getCodeHandler()),
118 varIndexed_(*handler_),
119 relatedDepCandidates_(relatedDepCandidates),
120 dependents_(dependents),
121 independents_(independents),
123 origShareNodeId_(*handler_),
125 CPPADCG_ASSERT_UNKNOWN(independents_.size() > 0)
126 CPPADCG_ASSERT_UNKNOWN(independents_[0].getCodeHandler() !=
nullptr)
127 equations_.reserve(relatedDepCandidates_.size());
128 origShareNodeId_.adjustSize();
131 const std::vector<EquationPattern<Base>*>& getEquationPatterns()
const {
135 const std::vector<Loop<Base>*>& getLoops()
const {
151 for (
size_t j = 0; j < independents_.size(); j++) {
152 std::vector<size_t>& info = independents_[j].getOperationNode()->getInfo();
159 nonLoopTape = createNewTape();
162 for (
size_t l = 0; l < loops_.size(); l++) {
164 loopTapes.insert(loop->releaseLoopModel());
169 for (
size_t l = 0; l < loops_.size(); l++) {
182 virtual std::vector<Loop<Base>*> findLoops() {
185 size_t rSize = relatedDepCandidates_.size();
186 for (
size_t r = 0; r < rSize; r++) {
187 const std::set<size_t>& candidates = relatedDepCandidates_[r];
188 for (
size_t iDep : candidates) {
189 OperationNode<Base>* node = dependents_[iDep].getOperationNode();
190 if (node !=
nullptr && node->getOperationType() == CGOpCode::Inv) {
199 CPPADCG_ASSERT_UNKNOWN(handler_ == dependents_[iDep].getCodeHandler())
200 dependents_[iDep] = CG<Base>(*handler_->makeNode(CGOpCode::Alias, *node));
210 id2Deps.resize(idCounter_ + 1);
215 findRelatedVariables();
217 for (EquationPattern<Base>* eq : equations_) {
218 for (
size_t depIt : eq->dependents) {
219 dep2Equation_[depIt] = eq;
223 const size_t eq_size = equations_.size();
224 loops_.reserve(eq_size);
226 SmartSetPointer<set<size_t> > dependentRelations;
227 std::vector<set<size_t>*> dep2Relations(dependents_.size(),
nullptr);
228 map<size_t, set<size_t> > dependentBlackListRelations;
237 varIndexed_.adjustSize();
238 varIndexed_.fill(
false);
240 for (
size_t e = 0; e < eq_size; e++) {
241 EquationPattern<Base>* eq = equations_[e];
244 for (
size_t depIt : eq->dependents) {
245 OperationNode<Base>* node = dependents_[depIt].getOperationNode();
247 markOperationsWithDependent(node, depIt);
254 handler_->startNewOperationTreeVisit();
256 for (
size_t depIt : eq->dependents) {
257 findSharedTemporaries(dependents_[depIt], depIt);
263 for (
size_t depIt : eq->dependents) {
264 OperationNode<Base>* node = dependents_[depIt].getOperationNode();
265 EquationPattern<Base>::uncolor(node, varIndexed_);
270 auto* loop =
new Loop<Base>(*eq);
271 loops_.push_back(loop);
272 equation2Loop_[eq] = loop;
278 MaxOps2eq2totalOps2validDepsType maxOps2Eq2totalOps2validDeps;
279 Eq2totalOps2validDepsType eq2totalOps2validDeps;
280 SmartListPointer<TotalOps2validDepsType> totalOps2validDepsMem;
287 for (
size_t l1 = 0; l1 < loops_.size(); l1++) {
288 Loop<Base>* loop1 = loops_[l1];
289 CPPADCG_ASSERT_UNKNOWN(loop1->equations.size() == 1)
290 EquationPattern<Base>* eq1 = *loop1->equations.begin();
292 for (
size_t l2 = l1 + 1; l2 < loops_.size(); l2++) {
293 Loop<Base>* loop2 = loops_[l2];
294 CPPADCG_ASSERT_UNKNOWN(loop2->equations.size() == 1)
295 EquationPattern<Base>* eq2 = *loop2->equations.begin();
297 UniqueEquationPair<Base> eqRel(eq1, eq2);
298 const auto eqSharedit = equationShared_.find(eqRel);
299 if (eqSharedit == equationShared_.end())
302 const Dep1Dep2SharedType& dep1Dep2Shared = eqSharedit->second;
307 auto* totalOps2validDeps = new TotalOps2validDepsType();
308 totalOps2validDepsMem.push_back(totalOps2validDeps);
311 bool canCombine = true;
316 for (const auto& itDep1Dep2 : dep1Dep2Shared) {
317 size_t dep1 = itDep1Dep2.first;
318 const map<size_t, map<OperationNode<Base>*, Indexed2OpCountType> >& dep2Shared = itDep1Dep2.second;
321 for (
const auto& itDep2 : dep2Shared) {
322 size_t dep2 = itDep2.first;
323 const map<OperationNode<Base>*, Indexed2OpCountType>& sharedTmps = itDep2.second;
326 for (
const auto& itShared : sharedTmps) {
327 if (itShared.second.first == INDEXED_OPERATION_TYPE::BOTH) {
336 totalOps += itShared.second.second;
340 if (!canCombine)
break;
342 DepPairType depRel(dep1, dep2);
343 (*totalOps2validDeps)[totalOps][depRel] = &sharedTmps;
344 maxOps = std::max<size_t>(maxOps, totalOps);
347 if (!canCombine)
break;
351 maxOps2Eq2totalOps2validDeps[maxOps][eqRel] = totalOps2validDeps;
352 eq2totalOps2validDeps[eqRel] = totalOps2validDeps;
354 incompatible_[eq1].insert(eq2);
355 incompatible_[eq2].insert(eq1);
356 totalOps2validDepsMem.pop_back();
357 delete totalOps2validDeps;
365 typename MaxOps2eq2totalOps2validDepsType::const_reverse_iterator itMaxOps;
366 for (itMaxOps = maxOps2Eq2totalOps2validDeps.rbegin(); itMaxOps != maxOps2Eq2totalOps2validDeps.rend(); ++itMaxOps) {
367 #ifdef CPPADCG_PRINT_DEBUG
368 std::cout <<
"\n\nmaxOps: " << itMaxOps->first <<
" count:" << itMaxOps->second.size() << std::endl;
371 for (
const auto& itEqPair : itMaxOps->second) {
372 const UniqueEquationPair<Base>& eqRel = itEqPair.first;
373 #ifdef CPPADCG_PRINT_DEBUG
374 std::cout <<
" eq1: " << *eqRel.eq1->dependents.begin() <<
" eq2: " << *eqRel.eq2->dependents.begin() << std::endl;
377 Loop<Base>* loop1 = equation2Loop_.at(eqRel.eq1);
378 Loop<Base>* loop2 = equation2Loop_.at(eqRel.eq2);
382 if (contains(incompatible_, eqRel.eq1, eqRel.eq2))
389 SmartSetPointer<set<size_t> > dependentRelationsBak;
390 for (
const set<size_t>* its : dependentRelations) {
391 dependentRelationsBak.insert(
new set<size_t>(*its));
395 set<set<size_t>*> loopRelations;
397 set<EquationPattern<Base>*> indexedLoopRelations;
398 std::vector<std::pair<EquationPattern<Base>*, EquationPattern<Base>*> > nonIndexedLoopRelations;
403 bool compatible = isCompatible(loop1, loop2,
404 eq2totalOps2validDeps,
405 dep2Relations, dependentBlackListRelations, dependentRelations,
406 loopRelations, indexedLoopRelations, nonIndexedLoopRelations);
412 for (EquationPattern<Base>* itle : loop2->equations) {
413 equation2Loop_[itle] = loop1;
415 loop1->merge(*loop2, indexedLoopRelations, nonIndexedLoopRelations);
417 typename std::vector<Loop<Base>*>::const_iterator it = std::find(loops_.cbegin(), loops_.cend(), loop2);
418 CPPADCG_ASSERT_UNKNOWN(it != loops_.end())
422 loop1->setLinkedDependents(loopRelations);
427 dependentRelations.s.swap(dependentRelationsBak.s);
429 std::fill(dep2Relations.begin(), dep2Relations.end(),
nullptr);
430 for (set<size_t>* relation : dependentRelations) {
431 for (
size_t itd : *relation) {
432 dep2Relations[itd] = relation;
444 for (
size_t l = 0; l < loops_.size(); l++) {
445 loops_[l]->generateDependentLoopIndexes(dep2Equation_);
451 if (!loops_.empty()) {
452 for (
size_t l1 = 0; l1 < loops_.size() - 1; l1++) {
453 Loop<Base>* loop1 = loops_[l1];
454 for (
size_t l2 = l1 + 1; l2 < loops_.size();) {
455 Loop<Base>* loop2 = loops_[l2];
457 bool canMerge = loop1->getIterationCount() == loop2->getIterationCount();
460 canMerge = !find(loop1, loop2, incompatible_);
464 loop1->mergeEqGroups(*loop2);
465 loops_.erase(loops_.begin() + l2);
474 size_t l_size = loops_.size();
479 for (
size_t l = 0; l < l_size; l++) {
480 Loop<Base>* loop = loops_[l];
483 loop->createLoopModel(dependents_, independents_, dep2Equation_, origTemp2Index_);
489 resetHandlerCounters();
502 inline bool isCompatible(Loop<Base>* loop1,
504 const Eq2totalOps2validDepsType& eq2totalOps2validDeps,
505 std::vector<std::set<size_t>* >& dep2Relations,
506 std::map<
size_t, std::set<size_t> >& dependentBlackListRelations,
507 SmartSetPointer<std::set<size_t> >& dependentRelations,
508 std::set<std::set<size_t>*>& loopRelations,
509 std::set<EquationPattern<Base>*>& indexedLoopRelations,
510 std::vector<std::pair<EquationPattern<Base>*, EquationPattern<Base>*> >& nonIndexedLoopRelations) {
513 bool compatible =
true;
519 map<size_t, map<UniqueEquationPair<Base>, TotalOps2validDepsType*> > totalOp2eq;
521 for (EquationPattern<Base>* eq1 : loop1->equations) {
523 for (EquationPattern<Base>* eq2 : loop2->equations) {
525 UniqueEquationPair<Base> eqRel(eq1, eq2);
527 typename Eq2totalOps2validDepsType::const_iterator eqSharedit = eq2totalOps2validDeps.find(eqRel);
528 if (eqSharedit == eq2totalOps2validDeps.end())
532 size_t maxOps = eqSharedit->second->rbegin()->first;
533 totalOp2eq[maxOps][eqRel] = eqSharedit->second;
537 typename map<size_t, map<UniqueEquationPair<Base>, TotalOps2validDepsType*> >::const_reverse_iterator itr;
538 for (itr = totalOp2eq.rbegin(); itr != totalOp2eq.rend(); ++itr) {
541 for (
const auto& itEq : itr->second) {
542 EquationPattern<Base>* eq1 = itEq.first.eq1;
543 EquationPattern<Base>* eq2 = itEq.first.eq2;
544 TotalOps2validDepsType& totalOps2validDeps = *itEq.second;
550 typename map<size_t, map<DepPairType, const map<OperationNode<Base>*, Indexed2OpCountType>* > >::const_reverse_iterator itOp2Dep2Shared;
551 for (itOp2Dep2Shared = totalOps2validDeps.rbegin(); itOp2Dep2Shared != totalOps2validDeps.rend(); ++itOp2Dep2Shared) {
552 #ifdef CPPADCG_PRINT_DEBUG
553 std::cout <<
" operation count: " << itOp2Dep2Shared->first <<
" relations: " << itOp2Dep2Shared->second.size() << std::endl;
555 for (
const auto& itDep2Shared : itOp2Dep2Shared->second) {
556 DepPairType depRel = itDep2Shared.first;
557 size_t dep1 = depRel.first;
558 size_t dep2 = depRel.second;
560 const map<OperationNode<Base>*, Indexed2OpCountType>& shared = *itDep2Shared.second;
565 compatible = findDepRelations(eq1, dep1, eq2, dep2, shared,
566 dep2Relations, dependentBlackListRelations, dependentRelations);
567 if (!compatible)
break;
570 loopRelations.clear();
576 std::vector<Loop<Base>*> loops(2);
579 bool nonIndexedOnly =
true;
580 for (
size_t l = 0; l < 2; l++) {
581 Loop<Base>* loop = loops[l];
582 for (EquationPattern<Base>* eq : loop->equations) {
583 for (
size_t dep : eq->dependents) {
584 if (dep2Relations[dep] !=
nullptr) {
585 loopRelations.insert(dep2Relations[dep]);
586 nonIndexedOnly =
false;
594 if (nonIndexedOnly) {
595 nonIndexedLoopRelations.push_back(std::make_pair(eq1, eq2));
599 size_t nNonIndexedRel1 = loop1->getLinkedEquationsByNonIndexedCount();
600 size_t nNonIndexedRel2 = loop2->getLinkedEquationsByNonIndexedCount();
601 size_t requiredSize = loop1->equations.size() + loop2->equations.size() - nNonIndexedRel1 - nNonIndexedRel2;
603 for (set<size_t>* relations : loopRelations) {
604 if (relations->size() == requiredSize) {
609 #ifdef CPPADCG_PRINT_DEBUG
611 std::cout <<
" loopRelations:";
612 print(loopRelations);
613 std::cout << std::endl;
619 if (!compatible)
break;
623 incompatible_[eq1].insert(eq2);
624 incompatible_[eq2].insert(eq1);
627 indexedLoopRelations.insert(eq1);
628 indexedLoopRelations.insert(eq2);
632 if (!compatible)
break;
646 bool findDepRelations(EquationPattern<Base>* eq1,
648 EquationPattern<Base>* eq2,
650 const std::map<OperationNode<Base>*, Indexed2OpCountType>& sharedNodes,
651 std::vector<std::set<size_t>* >& dep2Relations,
652 std::map<
size_t, std::set<size_t> >& dependentBlackListRelations,
653 SmartSetPointer<std::set<size_t> >& dependentRelations) {
656 for (
const auto& itShared : sharedNodes) {
657 OperationNode<Base>* sharedNode = itShared.first;
660 bool compatible = canCombineEquations(*eq1, dep1, *eq2, dep2, *sharedNode,
661 dep2Relations, dependentBlackListRelations, dependentRelations);
663 if (!compatible)
return false;
669 void groupByLoopEqOp(EquationPattern<Base>* eq,
670 std::map<Loop<Base>*, std::map<EquationPattern<Base>*, std::map<
size_t, std::pair<OperationNode<Base>*,
bool> > > >& loopSharedTemps,
671 const std::map<OperationNode<Base>*, std::set<size_t> >& opShared,
675 for (OperationNode<Base>* shared : opShared) {
676 const set<size_t>& deps = id2Deps[varId_[*shared]];
678 for (
size_t dep : deps) {
679 EquationPattern<Base>* otherEq = dep2Equation_.at(dep);
681 Loop<Base>* loop = equation2Loop_.at(otherEq);
684 loopSharedTemps[loop][otherEq][origShareNodeId_[shared]] = std::make_pair(shared, indexed);
697 virtual LoopFreeModel<Base>* createNewTape() {
698 CPPADCG_ASSERT_UNKNOWN(handler_ == independents_[0].getCodeHandler());
700 size_t m = dependents_.size();
701 std::vector<bool> inLoop(m,
false);
702 size_t eqInLoopCount = 0;
707 size_t l_size = loops_.size();
709 for (
size_t l = 0; l < l_size; l++) {
710 Loop<Base>* loop = loops_[l];
711 LoopModel<Base>* loopModel = loop->getModel();
716 const std::vector<std::vector<LoopPosition> >& ldeps = loopModel->getDependentIndexes();
717 for (
const auto& ldep : ldeps) {
718 for (
const auto& pos : ldep) {
719 if (pos.original != (std::numeric_limits<size_t>::max)()) {
720 inLoop[pos.original] =
true;
730 assert(m >= eqInLoopCount);
731 size_t nonLoopEq = m - eqInLoopCount;
732 std::vector<CGBase> nonLoopDeps(nonLoopEq + origTemp2Index_.size());
734 if (nonLoopDeps.size() == 0)
741 std::vector<size_t> depTape2Orig(nonLoopEq);
742 if (eqInLoopCount < m) {
743 for (
size_t i = 0; i < inLoop.size(); i++) {
745 depTape2Orig[inl] = i;
746 nonLoopDeps[inl++] = dependents_[i];
750 CPPADCG_ASSERT_UNKNOWN(inl == nonLoopEq)
755 for (
const auto& itTmp : origTemp2Index_) {
756 size_t k = itTmp.second;
757 nonLoopDeps[nonLoopEq + k] = handler_->createCG(Argument<Base>(*itTmp.first));
763 Evaluator<Base, CGBase> evaluator(*handler_);
766 const std::map<size_t, CGAbstractAtomicFun<Base>* >& atomicsOrig = handler_->
getAtomicFunctions();
767 std::map<size_t, atomic_base<CGBase>* > atomics;
768 atomics.insert(atomicsOrig.begin(), atomicsOrig.end());
769 evaluator.addAtomicFunctions(atomics);
771 std::vector<AD<CGBase> > x(independents_.size());
772 for (
size_t j = 0; j < x.size(); j++) {
773 if (independents_[j].isValueDefined())
774 x[j] = independents_[j].getValue();
777 CppAD::Independent(x);
778 std::vector<AD<CGBase> > y = evaluator.evaluate(x, nonLoopDeps);
780 std::unique_ptr<ADFun<CGBase> > tapeNoLoops(
new ADFun<CGBase>());
781 tapeNoLoops->Dependent(y);
783 return new LoopFreeModel<Base>(tapeNoLoops.release(), depTape2Orig);
786 std::vector<EquationPattern<Base>*> findRelatedVariables() {
788 CodeHandlerVector<Base, size_t> varColor(*handler_);
791 varColor.adjustSize();
794 size_t rSize = relatedDepCandidates_.size();
795 for (
size_t r = 0; r < rSize; r++) {
796 const std::set<size_t>& candidates = relatedDepCandidates_[r];
797 std::set<size_t> used;
801 std::set<size_t>::const_iterator itRef;
802 for (itRef = candidates.begin(); itRef != candidates.end(); ++itRef) {
803 size_t iDepRef = *itRef;
806 if (used.find(iDepRef) != used.end()) {
810 if (eqCurr_ ==
nullptr || !used.empty()) {
811 eqCurr_ =
new EquationPattern<Base>(dependents_[iDepRef], iDepRef);
812 equations_.push_back(eqCurr_);
816 for (++it; it != candidates.end(); ++it) {
819 if (used.find(iDep) != used.end()) {
823 if (eqCurr_->testAdd(iDep, dependents_[iDep], color_, varColor)) {
828 if (eqCurr_->dependents.size() == 1) {
832 equations_.pop_back();
841 for (
size_t eq = 0; eq < equations_.size(); eq++) {
855 inline bool findSharedTemporaries(
const CG<Base>& value,
857 OperationNode<Base>* depNode = value.getOperationNode();
859 if (findSharedTemporaries(depNode, depIndex, opCount)) {
860 varIndexed_[*depNode] =
true;
875 inline bool findSharedTemporaries(OperationNode<Base>* node,
881 if (handler_->isVisited(*node)) {
883 return varIndexed_[*node];
886 handler_->markVisited(*node);
888 bool indexedOperation =
false;
890 size_t localOpCount = 1;
891 const std::vector<Argument<Base> >& args = node->getArguments();
892 size_t arg_size = args.size();
893 for (
size_t a = 0; a < arg_size; a++) {
894 OperationNode<Base>*argOp = args[a].getOperation();
895 if (argOp !=
nullptr) {
896 if (argOp->getOperationType() != CGOpCode::Inv) {
897 indexedOperation |= findSharedTemporaries(argOp, depIndex, localOpCount);
899 indexedOperation |= !eqCurr_->containsConstantIndependent(node, a);
904 opCount += localOpCount;
906 varIndexed_[*node] = indexedOperation;
908 size_t id = varId_[*node];
909 std::set<size_t>& deps = id2Deps[id];
911 if (deps.size() > 1 && node->getOperationType() != CGOpCode::Inv) {
915 for (
size_t otherDep : deps) {
917 EquationPattern<Base>* otherEquation = dep2Equation_.at(otherDep);
918 if (otherEquation != eqCurr_) {
922 UniqueEquationPair<Base> eqPair(eqCurr_, otherEquation);
923 Dep1Dep2SharedType& relation = equationShared_[eqPair];
925 std::map<OperationNode<Base>*, Indexed2OpCountType>* reldepdep;
926 if (eqPair.eq1 == eqCurr_)
927 reldepdep = &relation[depIndex][otherDep];
929 reldepdep = &relation[otherDep][depIndex];
931 INDEXED_OPERATION_TYPE expected = indexedOperation ? INDEXED_OPERATION_TYPE::INDEXED : INDEXED_OPERATION_TYPE::NONINDEXED;
932 typename std::map<OperationNode<Base>*, Indexed2OpCountType>::iterator itIndexedType = reldepdep->find(node);
933 if (itIndexedType == reldepdep->end()) {
934 (*reldepdep)[node] = Indexed2OpCountType(expected, localOpCount);
935 }
else if (itIndexedType->second.first != expected) {
936 itIndexedType->second.first = INDEXED_OPERATION_TYPE::BOTH;
944 return indexedOperation;
954 inline void markOperationsWithDependent(
const OperationNode<Base>* node,
956 if (node ==
nullptr || node->getOperationType() == CGOpCode::Inv)
959 size_t id = varId_[*node];
961 std::set<size_t>& deps = id2Deps[id];
966 auto added = deps.insert(dep);
972 const std::vector<Argument<Base> >& args = node->getArguments();
973 size_t arg_size = args.size();
974 for (
size_t i = 0; i < arg_size; i++) {
975 markOperationsWithDependent(args[i].getOperation(), dep);
982 size_t rSize = relatedDepCandidates_.size();
983 for (
size_t r = 0; r < rSize; r++) {
984 const std::set<size_t>& candidates = relatedDepCandidates_[r];
986 for (
size_t it : candidates) {
987 assignIds(dependents_[it].getOperationNode());
992 void assignIds(OperationNode<Base>* node) {
993 if (node ==
nullptr || varId_[*node] > 0)
996 varId_[*node] = idCounter_;
997 origShareNodeId_.adjustSize(*node);
998 origShareNodeId_[*node] = idCounter_;
1001 const std::vector<Argument<Base> >& args = node->getArguments();
1002 size_t arg_size = args.size();
1003 for (
size_t i = 0; i < arg_size; i++) {
1004 assignIds(args[i].getOperation());
1008 void resetHandlerCounters() {
1009 size_t rSize = relatedDepCandidates_.size();
1010 for (
size_t r = 0; r < rSize; r++) {
1011 const std::set<size_t>& candidates = relatedDepCandidates_[r];
1013 for (
size_t it : candidates) {
1014 resetHandlerCounters(dependents_[it].getOperationNode());
1019 void resetHandlerCounters(OperationNode<Base>* node) {
1020 if (node ==
nullptr || varId_[*node] == 0 || origShareNodeId_[*node] == 0)
1024 origShareNodeId_[*node] = 0;
1026 const std::vector<Argument<Base> >& args = node->getArguments();
1027 size_t arg_size = args.size();
1028 for (
size_t i = 0; i < arg_size; i++) {
1029 resetHandlerCounters(args[i].getOperation());
1033 static bool find(Loop<Base>* loop1, Loop<Base>* loop2,
1034 const std::map<EquationPattern<Base>*, std::set<EquationPattern<Base>*> >& blackList) {
1035 for (EquationPattern<Base>* iteq1 : loop1->equations) {
1037 const auto itBlack = blackList.find(iteq1);
1038 if (itBlack != blackList.end()) {
1040 for (EquationPattern<Base>* iteq2 : loop2->equations) {
1041 if (itBlack->second.find(iteq2) != itBlack->second.end()) {
1052 static inline bool contains(
const std::map<T, std::set<T> >& map, T eq1, T eq2) {
1053 typename std::map<T, std::set<T> >::const_iterator itb1;
1054 itb1 = map.find(eq1);
1055 if (itb1 != map.end()) {
1056 if (itb1->second.find(eq2) != itb1->second.end()) {
1063 bool canCombineEquations(
const EquationPattern<Base>& eq1,
1065 const EquationPattern<Base>& eq2,
1067 OperationNode<Base>& sharedTemp,
1068 std::vector<std::set<size_t>* >& dep2Relations,
1069 std::map<
size_t, std::set<size_t> >& dependentBlackListRelations,
1070 SmartSetPointer<std::set<size_t> >& dependentRelations) {
1071 using namespace std;
1074 const set<const OperationNode<Base>*> opWithIndepArgs = eq1.findOperationsUsingIndependents(sharedTemp);
1077 for (
const OperationNode<Base>* op : opWithIndepArgs) {
1080 typename map<const OperationNode<Base>*, OperationIndexedIndependents<Base> >::const_iterator indexed1It;
1081 OperationNode<Base>* op1 = eq1.operationEO2Reference.at(dep1).at(op);
1082 indexed1It = eq1.indexedOpIndep.op2Arguments.find(op1);
1085 typename map<const OperationNode<Base>*, OperationIndexedIndependents<Base> >::const_iterator indexed2It;
1086 OperationNode<Base>* op2 = eq2.operationEO2Reference.at(dep2).at(op);
1087 indexed2It = eq2.indexedOpIndep.op2Arguments.find(op2);
1092 if (indexed1It == eq1.indexedOpIndep.op2Arguments.end()) {
1093 if (indexed2It != eq2.indexedOpIndep.op2Arguments.end()) {
1097 if (indexed2It == eq2.indexedOpIndep.op2Arguments.end()) {
1104 const OperationIndexedIndependents<Base>& indexed1Ops = indexed1It->second;
1105 const OperationIndexedIndependents<Base>& indexed2Ops = indexed2It->second;
1107 size_t a1Size = indexed1Ops.arg2Independents.size();
1108 if (a1Size != indexed2Ops.arg2Independents.size()) {
1112 for (
size_t a = 0; a < a1Size; a++) {
1113 const map<size_t, const OperationNode<Base>*>& eq1Dep2Indep = indexed1Ops.arg2Independents[a];
1114 const map<size_t, const OperationNode<Base>*>& eq2Dep2Indep = indexed2Ops.arg2Independents[a];
1116 if (eq1Dep2Indep.empty() != eq2Dep2Indep.empty())
1121 if (eq1Dep2Indep.empty()) {
1128 using MapIndep2Dep = map<const OperationNode<Base>*, size_t, IndependentNodeSorter<Base> >;
1129 MapIndep2Dep eq1Indep2Dep;
1130 typename MapIndep2Dep::iterator hint = eq1Indep2Dep.begin();
1131 for (
const auto& d2i : eq1Dep2Indep) {
1132 hint = eq1Indep2Dep.insert(hint, std::make_pair(d2i.second, d2i.first));
1136 typename map<const OperationNode<Base>*,
size_t>::const_iterator itHint = eq1Indep2Dep.begin();
1139 for (
const auto& d2i : eq2Dep2Indep) {
1140 size_t dep2 = d2i.first;
1141 const OperationNode<Base>* indep = d2i.second;
1142 typename map<const OperationNode<Base>*,
size_t>::const_iterator it;
1143 if (itHint->first == indep) {
1152 it = eq1Indep2Dep.find(indep);
1155 if (it != eq1Indep2Dep.end()) {
1156 size_t dep1 = it->second;
1159 std::map<size_t, set<size_t> >::const_iterator itBlackL = dependentBlackListRelations.find(dep1);
1160 if (itBlackL != dependentBlackListRelations.end() && itBlackL->second.find(dep2) != itBlackL->second.end()) {
1164 bool related = makeDependentRelation(eq1, dep1, eq2, dep2,
1165 dep2Relations, dependentRelations);
1174 dependentBlackListRelations[dep2].insert(eq1.dependents.begin(), eq1.dependents.end());
1187 bool isNonIndexed(
const EquationPattern<Base>& eq2,
1189 OperationNode<Base>& sharedTemp) {
1190 using namespace std;
1194 const set<const OperationNode<Base>*> opWithIndepArgs = EquationPattern<Base>::findOperationsUsingIndependents(sharedTemp);
1196 for (
const OperationNode<Base>* op : opWithIndepArgs) {
1199 OperationNode<Base>* op2 = eq2.operationEO2Reference.at(dep2).at(op);
1201 const auto indexed2It = eq2.indexedOpIndep.op2Arguments.find(op2);
1202 if (indexed2It != eq2.indexedOpIndep.op2Arguments.end()) {
1210 bool makeDependentRelation(
const EquationPattern<Base>& eq1,
1212 const EquationPattern<Base>& eq2,
1214 std::vector<std::set<size_t>* >& dep2Relations,
1215 SmartSetPointer<std::set<size_t> >& dependentRelations) {
1216 using namespace std;
1218 set<size_t>* related1 = dep2Relations[dep1];
1219 set<size_t>* related2 = dep2Relations[dep2];
1222 if (related1 !=
nullptr) {
1225 if (related2 !=
nullptr) {
1228 if (related1 == related2)
1233 bool canMerge =
true;
1235 for (
size_t dep3 : *related2) {
1237 const EquationPattern<Base>& eq3 = *dep2Equation_.at(dep3);
1239 for (
size_t it : eq3.dependents) {
1240 if (it != dep3 && related1->find(it) != related1->end()) {
1252 for (
size_t dep3 : *related2) {
1253 related1->insert(dep3);
1254 dep2Relations[dep3] = related1;
1257 dependentRelations.erase(related2);
1269 if (related1->find(dep2) == related1->end()) {
1271 bool canMerge =
true;
1272 for (
size_t it : eq2.dependents) {
1273 if (it != dep2 && related1->find(it) != related1->end()) {
1280 related1->insert(dep2);
1281 dep2Relations[dep2] = related1;
1293 }
else if (related2 !=
nullptr) {
1297 bool canMerge =
true;
1298 for (
size_t it : eq1.dependents) {
1299 if (it != dep1 && related2->find(it) != related2->end()) {
1307 related2->insert(dep1);
1308 dep2Relations[dep1] = related2;
1321 auto* related =
new std::set<size_t>();
1322 dependentRelations.insert(related);
1323 related->insert(dep1);
1324 related->insert(dep2);
1325 dep2Relations[dep1] = related;
1326 dep2Relations[dep2] = related;
const std::map< size_t, CGAbstractAtomicFun< Base > * > & getAtomicFunctions() const
DependentPatternMatcher(const std::vector< std::set< size_t > > &relatedDepCandidates, const std::vector< CGBase > &dependents, const std::vector< CGBase > &independents)
virtual void generateTapes(LoopFreeModel< Base > *&nonLoopTape, std::set< LoopModel< Base > * > &loopTapes)
void detectNonIndexedIndependents()