Loading...
Searching...
No Matches
TensorRef.h
1// This file is part of Eigen, a lightweight C++ template library
2// for linear algebra.
3//
4// Copyright (C) 2014 Benoit Steiner <benoit.steiner.goog@gmail.com>
5//
6// This Source Code Form is subject to the terms of the Mozilla
7// Public License v. 2.0. If a copy of the MPL was not distributed
8// with this file, You can obtain one at http://mozilla.org/MPL/2.0/.
9
10#ifndef EIGEN_CXX11_TENSOR_TENSOR_REF_H
11#define EIGEN_CXX11_TENSOR_TENSOR_REF_H
12
13namespace Eigen {
14
15namespace internal {
16
17template <typename Dimensions, typename Scalar>
18class TensorLazyBaseEvaluator {
19 public:
20 TensorLazyBaseEvaluator() : m_refcount(0) { }
21 virtual ~TensorLazyBaseEvaluator() { }
22
23 EIGEN_DEVICE_FUNC virtual const Dimensions& dimensions() const = 0;
24 EIGEN_DEVICE_FUNC virtual const Scalar* data() const = 0;
25
26 EIGEN_DEVICE_FUNC virtual const Scalar coeff(DenseIndex index) const = 0;
27 EIGEN_DEVICE_FUNC virtual Scalar& coeffRef(DenseIndex index) = 0;
28
29 void incrRefCount() { ++m_refcount; }
30 void decrRefCount() { --m_refcount; }
31 int refCount() const { return m_refcount; }
32
33 private:
34 // No copy, no assigment;
35 TensorLazyBaseEvaluator(const TensorLazyBaseEvaluator& other);
36 TensorLazyBaseEvaluator& operator = (const TensorLazyBaseEvaluator& other);
37
38 int m_refcount;
39};
40
41
42template <typename Dimensions, typename Expr, typename Device>
43class TensorLazyEvaluatorReadOnly : public TensorLazyBaseEvaluator<Dimensions, typename TensorEvaluator<Expr, Device>::Scalar> {
44 public:
45 // typedef typename TensorEvaluator<Expr, Device>::Dimensions Dimensions;
46 typedef typename TensorEvaluator<Expr, Device>::Scalar Scalar;
47
48 TensorLazyEvaluatorReadOnly(const Expr& expr, const Device& device) : m_impl(expr, device), m_dummy(Scalar(0)) {
49 m_dims = m_impl.dimensions();
50 m_impl.evalSubExprsIfNeeded(NULL);
51 }
52 virtual ~TensorLazyEvaluatorReadOnly() {
53 m_impl.cleanup();
54 }
55
56 EIGEN_DEVICE_FUNC virtual const Dimensions& dimensions() const {
57 return m_dims;
58 }
59 EIGEN_DEVICE_FUNC virtual const Scalar* data() const {
60 return m_impl.data();
61 }
62
63 EIGEN_DEVICE_FUNC virtual const Scalar coeff(DenseIndex index) const {
64 return m_impl.coeff(index);
65 }
66 EIGEN_DEVICE_FUNC virtual Scalar& coeffRef(DenseIndex /*index*/) {
67 eigen_assert(false && "can't reference the coefficient of a rvalue");
68 return m_dummy;
69 };
70
71 protected:
72 TensorEvaluator<Expr, Device> m_impl;
73 Dimensions m_dims;
74 Scalar m_dummy;
75};
76
77template <typename Dimensions, typename Expr, typename Device>
78class TensorLazyEvaluatorWritable : public TensorLazyEvaluatorReadOnly<Dimensions, Expr, Device> {
79 public:
80 typedef TensorLazyEvaluatorReadOnly<Dimensions, Expr, Device> Base;
81 typedef typename Base::Scalar Scalar;
82
83 TensorLazyEvaluatorWritable(const Expr& expr, const Device& device) : Base(expr, device) {
84 }
85 virtual ~TensorLazyEvaluatorWritable() {
86 }
87
88 EIGEN_DEVICE_FUNC virtual Scalar& coeffRef(DenseIndex index) {
89 return this->m_impl.coeffRef(index);
90 }
91};
92
93template <typename Dimensions, typename Expr, typename Device>
94class TensorLazyEvaluator : public internal::conditional<bool(internal::is_lvalue<Expr>::value),
95 TensorLazyEvaluatorWritable<Dimensions, Expr, Device>,
96 TensorLazyEvaluatorReadOnly<Dimensions, const Expr, Device> >::type {
97 public:
98 typedef typename internal::conditional<bool(internal::is_lvalue<Expr>::value),
99 TensorLazyEvaluatorWritable<Dimensions, Expr, Device>,
100 TensorLazyEvaluatorReadOnly<Dimensions, const Expr, Device> >::type Base;
101 typedef typename Base::Scalar Scalar;
102
103 TensorLazyEvaluator(const Expr& expr, const Device& device) : Base(expr, device) {
104 }
105 virtual ~TensorLazyEvaluator() {
106 }
107};
108
109} // namespace internal
110
111
119template<typename PlainObjectType> class TensorRef : public TensorBase<TensorRef<PlainObjectType> >
120{
121 public:
123 typedef typename PlainObjectType::Base Base;
124 typedef typename Eigen::internal::nested<Self>::type Nested;
125 typedef typename internal::traits<PlainObjectType>::StorageKind StorageKind;
126 typedef typename internal::traits<PlainObjectType>::Index Index;
127 typedef typename internal::traits<PlainObjectType>::Scalar Scalar;
128 typedef typename NumTraits<Scalar>::Real RealScalar;
129 typedef typename Base::CoeffReturnType CoeffReturnType;
130 typedef Scalar* PointerType;
131 typedef PointerType PointerArgType;
132
133 static const Index NumIndices = PlainObjectType::NumIndices;
134 typedef typename PlainObjectType::Dimensions Dimensions;
135
136 enum {
137 IsAligned = false,
138 PacketAccess = false,
139 Layout = PlainObjectType::Layout,
140 CoordAccess = false, // to be implemented
141 RawAccess = false
142 };
143
144 EIGEN_STRONG_INLINE TensorRef() : m_evaluator(NULL) {
145 }
146
147 template <typename Expression>
148 EIGEN_STRONG_INLINE TensorRef(const Expression& expr) : m_evaluator(new internal::TensorLazyEvaluator<Dimensions, Expression, DefaultDevice>(expr, DefaultDevice())) {
149 m_evaluator->incrRefCount();
150 }
151
152 template <typename Expression>
153 EIGEN_STRONG_INLINE TensorRef& operator = (const Expression& expr) {
154 unrefEvaluator();
155 m_evaluator = new internal::TensorLazyEvaluator<Dimensions, Expression, DefaultDevice>(expr, DefaultDevice());
156 m_evaluator->incrRefCount();
157 return *this;
158 }
159
160 ~TensorRef() {
161 unrefEvaluator();
162 }
163
164 TensorRef(const TensorRef& other) : m_evaluator(other.m_evaluator) {
165 eigen_assert(m_evaluator->refCount() > 0);
166 m_evaluator->incrRefCount();
167 }
168
169 TensorRef& operator = (const TensorRef& other) {
170 if (this != &other) {
171 unrefEvaluator();
172 m_evaluator = other.m_evaluator;
173 eigen_assert(m_evaluator->refCount() > 0);
174 m_evaluator->incrRefCount();
175 }
176 return *this;
177 }
178
179 EIGEN_DEVICE_FUNC
180 EIGEN_STRONG_INLINE Index rank() const { return m_evaluator->dimensions().size(); }
181 EIGEN_DEVICE_FUNC
182 EIGEN_STRONG_INLINE Index dimension(Index n) const { return m_evaluator->dimensions()[n]; }
183 EIGEN_DEVICE_FUNC
184 EIGEN_STRONG_INLINE const Dimensions& dimensions() const { return m_evaluator->dimensions(); }
185 EIGEN_DEVICE_FUNC
186 EIGEN_STRONG_INLINE Index size() const { return m_evaluator->dimensions().TotalSize(); }
187 EIGEN_DEVICE_FUNC
188 EIGEN_STRONG_INLINE const Scalar* data() const { return m_evaluator->data(); }
189
190 EIGEN_DEVICE_FUNC
191 EIGEN_STRONG_INLINE const Scalar operator()(Index index) const
192 {
193 return m_evaluator->coeff(index);
194 }
195
196#if EIGEN_HAS_VARIADIC_TEMPLATES
197 template<typename... IndexTypes> EIGEN_DEVICE_FUNC
198 EIGEN_STRONG_INLINE const Scalar operator()(Index firstIndex, IndexTypes... otherIndices) const
199 {
200 const std::size_t num_indices = (sizeof...(otherIndices) + 1);
201 const array<Index, num_indices> indices{{firstIndex, otherIndices...}};
202 return coeff(indices);
203 }
204 template<typename... IndexTypes> EIGEN_DEVICE_FUNC
205 EIGEN_STRONG_INLINE Scalar& coeffRef(Index firstIndex, IndexTypes... otherIndices)
206 {
207 const std::size_t num_indices = (sizeof...(otherIndices) + 1);
208 const array<Index, num_indices> indices{{firstIndex, otherIndices...}};
209 return coeffRef(indices);
210 }
211#else
212
213 EIGEN_DEVICE_FUNC
214 EIGEN_STRONG_INLINE const Scalar operator()(Index i0, Index i1) const
215 {
216 array<Index, 2> indices;
217 indices[0] = i0;
218 indices[1] = i1;
219 return coeff(indices);
220 }
221 EIGEN_DEVICE_FUNC
222 EIGEN_STRONG_INLINE const Scalar operator()(Index i0, Index i1, Index i2) const
223 {
224 array<Index, 3> indices;
225 indices[0] = i0;
226 indices[1] = i1;
227 indices[2] = i2;
228 return coeff(indices);
229 }
230 EIGEN_DEVICE_FUNC
231 EIGEN_STRONG_INLINE const Scalar operator()(Index i0, Index i1, Index i2, Index i3) const
232 {
233 array<Index, 4> indices;
234 indices[0] = i0;
235 indices[1] = i1;
236 indices[2] = i2;
237 indices[3] = i3;
238 return coeff(indices);
239 }
240 EIGEN_DEVICE_FUNC
241 EIGEN_STRONG_INLINE const Scalar operator()(Index i0, Index i1, Index i2, Index i3, Index i4) const
242 {
243 array<Index, 5> indices;
244 indices[0] = i0;
245 indices[1] = i1;
246 indices[2] = i2;
247 indices[3] = i3;
248 indices[4] = i4;
249 return coeff(indices);
250 }
251 EIGEN_DEVICE_FUNC
252 EIGEN_STRONG_INLINE Scalar& coeffRef(Index i0, Index i1)
253 {
254 array<Index, 2> indices;
255 indices[0] = i0;
256 indices[1] = i1;
257 return coeffRef(indices);
258 }
259 EIGEN_DEVICE_FUNC
260 EIGEN_STRONG_INLINE Scalar& coeffRef(Index i0, Index i1, Index i2)
261 {
262 array<Index, 3> indices;
263 indices[0] = i0;
264 indices[1] = i1;
265 indices[2] = i2;
266 return coeffRef(indices);
267 }
268 EIGEN_DEVICE_FUNC
269 EIGEN_STRONG_INLINE Scalar& operator()(Index i0, Index i1, Index i2, Index i3)
270 {
271 array<Index, 4> indices;
272 indices[0] = i0;
273 indices[1] = i1;
274 indices[2] = i2;
275 indices[3] = i3;
276 return coeffRef(indices);
277 }
278 EIGEN_DEVICE_FUNC
279 EIGEN_STRONG_INLINE Scalar& coeffRef(Index i0, Index i1, Index i2, Index i3, Index i4)
280 {
281 array<Index, 5> indices;
282 indices[0] = i0;
283 indices[1] = i1;
284 indices[2] = i2;
285 indices[3] = i3;
286 indices[4] = i4;
287 return coeffRef(indices);
288 }
289#endif
290
291 template <std::size_t NumIndices> EIGEN_DEVICE_FUNC
292 EIGEN_STRONG_INLINE const Scalar coeff(const array<Index, NumIndices>& indices) const
293 {
294 const Dimensions& dims = this->dimensions();
295 Index index = 0;
296 if (PlainObjectType::Options & RowMajor) {
297 index += indices[0];
298 for (size_t i = 1; i < NumIndices; ++i) {
299 index = index * dims[i] + indices[i];
300 }
301 } else {
302 index += indices[NumIndices-1];
303 for (int i = NumIndices-2; i >= 0; --i) {
304 index = index * dims[i] + indices[i];
305 }
306 }
307 return m_evaluator->coeff(index);
308 }
309 template <std::size_t NumIndices> EIGEN_DEVICE_FUNC
310 EIGEN_STRONG_INLINE Scalar& coeffRef(const array<Index, NumIndices>& indices)
311 {
312 const Dimensions& dims = this->dimensions();
313 Index index = 0;
314 if (PlainObjectType::Options & RowMajor) {
315 index += indices[0];
316 for (size_t i = 1; i < NumIndices; ++i) {
317 index = index * dims[i] + indices[i];
318 }
319 } else {
320 index += indices[NumIndices-1];
321 for (int i = NumIndices-2; i >= 0; --i) {
322 index = index * dims[i] + indices[i];
323 }
324 }
325 return m_evaluator->coeffRef(index);
326 }
327
328 EIGEN_DEVICE_FUNC
329 EIGEN_STRONG_INLINE const Scalar coeff(Index index) const
330 {
331 return m_evaluator->coeff(index);
332 }
333
334 EIGEN_DEVICE_FUNC
335 EIGEN_STRONG_INLINE Scalar& coeffRef(Index index)
336 {
337 return m_evaluator->coeffRef(index);
338 }
339
340 private:
341 EIGEN_STRONG_INLINE void unrefEvaluator() {
342 if (m_evaluator) {
343 m_evaluator->decrRefCount();
344 if (m_evaluator->refCount() == 0) {
345 delete m_evaluator;
346 }
347 }
348 }
349
350 internal::TensorLazyBaseEvaluator<Dimensions, Scalar>* m_evaluator;
351};
352
353
354// evaluator for rvalues
355template<typename Derived, typename Device>
356struct TensorEvaluator<const TensorRef<Derived>, Device>
357{
358 typedef typename Derived::Index Index;
359 typedef typename Derived::Scalar Scalar;
360 typedef typename Derived::Scalar CoeffReturnType;
361 typedef typename PacketType<CoeffReturnType, Device>::type PacketReturnType;
362 typedef typename Derived::Dimensions Dimensions;
363
364 enum {
365 IsAligned = false,
366 PacketAccess = false,
368 CoordAccess = false, // to be implemented
369 RawAccess = false
370 };
371
372 EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE TensorEvaluator(const TensorRef<Derived>& m, const Device&)
373 : m_ref(m)
374 { }
375
376 EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE const Dimensions& dimensions() const { return m_ref.dimensions(); }
377
378 EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE bool evalSubExprsIfNeeded(Scalar*) {
379 return true;
380 }
381
382 EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE void cleanup() { }
383
384 EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE CoeffReturnType coeff(Index index) const {
385 return m_ref.coeff(index);
386 }
387
388 EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE Scalar& coeffRef(Index index) {
389 return m_ref.coeffRef(index);
390 }
391
392 EIGEN_DEVICE_FUNC Scalar* data() const { return m_ref.data(); }
393
394 protected:
395 TensorRef<Derived> m_ref;
396};
397
398
399// evaluator for lvalues
400template<typename Derived, typename Device>
401struct TensorEvaluator<TensorRef<Derived>, Device> : public TensorEvaluator<const TensorRef<Derived>, Device>
402{
403 typedef typename Derived::Index Index;
404 typedef typename Derived::Scalar Scalar;
405 typedef typename Derived::Scalar CoeffReturnType;
406 typedef typename PacketType<CoeffReturnType, Device>::type PacketReturnType;
407 typedef typename Derived::Dimensions Dimensions;
408
409 typedef TensorEvaluator<const TensorRef<Derived>, Device> Base;
410
411 enum {
412 IsAligned = false,
413 PacketAccess = false,
414 RawAccess = false
415 };
416
417 EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE TensorEvaluator(TensorRef<Derived>& m, const Device& d) : Base(m, d)
418 { }
419
420 EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE Scalar& coeffRef(Index index) {
421 return this->m_ref.coeffRef(index);
422 }
423};
424
425
426
427} // end namespace Eigen
428
429#endif // EIGEN_CXX11_TENSOR_TENSOR_REF_H
The tensor base class.
Definition: TensorBase.h:827
A reference to a tensor expression The expression will be evaluated lazily (as much as possible).
Definition: TensorRef.h:120
Namespace containing all symbols from the Eigen library.
Definition: AdolcForward:45
A cost model used to limit the number of threads used for evaluating tensor expression.
Definition: TensorEvaluator.h:29