Loading...
Searching...
No Matches
TensorGenerator.h
1// This file is part of Eigen, a lightweight C++ template library
2// for linear algebra.
3//
4// Copyright (C) 2015 Benoit Steiner <benoit.steiner.goog@gmail.com>
5//
6// This Source Code Form is subject to the terms of the Mozilla
7// Public License v. 2.0. If a copy of the MPL was not distributed
8// with this file, You can obtain one at http://mozilla.org/MPL/2.0/.
9
10#ifndef EIGEN_CXX11_TENSOR_TENSOR_GENERATOR_H
11#define EIGEN_CXX11_TENSOR_TENSOR_GENERATOR_H
12
13namespace Eigen {
14
22namespace internal {
23template<typename Generator, typename XprType>
24struct traits<TensorGeneratorOp<Generator, XprType> > : public traits<XprType>
25{
26 typedef typename XprType::Scalar Scalar;
27 typedef traits<XprType> XprTraits;
28 typedef typename XprTraits::StorageKind StorageKind;
29 typedef typename XprTraits::Index Index;
30 typedef typename XprType::Nested Nested;
31 typedef typename remove_reference<Nested>::type _Nested;
32 static const int NumDimensions = XprTraits::NumDimensions;
33 static const int Layout = XprTraits::Layout;
34};
35
36template<typename Generator, typename XprType>
37struct eval<TensorGeneratorOp<Generator, XprType>, Eigen::Dense>
38{
39 typedef const TensorGeneratorOp<Generator, XprType>& type;
40};
41
42template<typename Generator, typename XprType>
43struct nested<TensorGeneratorOp<Generator, XprType>, 1, typename eval<TensorGeneratorOp<Generator, XprType> >::type>
44{
45 typedef TensorGeneratorOp<Generator, XprType> type;
46};
47
48} // end namespace internal
49
50
51
52template<typename Generator, typename XprType>
53class TensorGeneratorOp : public TensorBase<TensorGeneratorOp<Generator, XprType>, ReadOnlyAccessors>
54{
55 public:
56 typedef typename Eigen::internal::traits<TensorGeneratorOp>::Scalar Scalar;
57 typedef typename Eigen::NumTraits<Scalar>::Real RealScalar;
58 typedef typename XprType::CoeffReturnType CoeffReturnType;
59 typedef typename Eigen::internal::nested<TensorGeneratorOp>::type Nested;
60 typedef typename Eigen::internal::traits<TensorGeneratorOp>::StorageKind StorageKind;
61 typedef typename Eigen::internal::traits<TensorGeneratorOp>::Index Index;
62
63 EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE TensorGeneratorOp(const XprType& expr, const Generator& generator)
64 : m_xpr(expr), m_generator(generator) {}
65
66 EIGEN_DEVICE_FUNC
67 const Generator& generator() const { return m_generator; }
68
69 EIGEN_DEVICE_FUNC
70 const typename internal::remove_all<typename XprType::Nested>::type&
71 expression() const { return m_xpr; }
72
73 protected:
74 typename XprType::Nested m_xpr;
75 const Generator m_generator;
76};
77
78
79// Eval as rvalue
80template<typename Generator, typename ArgType, typename Device>
81struct TensorEvaluator<const TensorGeneratorOp<Generator, ArgType>, Device>
82{
83 typedef TensorGeneratorOp<Generator, ArgType> XprType;
84 typedef typename XprType::Index Index;
85 typedef typename TensorEvaluator<ArgType, Device>::Dimensions Dimensions;
86 static const int NumDims = internal::array_size<Dimensions>::value;
87 typedef typename XprType::Scalar Scalar;
88 typedef typename XprType::CoeffReturnType CoeffReturnType;
89 typedef typename PacketType<CoeffReturnType, Device>::type PacketReturnType;
90 enum {
91 IsAligned = false,
92 PacketAccess = (internal::unpacket_traits<PacketReturnType>::size > 1),
93 BlockAccess = false,
94 Layout = TensorEvaluator<ArgType, Device>::Layout,
95 CoordAccess = false, // to be implemented
96 RawAccess = false
97 };
98
99 EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE TensorEvaluator(const XprType& op, const Device& device)
100 : m_generator(op.generator())
101 {
102 TensorEvaluator<ArgType, Device> impl(op.expression(), device);
103 m_dimensions = impl.dimensions();
104
105 if (static_cast<int>(Layout) == static_cast<int>(ColMajor)) {
106 m_strides[0] = 1;
107 for (int i = 1; i < NumDims; ++i) {
108 m_strides[i] = m_strides[i - 1] * m_dimensions[i - 1];
109 }
110 } else {
111 m_strides[NumDims - 1] = 1;
112 for (int i = NumDims - 2; i >= 0; --i) {
113 m_strides[i] = m_strides[i + 1] * m_dimensions[i + 1];
114 }
115 }
116 }
117
118 EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE const Dimensions& dimensions() const { return m_dimensions; }
119
120 EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE bool evalSubExprsIfNeeded(Scalar* /*data*/) {
121 return true;
122 }
123 EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE void cleanup() {
124 }
125
126 EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE CoeffReturnType coeff(Index index) const
127 {
128 array<Index, NumDims> coords;
129 extract_coordinates(index, coords);
130 return m_generator(coords);
131 }
132
133 template<int LoadMode>
134 EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE PacketReturnType packet(Index index) const
135 {
136 const int packetSize = internal::unpacket_traits<PacketReturnType>::size;
137 EIGEN_STATIC_ASSERT((packetSize > 1), YOU_MADE_A_PROGRAMMING_MISTAKE)
138 eigen_assert(index+packetSize-1 < dimensions().TotalSize());
139
140 EIGEN_ALIGN_MAX typename internal::remove_const<CoeffReturnType>::type values[packetSize];
141 for (int i = 0; i < packetSize; ++i) {
142 values[i] = coeff(index+i);
143 }
144 PacketReturnType rslt = internal::pload<PacketReturnType>(values);
145 return rslt;
146 }
147
148 EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE TensorOpCost
149 costPerCoeff(bool) const {
150 // TODO(rmlarsen): This is just a placeholder. Define interface to make
151 // generators return their cost.
152 return TensorOpCost(0, 0, TensorOpCost::AddCost<Scalar>() +
153 TensorOpCost::MulCost<Scalar>());
154 }
155
156 EIGEN_DEVICE_FUNC Scalar* data() const { return NULL; }
157
158 protected:
159 EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE
160 void extract_coordinates(Index index, array<Index, NumDims>& coords) const {
161 if (static_cast<int>(Layout) == static_cast<int>(ColMajor)) {
162 for (int i = NumDims - 1; i > 0; --i) {
163 const Index idx = index / m_strides[i];
164 index -= idx * m_strides[i];
165 coords[i] = idx;
166 }
167 coords[0] = index;
168 } else {
169 for (int i = 0; i < NumDims - 1; ++i) {
170 const Index idx = index / m_strides[i];
171 index -= idx * m_strides[i];
172 coords[i] = idx;
173 }
174 coords[NumDims-1] = index;
175 }
176 }
177
178 Dimensions m_dimensions;
179 array<Index, NumDims> m_strides;
180 Generator m_generator;
181};
182
183} // end namespace Eigen
184
185#endif // EIGEN_CXX11_TENSOR_TENSOR_GENERATOR_H
Namespace containing all symbols from the Eigen library.
Definition: AdolcForward:45
const Device & device() const
required by sycl in order to construct sycl buffer from raw pointer
Definition: TensorEvaluator.h:114