CppADCodeGen 2.4.3
A C++ Algorithmic Differentiation Package with Source Code Generation
Loading...
Searching...
No Matches
atomic_dependency_locator.hpp
1#ifndef CPPAD_CG_ATOMIC_DEPENDENCY_LOCATOR_INCLUDED
2#define CPPAD_CG_ATOMIC_DEPENDENCY_LOCATOR_INCLUDED
3/* --------------------------------------------------------------------------
4 * CppADCodeGen: C++ Algorithmic Differentiation with Source Code Generation:
5 * Copyright (C) 2013 Ciengis
6 *
7 * CppADCodeGen is distributed under multiple licenses:
8 *
9 * - Eclipse Public License Version 1.0 (EPL1), and
10 * - GNU General Public License Version 3 (GPL3).
11 *
12 * EPL1 terms and conditions can be found in the file "epl-v10.txt", while
13 * terms and conditions for the GPL3 can be found in the file "gpl3.txt".
14 * ----------------------------------------------------------------------------
15 * Author: Joao Leal
16 */
17
18namespace CppAD {
19namespace cg {
20
24template<class Base>
26public:
34 std::set<std::pair<size_t, size_t>> sizes;
38 std::set<size_t> outerIndeps;
39public:
40 inline AtomicUseInfo() :
41 atom(nullptr) {
42 }
43};
44
49template<class Base>
51private:
52 ADFun<CG<Base> >& fun_;
53 std::map<size_t, AtomicUseInfo<Base>> atomicInfo_;
54 std::map<OperationNode<Base>*, std::set<size_t> > indeps_;
55 CodeHandler<Base> handler_;
56public:
57
59 fun_(fun) {
60 }
61
62 inline const std::map<size_t, AtomicUseInfo<Base>>& findAtomicsUsage() {
63 if (!atomicInfo_.empty()) {
64 return atomicInfo_;
65 }
66
67 size_t m = fun_.Range();
68 size_t n = fun_.Domain();
69
70 std::vector<CG<Base> > x(n);
71 handler_.makeVariables(x);
72
73 // make sure the position in the code handler is the same as the independent index
74 assert(x.size() == 0 || (x[0].getOperationNode()->getHandlerPosition() == 0 && x[x.size() - 1].getOperationNode()->getHandlerPosition() == x.size() - 1));
75
76 std::vector<CG<Base> > dep = fun_.Forward(0, x);
77
78 for (size_t i = 0; i < m; i++) {
79 findAtomicsUsage(dep[i].getOperationNode());
80 }
81
82 const auto& regAtomics = handler_.getAtomicFunctions();
83 for (auto& pair: atomicInfo_) {
84 size_t id = pair.first;
85
86 pair.second.atom = regAtomics.at(id);
87 }
88
89 return atomicInfo_;
90 }
91
92private:
93
94 inline std::set<size_t> findAtomicsUsage(OperationNode<Base>* node) {
95 if (node == nullptr)
96 return std::set<size_t>();
97
98 CGOpCode op = node->getOperationType();
99 if (op == CGOpCode::Inv) {
100 std::set<size_t> indeps;
101 // particular case where the position in the code handler is the same as the independent index
102 indeps.insert(node->getHandlerPosition());
103 return indeps;
104 }
105
106 if (handler_.isVisited(*node)) {
107 // been here before
108 return indeps_.at(node);
109 }
110
111 handler_.markVisited(*node);
112
113 std::set<size_t> indeps;
114 const std::vector<Argument<Base> >& args = node->getArguments();
115 for (size_t a = 0; a < args.size(); a++) {
116 std::set<size_t> aindeps = findAtomicsUsage(args[a].getOperation());
117 indeps.insert(aindeps.begin(), aindeps.end());
118 }
119 indeps_[node] = indeps;
120
121 if (op == CGOpCode::AtomicForward) {
122 CPPADCG_ASSERT_UNKNOWN(node->getInfo().size() > 1);
123 CPPADCG_ASSERT_UNKNOWN(node->getArguments().size() > 1);
124 size_t id = node->getInfo()[0];
125
126#ifndef NDEBUG
127 size_t p = node->getInfo()[2];
128 CPPADCG_ASSERT_UNKNOWN(p == 0);
129#endif
130
131 OperationNode<Base>* tx = node->getArguments()[0].getOperation();
132 OperationNode<Base>* ty = node->getArguments()[1].getOperation();
133
134 CPPADCG_ASSERT_UNKNOWN(tx != nullptr && tx->getOperationType() == CGOpCode::ArrayCreation);
135 CPPADCG_ASSERT_UNKNOWN(ty != nullptr && ty->getOperationType() == CGOpCode::ArrayCreation);
136
137 auto& info = atomicInfo_[id];
138 info.outerIndeps.insert(indeps.begin(), indeps.end());
139 info.sizes.insert(std::pair<size_t, size_t>(tx->getArguments().size(),
140 ty->getArguments().size()));
141 }
142
143 return indeps;
144 }
145};
146
147} // END cg namespace
148} // END CppAD namespace
149
150#endif
std::set< std::pair< size_t, size_t > > sizes
CGAbstractAtomicFun< Base > * atom
void makeVariables(VectorCG &variables)
const std::map< size_t, CGAbstractAtomicFun< Base > * > & getAtomicFunctions() const
const std::vector< size_t > & getInfo() const
size_t getHandlerPosition() const
const std::vector< Argument< Base > > & getArguments() const
CGOpCode getOperationType() const