Loading...
Searching...
No Matches
TensorCustomOp.h
1// This file is part of Eigen, a lightweight C++ template library
2// for linear algebra.
3//
4// Copyright (C) 2014 Benoit Steiner <benoit.steiner.goog@gmail.com>
5//
6// This Source Code Form is subject to the terms of the Mozilla
7// Public License v. 2.0. If a copy of the MPL was not distributed
8// with this file, You can obtain one at http://mozilla.org/MPL/2.0/.
9
10#ifndef EIGEN_CXX11_TENSOR_TENSOR_CUSTOM_OP_H
11#define EIGEN_CXX11_TENSOR_TENSOR_CUSTOM_OP_H
12
13namespace Eigen {
14
22namespace internal {
23template<typename CustomUnaryFunc, typename XprType>
24struct traits<TensorCustomUnaryOp<CustomUnaryFunc, XprType> >
25{
26 typedef typename XprType::Scalar Scalar;
27 typedef typename XprType::StorageKind StorageKind;
28 typedef typename XprType::Index Index;
29 typedef typename XprType::Nested Nested;
30 typedef typename remove_reference<Nested>::type _Nested;
31 static const int NumDimensions = traits<XprType>::NumDimensions;
32 static const int Layout = traits<XprType>::Layout;
33};
34
35template<typename CustomUnaryFunc, typename XprType>
36struct eval<TensorCustomUnaryOp<CustomUnaryFunc, XprType>, Eigen::Dense>
37{
38 typedef const TensorCustomUnaryOp<CustomUnaryFunc, XprType>& type;
39};
40
41template<typename CustomUnaryFunc, typename XprType>
42struct nested<TensorCustomUnaryOp<CustomUnaryFunc, XprType> >
43{
44 typedef TensorCustomUnaryOp<CustomUnaryFunc, XprType> type;
45};
46
47} // end namespace internal
48
49
50
51template<typename CustomUnaryFunc, typename XprType>
52class TensorCustomUnaryOp : public TensorBase<TensorCustomUnaryOp<CustomUnaryFunc, XprType>, ReadOnlyAccessors>
53{
54 public:
55 typedef typename internal::traits<TensorCustomUnaryOp>::Scalar Scalar;
56 typedef typename Eigen::NumTraits<Scalar>::Real RealScalar;
57 typedef typename XprType::CoeffReturnType CoeffReturnType;
58 typedef typename internal::nested<TensorCustomUnaryOp>::type Nested;
59 typedef typename internal::traits<TensorCustomUnaryOp>::StorageKind StorageKind;
60 typedef typename internal::traits<TensorCustomUnaryOp>::Index Index;
61
62 EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE TensorCustomUnaryOp(const XprType& expr, const CustomUnaryFunc& func)
63 : m_expr(expr), m_func(func) {}
64
65 EIGEN_DEVICE_FUNC
66 const CustomUnaryFunc& func() const { return m_func; }
67
68 EIGEN_DEVICE_FUNC
69 const typename internal::remove_all<typename XprType::Nested>::type&
70 expression() const { return m_expr; }
71
72 protected:
73 typename XprType::Nested m_expr;
74 const CustomUnaryFunc m_func;
75};
76
77
78// Eval as rvalue
79template<typename CustomUnaryFunc, typename XprType, typename Device>
80struct TensorEvaluator<const TensorCustomUnaryOp<CustomUnaryFunc, XprType>, Device>
81{
83 typedef typename internal::traits<ArgType>::Index Index;
84 static const int NumDims = internal::traits<ArgType>::NumDimensions;
85 typedef DSizes<Index, NumDims> Dimensions;
86 typedef typename internal::remove_const<typename ArgType::Scalar>::type Scalar;
87 typedef typename internal::remove_const<typename XprType::CoeffReturnType>::type CoeffReturnType;
88 typedef typename PacketType<CoeffReturnType, Device>::type PacketReturnType;
89 static const int PacketSize = internal::unpacket_traits<PacketReturnType>::size;
90
91 enum {
92 IsAligned = false,
93 PacketAccess = (internal::packet_traits<Scalar>::size > 1),
94 BlockAccess = false,
96 CoordAccess = false, // to be implemented
97 RawAccess = false
98 };
99
100 EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE TensorEvaluator(const ArgType& op, const Device& device)
101 : m_op(op), m_device(device), m_result(NULL)
102 {
103 m_dimensions = op.func().dimensions(op.expression());
104 }
105
106 EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE const Dimensions& dimensions() const { return m_dimensions; }
107
108 EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE bool evalSubExprsIfNeeded(CoeffReturnType* data) {
109 if (data) {
110 evalTo(data);
111 return false;
112 } else {
113 m_result = static_cast<CoeffReturnType*>(
114 m_device.allocate(dimensions().TotalSize() * sizeof(Scalar)));
115 evalTo(m_result);
116 return true;
117 }
118 }
119
120 EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE void cleanup() {
121 if (m_result != NULL) {
122 m_device.deallocate(m_result);
123 m_result = NULL;
124 }
125 }
126
127 EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE CoeffReturnType coeff(Index index) const {
128 return m_result[index];
129 }
130
131 template<int LoadMode>
132 EIGEN_DEVICE_FUNC PacketReturnType packet(Index index) const {
133 return internal::ploadt<PacketReturnType, LoadMode>(m_result + index);
134 }
135
136 EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE TensorOpCost costPerCoeff(bool vectorized) const {
137 // TODO(rmlarsen): Extend CustomOp API to return its cost estimate.
138 return TensorOpCost(sizeof(CoeffReturnType), 0, 0, vectorized, PacketSize);
139 }
140
141 EIGEN_DEVICE_FUNC CoeffReturnType* data() const { return m_result; }
142
143 protected:
144 EIGEN_DEVICE_FUNC void evalTo(Scalar* data) {
145 TensorMap<Tensor<CoeffReturnType, NumDims, Layout, Index> > result(
146 data, m_dimensions);
147 m_op.func().eval(m_op.expression(), result, m_device);
148 }
149
150 Dimensions m_dimensions;
151 const ArgType m_op;
152 const Device& m_device;
153 CoeffReturnType* m_result;
154};
155
156
157
165namespace internal {
166template<typename CustomBinaryFunc, typename LhsXprType, typename RhsXprType>
167struct traits<TensorCustomBinaryOp<CustomBinaryFunc, LhsXprType, RhsXprType> >
168{
169 typedef typename internal::promote_storage_type<typename LhsXprType::Scalar,
170 typename RhsXprType::Scalar>::ret Scalar;
171 typedef typename internal::promote_storage_type<typename LhsXprType::CoeffReturnType,
172 typename RhsXprType::CoeffReturnType>::ret CoeffReturnType;
173 typedef typename promote_storage_type<typename traits<LhsXprType>::StorageKind,
174 typename traits<RhsXprType>::StorageKind>::ret StorageKind;
175 typedef typename promote_index_type<typename traits<LhsXprType>::Index,
176 typename traits<RhsXprType>::Index>::type Index;
177 typedef typename LhsXprType::Nested LhsNested;
178 typedef typename RhsXprType::Nested RhsNested;
179 typedef typename remove_reference<LhsNested>::type _LhsNested;
180 typedef typename remove_reference<RhsNested>::type _RhsNested;
181 static const int NumDimensions = traits<LhsXprType>::NumDimensions;
182 static const int Layout = traits<LhsXprType>::Layout;
183};
184
185template<typename CustomBinaryFunc, typename LhsXprType, typename RhsXprType>
186struct eval<TensorCustomBinaryOp<CustomBinaryFunc, LhsXprType, RhsXprType>, Eigen::Dense>
187{
188 typedef const TensorCustomBinaryOp<CustomBinaryFunc, LhsXprType, RhsXprType>& type;
189};
190
191template<typename CustomBinaryFunc, typename LhsXprType, typename RhsXprType>
192struct nested<TensorCustomBinaryOp<CustomBinaryFunc, LhsXprType, RhsXprType> >
193{
194 typedef TensorCustomBinaryOp<CustomBinaryFunc, LhsXprType, RhsXprType> type;
195};
196
197} // end namespace internal
198
199
200
201template<typename CustomBinaryFunc, typename LhsXprType, typename RhsXprType>
202class TensorCustomBinaryOp : public TensorBase<TensorCustomBinaryOp<CustomBinaryFunc, LhsXprType, RhsXprType>, ReadOnlyAccessors>
203{
204 public:
205 typedef typename internal::traits<TensorCustomBinaryOp>::Scalar Scalar;
206 typedef typename Eigen::NumTraits<Scalar>::Real RealScalar;
207 typedef typename internal::traits<TensorCustomBinaryOp>::CoeffReturnType CoeffReturnType;
208 typedef typename internal::nested<TensorCustomBinaryOp>::type Nested;
209 typedef typename internal::traits<TensorCustomBinaryOp>::StorageKind StorageKind;
210 typedef typename internal::traits<TensorCustomBinaryOp>::Index Index;
211
212 EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE TensorCustomBinaryOp(const LhsXprType& lhs, const RhsXprType& rhs, const CustomBinaryFunc& func)
213
214 : m_lhs_xpr(lhs), m_rhs_xpr(rhs), m_func(func) {}
215
216 EIGEN_DEVICE_FUNC
217 const CustomBinaryFunc& func() const { return m_func; }
218
219 EIGEN_DEVICE_FUNC
220 const typename internal::remove_all<typename LhsXprType::Nested>::type&
221 lhsExpression() const { return m_lhs_xpr; }
222
223 EIGEN_DEVICE_FUNC
224 const typename internal::remove_all<typename RhsXprType::Nested>::type&
225 rhsExpression() const { return m_rhs_xpr; }
226
227 protected:
228 typename LhsXprType::Nested m_lhs_xpr;
229 typename RhsXprType::Nested m_rhs_xpr;
230 const CustomBinaryFunc m_func;
231};
232
233
234// Eval as rvalue
235template<typename CustomBinaryFunc, typename LhsXprType, typename RhsXprType, typename Device>
236struct TensorEvaluator<const TensorCustomBinaryOp<CustomBinaryFunc, LhsXprType, RhsXprType>, Device>
237{
239 typedef typename internal::traits<XprType>::Index Index;
240 static const int NumDims = internal::traits<XprType>::NumDimensions;
241 typedef DSizes<Index, NumDims> Dimensions;
242 typedef typename XprType::Scalar Scalar;
243 typedef typename internal::remove_const<typename XprType::CoeffReturnType>::type CoeffReturnType;
244 typedef typename PacketType<CoeffReturnType, Device>::type PacketReturnType;
245 static const int PacketSize = internal::unpacket_traits<PacketReturnType>::size;
246
247 enum {
248 IsAligned = false,
249 PacketAccess = (internal::packet_traits<Scalar>::size > 1),
250 BlockAccess = false,
252 CoordAccess = false, // to be implemented
253 RawAccess = false
254 };
255
256 EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE TensorEvaluator(const XprType& op, const Device& device)
257 : m_op(op), m_device(device), m_result(NULL)
258 {
259 m_dimensions = op.func().dimensions(op.lhsExpression(), op.rhsExpression());
260 }
261
262 EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE const Dimensions& dimensions() const { return m_dimensions; }
263
264 EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE bool evalSubExprsIfNeeded(CoeffReturnType* data) {
265 if (data) {
266 evalTo(data);
267 return false;
268 } else {
269 m_result = static_cast<Scalar *>(m_device.allocate(dimensions().TotalSize() * sizeof(Scalar)));
270 evalTo(m_result);
271 return true;
272 }
273 }
274
275 EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE void cleanup() {
276 if (m_result != NULL) {
277 m_device.deallocate(m_result);
278 m_result = NULL;
279 }
280 }
281
282 EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE CoeffReturnType coeff(Index index) const {
283 return m_result[index];
284 }
285
286 template<int LoadMode>
287 EIGEN_DEVICE_FUNC PacketReturnType packet(Index index) const {
288 return internal::ploadt<PacketReturnType, LoadMode>(m_result + index);
289 }
290
291 EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE TensorOpCost costPerCoeff(bool vectorized) const {
292 // TODO(rmlarsen): Extend CustomOp API to return its cost estimate.
293 return TensorOpCost(sizeof(CoeffReturnType), 0, 0, vectorized, PacketSize);
294 }
295
296 EIGEN_DEVICE_FUNC CoeffReturnType* data() const { return m_result; }
297
298 protected:
299 EIGEN_DEVICE_FUNC void evalTo(Scalar* data) {
300 TensorMap<Tensor<Scalar, NumDims, Layout> > result(data, m_dimensions);
301 m_op.func().eval(m_op.lhsExpression(), m_op.rhsExpression(), result, m_device);
302 }
303
304 Dimensions m_dimensions;
305 const XprType m_op;
306 const Device& m_device;
307 CoeffReturnType* m_result;
308};
309
310
311} // end namespace Eigen
312
313#endif // EIGEN_CXX11_TENSOR_TENSOR_CUSTOM_OP_H
The tensor base class.
Definition: TensorBase.h:827
Tensor custom class.
Definition: TensorCustomOp.h:203
Tensor custom class.
Definition: TensorCustomOp.h:53
Namespace containing all symbols from the Eigen library.
Definition: AdolcForward:45
A cost model used to limit the number of threads used for evaluating tensor expression.
Definition: TensorEvaluator.h:29
const Device & device() const
required by sycl in order to construct sycl buffer from raw pointer
Definition: TensorEvaluator.h:114