Loading...
Searching...
No Matches
TensorContraction.h
1// This file is part of Eigen, a lightweight C++ template library
2// for linear algebra.
3//
4// Copyright (C) 2014 Benoit Steiner <benoit.steiner.goog@gmail.com>
5//
6// This Source Code Form is subject to the terms of the Mozilla
7// Public License v. 2.0. If a copy of the MPL was not distributed
8// with this file, You can obtain one at http://mozilla.org/MPL/2.0/.
9
10#ifndef EIGEN_CXX11_TENSOR_TENSOR_CONTRACTION_H
11#define EIGEN_CXX11_TENSOR_TENSOR_CONTRACTION_H
12
13namespace Eigen {
14
22namespace internal {
23
24template<typename Dimensions, typename LhsXprType, typename RhsXprType>
25struct traits<TensorContractionOp<Dimensions, LhsXprType, RhsXprType> >
26{
27 // Type promotion to handle the case where the types of the lhs and the rhs are different.
28 typedef typename gebp_traits<typename remove_const<typename LhsXprType::Scalar>::type,
29 typename remove_const<typename RhsXprType::Scalar>::type>::ResScalar Scalar;
30
31 typedef typename promote_storage_type<typename traits<LhsXprType>::StorageKind,
32 typename traits<RhsXprType>::StorageKind>::ret StorageKind;
33 typedef typename promote_index_type<typename traits<LhsXprType>::Index,
34 typename traits<RhsXprType>::Index>::type Index;
35 typedef typename LhsXprType::Nested LhsNested;
36 typedef typename RhsXprType::Nested RhsNested;
37 typedef typename remove_reference<LhsNested>::type _LhsNested;
38 typedef typename remove_reference<RhsNested>::type _RhsNested;
39
40 // From NumDims below.
41 static const int NumDimensions = traits<RhsXprType>::NumDimensions + traits<RhsXprType>::NumDimensions - 2 * array_size<Dimensions>::value;
42 static const int Layout = traits<LhsXprType>::Layout;
43
44 enum {
45 Flags = 0
46 };
47};
48
49template<typename Dimensions, typename LhsXprType, typename RhsXprType>
50struct eval<TensorContractionOp<Dimensions, LhsXprType, RhsXprType>, Eigen::Dense>
51{
52 typedef const TensorContractionOp<Dimensions, LhsXprType, RhsXprType>& type;
53};
54
55template<typename Dimensions, typename LhsXprType, typename RhsXprType>
56struct nested<TensorContractionOp<Dimensions, LhsXprType, RhsXprType>, 1, typename eval<TensorContractionOp<Dimensions, LhsXprType, RhsXprType> >::type>
57{
58 typedef TensorContractionOp<Dimensions, LhsXprType, RhsXprType> type;
59};
60
61template<typename Indices_, typename LeftArgType_, typename RightArgType_, typename Device_>
62struct traits<TensorEvaluator<const TensorContractionOp<Indices_, LeftArgType_, RightArgType_>, Device_> > {
63 typedef Indices_ Indices;
64 typedef LeftArgType_ LeftArgType;
65 typedef RightArgType_ RightArgType;
66 typedef Device_ Device;
67
68 // From NumDims below.
69 static const int NumDimensions = traits<LeftArgType_>::NumDimensions + traits<RightArgType_>::NumDimensions - 2 * array_size<Indices_>::value;
70};
71
72} // end namespace internal
73
74template<typename Indices, typename LhsXprType, typename RhsXprType>
75class TensorContractionOp : public TensorBase<TensorContractionOp<Indices, LhsXprType, RhsXprType>, ReadOnlyAccessors>
76{
77 public:
78 typedef typename Eigen::internal::traits<TensorContractionOp>::Scalar Scalar;
79 typedef typename internal::gebp_traits<typename LhsXprType::CoeffReturnType,
80 typename RhsXprType::CoeffReturnType>::ResScalar CoeffReturnType;
81 typedef typename Eigen::internal::nested<TensorContractionOp>::type Nested;
82 typedef typename Eigen::internal::traits<TensorContractionOp>::StorageKind StorageKind;
83 typedef typename Eigen::internal::traits<TensorContractionOp>::Index Index;
84
85 EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE TensorContractionOp(
86 const LhsXprType& lhs, const RhsXprType& rhs, const Indices& dims)
87 : m_lhs_xpr(lhs), m_rhs_xpr(rhs), m_indices(dims) {}
88
89 EIGEN_DEVICE_FUNC
90 const Indices& indices() const { return m_indices; }
91
93 EIGEN_DEVICE_FUNC
94 const typename internal::remove_all<typename LhsXprType::Nested>::type&
95 lhsExpression() const { return m_lhs_xpr; }
96
97 EIGEN_DEVICE_FUNC
98 const typename internal::remove_all<typename RhsXprType::Nested>::type&
99 rhsExpression() const { return m_rhs_xpr; }
100
101 protected:
102 typename LhsXprType::Nested m_lhs_xpr;
103 typename RhsXprType::Nested m_rhs_xpr;
104 const Indices m_indices;
105};
106
107
108template<typename Derived>
109struct TensorContractionEvaluatorBase
110{
111 typedef typename internal::traits<Derived>::Indices Indices;
112 typedef typename internal::traits<Derived>::LeftArgType LeftArgType;
113 typedef typename internal::traits<Derived>::RightArgType RightArgType;
114 typedef typename internal::traits<Derived>::Device Device;
115
116 typedef TensorContractionOp<Indices, LeftArgType, RightArgType> XprType;
117 typedef typename internal::remove_const<typename XprType::Scalar>::type Scalar;
118 typedef typename XprType::Index Index;
119 typedef typename XprType::CoeffReturnType CoeffReturnType;
120 typedef typename PacketType<CoeffReturnType, Device>::type PacketReturnType;
121
122 enum {
123 IsAligned = true,
124 PacketAccess = (internal::unpacket_traits<PacketReturnType>::size > 1),
125 Layout = TensorEvaluator<LeftArgType, Device>::Layout,
126 CoordAccess = false, // to be implemented
127 RawAccess = true
128 };
129
130 // Most of the code is assuming that both input tensors are ColMajor. If the
131 // inputs are RowMajor, we will "cheat" by swapping the LHS and RHS:
132 // If we want to compute A * B = C, where A is LHS and B is RHS, the code
133 // will pretend B is LHS and A is RHS.
134 typedef typename internal::conditional<
135 static_cast<int>(Layout) == static_cast<int>(ColMajor), LeftArgType, RightArgType>::type EvalLeftArgType;
136 typedef typename internal::conditional<
137 static_cast<int>(Layout) == static_cast<int>(ColMajor), RightArgType, LeftArgType>::type EvalRightArgType;
138
139 static const int LDims =
140 internal::array_size<typename TensorEvaluator<EvalLeftArgType, Device>::Dimensions>::value;
141 static const int RDims =
142 internal::array_size<typename TensorEvaluator<EvalRightArgType, Device>::Dimensions>::value;
143 static const int ContractDims = internal::array_size<Indices>::value;
144 static const int NumDims = LDims + RDims - 2 * ContractDims;
145
146 typedef array<Index, ContractDims> contract_t;
147 typedef array<Index, LDims - ContractDims> left_nocontract_t;
148 typedef array<Index, RDims - ContractDims> right_nocontract_t;
149
150 typedef DSizes<Index, NumDims> Dimensions;
151
152 EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE
153 TensorContractionEvaluatorBase(const XprType& op, const Device& device)
154 : m_leftImpl(choose(Cond<static_cast<int>(Layout) == static_cast<int>(ColMajor)>(),
155 op.lhsExpression(), op.rhsExpression()), device),
156 m_rightImpl(choose(Cond<static_cast<int>(Layout) == static_cast<int>(ColMajor)>(),
157 op.rhsExpression(), op.lhsExpression()), device),
158 m_device(device),
159 m_result(NULL) {
160 EIGEN_STATIC_ASSERT((static_cast<int>(TensorEvaluator<LeftArgType, Device>::Layout) ==
161 static_cast<int>(TensorEvaluator<RightArgType, Device>::Layout)),
162 YOU_MADE_A_PROGRAMMING_MISTAKE);
163
164
165 DSizes<Index, LDims> eval_left_dims;
166 DSizes<Index, RDims> eval_right_dims;
167 array<IndexPair<Index>, ContractDims> eval_op_indices;
168 if (static_cast<int>(Layout) == static_cast<int>(ColMajor)) {
169 // For ColMajor, we keep using the existing dimensions
170 for (int i = 0; i < LDims; i++) {
171 eval_left_dims[i] = m_leftImpl.dimensions()[i];
172 }
173 for (int i = 0; i < RDims; i++) {
174 eval_right_dims[i] = m_rightImpl.dimensions()[i];
175 }
176 // We keep the pairs of contracting indices.
177 for (int i = 0; i < ContractDims; i++) {
178 eval_op_indices[i].first = op.indices()[i].first;
179 eval_op_indices[i].second = op.indices()[i].second;
180 }
181 } else {
182 // For RowMajor, we need to reverse the existing dimensions
183 for (int i = 0; i < LDims; i++) {
184 eval_left_dims[i] = m_leftImpl.dimensions()[LDims - i - 1];
185 }
186 for (int i = 0; i < RDims; i++) {
187 eval_right_dims[i] = m_rightImpl.dimensions()[RDims - i - 1];
188 }
189 // We need to flip all the pairs of contracting indices as well as
190 // reversing the dimensions.
191 for (int i = 0; i < ContractDims; i++) {
192 eval_op_indices[i].first = LDims - 1 - op.indices()[ContractDims - 1 - i].second;
193 eval_op_indices[i].second = RDims - 1 - op.indices()[ContractDims - 1 - i].first;
194 }
195 }
196
197 // Check for duplicate axes and make sure the first index in eval_op_indices
198 // is increasing. Using O(n^2) sorting is OK since ContractDims is small
199 for (int i = 0; i < ContractDims; i++) {
200 for (int j = i + 1; j < ContractDims; j++) {
201 eigen_assert(eval_op_indices[j].first != eval_op_indices[i].first &&
202 eval_op_indices[j].second != eval_op_indices[i].second &&
203 "contraction axes should be unique");
204 if (eval_op_indices[j].first < eval_op_indices[i].first) {
205 numext::swap(eval_op_indices[j], eval_op_indices[i]);
206 }
207 }
208 }
209
210 array<Index, LDims> lhs_strides;
211 lhs_strides[0] = 1;
212 for (int i = 0; i < LDims-1; ++i) {
213 lhs_strides[i+1] = lhs_strides[i] * eval_left_dims[i];
214 }
215
216 array<Index, RDims> rhs_strides;
217 rhs_strides[0] = 1;
218 for (int i = 0; i < RDims-1; ++i) {
219 rhs_strides[i+1] = rhs_strides[i] * eval_right_dims[i];
220 }
221
222 if (m_i_strides.size() > 0) m_i_strides[0] = 1;
223 if (m_j_strides.size() > 0) m_j_strides[0] = 1;
224 if (m_k_strides.size() > 0) m_k_strides[0] = 1;
225
226 m_i_size = 1;
227 m_j_size = 1;
228 m_k_size = 1;
229
230 // To compute the dimension, we simply concatenate the non-contracting
231 // dimensions of the left and then the right tensor. Additionally, we also
232 // compute the strides corresponding to the left non-contracting
233 // dimensions and right non-contracting dimensions.
234 m_lhs_inner_dim_contiguous = true;
235 int dim_idx = 0;
236 unsigned int nocontract_idx = 0;
237
238 for (int i = 0; i < LDims; i++) {
239 // find if we are contracting on index i of left tensor
240 bool contracting = false;
241 for (int j = 0; j < ContractDims; j++) {
242 if (eval_op_indices[j].first == i) {
243 contracting = true;
244 break;
245 }
246 }
247 if (!contracting) {
248 // add dimension size to output dimensions
249 m_dimensions[dim_idx] = eval_left_dims[i];
250 m_left_nocontract_strides[nocontract_idx] = lhs_strides[i];
251 if (dim_idx != i) {
252 m_lhs_inner_dim_contiguous = false;
253 }
254 if (nocontract_idx+1 < internal::array_size<left_nocontract_t>::value) {
255 m_i_strides[nocontract_idx+1] =
256 m_i_strides[nocontract_idx] * eval_left_dims[i];
257 } else {
258 m_i_size = m_i_strides[nocontract_idx] * eval_left_dims[i];
259 }
260 dim_idx++;
261 nocontract_idx++;
262 }
263 }
264
265 nocontract_idx = 0;
266 for (int i = 0; i < RDims; i++) {
267 bool contracting = false;
268 // find if we are contracting on index i of right tensor
269 for (int j = 0; j < ContractDims; j++) {
270 if (eval_op_indices[j].second == i) {
271 contracting = true;
272 break;
273 }
274 }
275 if (!contracting) {
276 m_dimensions[dim_idx] = eval_right_dims[i];
277 if (nocontract_idx+1 < internal::array_size<right_nocontract_t>::value) {
278 m_j_strides[nocontract_idx+1] =
279 m_j_strides[nocontract_idx] * eval_right_dims[i];
280 } else {
281 m_j_size = m_j_strides[nocontract_idx] * eval_right_dims[i];
282 }
283 m_right_nocontract_strides[nocontract_idx] = rhs_strides[i];
284 dim_idx++;
285 nocontract_idx++;
286 }
287 }
288
289 // Now compute the strides corresponding to the contracting dimensions. We
290 // assumed above that non-contracting axes are represented in the same order
291 // in the matrix as they are in the tensor. This is not the case for
292 // contracting axes. As the contracting axes must be of the same size in
293 // each tensor, we'll only look at the first tensor here.
294 m_rhs_inner_dim_contiguous = true;
295 m_rhs_inner_dim_reordered = false;
296 for (int i = 0; i < ContractDims; i++) {
297 Index left = eval_op_indices[i].first;
298 Index right = eval_op_indices[i].second;
299
300 Index size = eval_left_dims[left];
301 eigen_assert(size == eval_right_dims[right] &&
302 "Contraction axes must be same size");
303
304 if (i+1 < static_cast<int>(internal::array_size<contract_t>::value)) {
305 m_k_strides[i+1] = m_k_strides[i] * size;
306 } else {
307 m_k_size = m_k_strides[i] * size;
308 }
309 m_left_contracting_strides[i] = lhs_strides[left];
310 m_right_contracting_strides[i] = rhs_strides[right];
311
312 if (i > 0 && right < eval_op_indices[i-1].second) {
313 m_rhs_inner_dim_reordered = true;
314 }
315 if (right != i) {
316 m_rhs_inner_dim_contiguous = false;
317 }
318 }
319
320 // If the layout is RowMajor, we need to reverse the m_dimensions
321 if (static_cast<int>(Layout) == static_cast<int>(RowMajor)) {
322 for (int i = 0, j = NumDims - 1; i < j; i++, j--) {
323 numext::swap(m_dimensions[i], m_dimensions[j]);
324 }
325 }
326 }
327
328 EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE const Dimensions& dimensions() const { return m_dimensions; }
329
330 EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE bool evalSubExprsIfNeeded(Scalar* data) {
331 m_leftImpl.evalSubExprsIfNeeded(NULL);
332 m_rightImpl.evalSubExprsIfNeeded(NULL);
333 if (data) {
334 evalTo(data);
335 return false;
336 } else {
337 m_result = static_cast<Scalar *>(m_device.allocate(dimensions().TotalSize() * sizeof(Scalar)));
338 evalTo(m_result);
339 return true;
340 }
341 }
342
343 EIGEN_DEVICE_FUNC void evalTo(Scalar* buffer) const {
344 if (this->m_lhs_inner_dim_contiguous) {
345 if (this->m_rhs_inner_dim_contiguous) {
346 if (this->m_rhs_inner_dim_reordered) {
347 static_cast<const Derived*>(this)->template evalProduct<true, true, true, Unaligned>(buffer);
348 }
349 else {
350 static_cast<const Derived*>(this)->template evalProduct<true, true, false, Unaligned>(buffer);
351 }
352 }
353 else {
354 if (this->m_rhs_inner_dim_reordered) {
355 static_cast<const Derived*>(this)->template evalProduct<true, false, true, Unaligned>(buffer);
356 }
357 else {
358 static_cast<const Derived*>(this)->template evalProduct<true, false, false, Unaligned>(buffer);
359 }
360 }
361 }
362 else {
363 if (this->m_rhs_inner_dim_contiguous) {
364 if (this->m_rhs_inner_dim_reordered) {
365 static_cast<const Derived*>(this)->template evalProduct<false, true, true, Unaligned>(buffer);
366 }
367 else {
368 static_cast<const Derived*>(this)->template evalProduct<false, true, false, Unaligned>(buffer);
369 }
370 }
371 else {
372 if (this->m_rhs_inner_dim_reordered) {
373 static_cast<const Derived*>(this)->template evalProduct<false, false, true, Unaligned>(buffer);
374 }
375 else {
376 static_cast<const Derived*>(this)->template evalProduct<false, false, false, Unaligned>(buffer);
377 }
378 }
379 }
380 }
381
382 template <bool lhs_inner_dim_contiguous, bool rhs_inner_dim_contiguous, bool rhs_inner_dim_reordered, int Alignment>
383 EIGEN_DEVICE_FUNC void evalGemv(Scalar* buffer) const {
384 const Index rows = m_i_size;
385 const Index cols = m_k_size;
386
387 typedef typename internal::remove_const<typename EvalLeftArgType::Scalar>::type LhsScalar;
388 typedef typename internal::remove_const<typename EvalRightArgType::Scalar>::type RhsScalar;
389 typedef TensorEvaluator<EvalLeftArgType, Device> LeftEvaluator;
390 typedef TensorEvaluator<EvalRightArgType, Device> RightEvaluator;
391 const Index lhs_packet_size = internal::unpacket_traits<typename LeftEvaluator::PacketReturnType>::size;
392 const Index rhs_packet_size = internal::unpacket_traits<typename RightEvaluator::PacketReturnType>::size;
393 const int lhs_alignment = LeftEvaluator::IsAligned ? Aligned : Unaligned;
394 const int rhs_alignment = RightEvaluator::IsAligned ? Aligned : Unaligned;
395 typedef internal::TensorContractionInputMapper<LhsScalar, Index, internal::Lhs,
396 LeftEvaluator, left_nocontract_t,
397 contract_t, lhs_packet_size,
398 lhs_inner_dim_contiguous,
399 false, lhs_alignment> LhsMapper;
400
401 typedef internal::TensorContractionInputMapper<RhsScalar, Index, internal::Rhs,
402 RightEvaluator, right_nocontract_t,
403 contract_t, rhs_packet_size,
404 rhs_inner_dim_contiguous,
405 rhs_inner_dim_reordered, rhs_alignment> RhsMapper;
406
407 LhsMapper lhs(m_leftImpl, m_left_nocontract_strides, m_i_strides,
408 m_left_contracting_strides, m_k_strides);
409 RhsMapper rhs(m_rightImpl, m_right_nocontract_strides, m_j_strides,
410 m_right_contracting_strides, m_k_strides);
411
412 const Scalar alpha(1);
413 const Index resIncr(1);
414
415 // zero out the result buffer (which must be of size at least rows * sizeof(Scalar)
416 m_device.memset(buffer, 0, rows * sizeof(Scalar));
417
418 internal::general_matrix_vector_product<Index,LhsScalar,LhsMapper,ColMajor,false,RhsScalar,RhsMapper,false>::run(
419 rows, cols, lhs, rhs,
420 buffer, resIncr, alpha);
421 }
422
423 template <bool lhs_inner_dim_contiguous, bool rhs_inner_dim_contiguous, bool rhs_inner_dim_reordered, int Alignment>
424 EIGEN_DEVICE_FUNC void evalGemm(Scalar* buffer) const {
425 // columns in left side, rows in right side
426 const Index k = this->m_k_size;
427
428 // rows in left side
429 const Index m = this->m_i_size;
430
431 // columns in right side
432 const Index n = this->m_j_size;
433
434 // zero out the result buffer (which must be of size at least m * n * sizeof(Scalar)
435 this->m_device.memset(buffer, 0, m * n * sizeof(Scalar));
436
437 // define mr, nr, and all of my data mapper types
438 typedef typename internal::remove_const<typename EvalLeftArgType::Scalar>::type LhsScalar;
439 typedef typename internal::remove_const<typename EvalRightArgType::Scalar>::type RhsScalar;
440 typedef typename internal::gebp_traits<LhsScalar, RhsScalar> Traits;
441
442 const Index nr = Traits::nr;
443 const Index mr = Traits::mr;
444
445 typedef TensorEvaluator<EvalLeftArgType, Device> LeftEvaluator;
446 typedef TensorEvaluator<EvalRightArgType, Device> RightEvaluator;
447
448 const Index lhs_packet_size = internal::unpacket_traits<typename LeftEvaluator::PacketReturnType>::size;
449 const Index rhs_packet_size = internal::unpacket_traits<typename RightEvaluator::PacketReturnType>::size;
450
451 typedef internal::TensorContractionInputMapper<LhsScalar, Index, internal::Lhs,
452 LeftEvaluator, left_nocontract_t,
453 contract_t, lhs_packet_size,
454 lhs_inner_dim_contiguous,
455 false, Unaligned> LhsMapper;
456
457 typedef internal::TensorContractionInputMapper<RhsScalar, Index, internal::Rhs,
458 RightEvaluator, right_nocontract_t,
459 contract_t, rhs_packet_size,
460 rhs_inner_dim_contiguous,
461 rhs_inner_dim_reordered, Unaligned> RhsMapper;
462
463 typedef internal::blas_data_mapper<Scalar, Index, ColMajor> OutputMapper;
464
465 // Declare GEBP packing and kernel structs
466 internal::gemm_pack_lhs<LhsScalar, Index, typename LhsMapper::SubMapper, mr, Traits::LhsProgress, ColMajor> pack_lhs;
467 internal::gemm_pack_rhs<RhsScalar, Index, typename RhsMapper::SubMapper, nr, ColMajor> pack_rhs;
468
469 internal::gebp_kernel<LhsScalar, RhsScalar, Index, OutputMapper, mr, nr, false, false> gebp;
470
471 // initialize data mappers
472 LhsMapper lhs(this->m_leftImpl, this->m_left_nocontract_strides, this->m_i_strides,
473 this->m_left_contracting_strides, this->m_k_strides);
474
475 RhsMapper rhs(this->m_rightImpl, this->m_right_nocontract_strides, this->m_j_strides,
476 this->m_right_contracting_strides, this->m_k_strides);
477
478 OutputMapper output(buffer, m);
479
480 // Sizes of the blocks to load in cache. See the Goto paper for details.
481 internal::TensorContractionBlocking<LhsMapper, RhsMapper, Index, internal::ShardByCol> blocking(k, m, n, 1);
482 const Index kc = blocking.kc();
483 const Index mc = numext::mini(m, blocking.mc());
484 const Index nc = numext::mini(n, blocking.nc());
485 const Index sizeA = mc * kc;
486 const Index sizeB = kc * nc;
487
488 LhsScalar* blockA = static_cast<LhsScalar *>(this->m_device.allocate(sizeA * sizeof(LhsScalar)));
489 RhsScalar* blockB = static_cast<RhsScalar *>(this->m_device.allocate(sizeB * sizeof(RhsScalar)));
490
491 for(Index i2=0; i2<m; i2+=mc)
492 {
493 const Index actual_mc = numext::mini(i2+mc,m)-i2;
494 for (Index k2 = 0; k2 < k; k2 += kc) {
495 // make sure we don't overshoot right edge of left matrix, then pack vertical panel
496 const Index actual_kc = numext::mini(k2 + kc, k) - k2;
497 pack_lhs(blockA, lhs.getSubMapper(i2, k2), actual_kc, actual_mc, 0, 0);
498
499 // series of horizontal blocks
500 for (Index j2 = 0; j2 < n; j2 += nc) {
501 // make sure we don't overshoot right edge of right matrix, then pack block
502 const Index actual_nc = numext::mini(j2 + nc, n) - j2;
503 pack_rhs(blockB, rhs.getSubMapper(k2, j2), actual_kc, actual_nc, 0, 0);
504
505 // call gebp (matrix kernel)
506 // The parameters here are copied from Eigen's GEMM implementation
507 gebp(output.getSubMapper(i2, j2), blockA, blockB, actual_mc, actual_kc, actual_nc, Scalar(1), -1, -1, 0, 0);
508 }
509 }
510 }
511
512 this->m_device.deallocate(blockA);
513 this->m_device.deallocate(blockB);
514 }
515
516 EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE void cleanup() {
517 m_leftImpl.cleanup();
518 m_rightImpl.cleanup();
519
520 if (m_result != NULL) {
521 m_device.deallocate(m_result);
522 m_result = NULL;
523 }
524 }
525
526 EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE CoeffReturnType coeff(Index index) const {
527 return m_result[index];
528 }
529
530 EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE TensorOpCost costPerCoeff(bool) const {
531 return TensorOpCost(sizeof(CoeffReturnType), 0, 0);
532 }
533
534 template<int LoadMode>
535 EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE PacketReturnType packet(Index index) const {
536 return internal::ploadt<PacketReturnType, LoadMode>(m_result + index);
537 }
538
539 EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE Scalar* data() const { return m_result; }
540
541 protected:
542 // Prevent assignment
543 TensorContractionEvaluatorBase& operator = (const TensorContractionEvaluatorBase&);
544 Dimensions m_dimensions;
545
546 contract_t m_k_strides;
547 contract_t m_left_contracting_strides;
548 contract_t m_right_contracting_strides;
549
550 bool m_lhs_inner_dim_contiguous;
551 bool m_rhs_inner_dim_contiguous;
552 bool m_rhs_inner_dim_reordered;
553
554 left_nocontract_t m_i_strides;
555 right_nocontract_t m_j_strides;
556 left_nocontract_t m_left_nocontract_strides;
557 right_nocontract_t m_right_nocontract_strides;
558
559 Index m_i_size;
560 Index m_j_size;
561 Index m_k_size;
562
563 TensorEvaluator<EvalLeftArgType, Device> m_leftImpl;
564 TensorEvaluator<EvalRightArgType, Device> m_rightImpl;
565 const Device& m_device;
566 Scalar* m_result;
567};
568
569
570// evaluator for default device
571template<typename Indices, typename LeftArgType, typename RightArgType, typename Device>
572struct TensorEvaluator<const TensorContractionOp<Indices, LeftArgType, RightArgType>, Device> :
573 public TensorContractionEvaluatorBase<
574 TensorEvaluator<const TensorContractionOp<Indices, LeftArgType, RightArgType>, Device> > {
575 typedef TensorEvaluator<const TensorContractionOp<Indices, LeftArgType, RightArgType>, Device> Self;
576 typedef TensorContractionEvaluatorBase<Self> Base;
577
578 typedef TensorContractionOp<Indices, LeftArgType, RightArgType> XprType;
579 typedef typename internal::remove_const<typename XprType::Scalar>::type Scalar;
580 typedef typename XprType::Index Index;
581 typedef typename XprType::CoeffReturnType CoeffReturnType;
582 typedef typename PacketType<CoeffReturnType, Device>::type PacketReturnType;
583
584 enum {
585 Layout = TensorEvaluator<LeftArgType, Device>::Layout
586 };
587
588 // Most of the code is assuming that both input tensors are ColMajor. If the
589 // inputs are RowMajor, we will "cheat" by swapping the LHS and RHS:
590 // If we want to compute A * B = C, where A is LHS and B is RHS, the code
591 // will pretend B is LHS and A is RHS.
592 typedef typename internal::conditional<
593 static_cast<int>(Layout) == static_cast<int>(ColMajor), LeftArgType, RightArgType>::type EvalLeftArgType;
594 typedef typename internal::conditional<
595 static_cast<int>(Layout) == static_cast<int>(ColMajor), RightArgType, LeftArgType>::type EvalRightArgType;
596
597 static const int LDims =
598 internal::array_size<typename TensorEvaluator<EvalLeftArgType, Device>::Dimensions>::value;
599 static const int RDims =
600 internal::array_size<typename TensorEvaluator<EvalRightArgType, Device>::Dimensions>::value;
601 static const int ContractDims = internal::array_size<Indices>::value;
602
603 typedef array<Index, ContractDims> contract_t;
604 typedef array<Index, LDims - ContractDims> left_nocontract_t;
605 typedef array<Index, RDims - ContractDims> right_nocontract_t;
606
607 static const int NumDims = LDims + RDims - 2 * ContractDims;
608
609 // Could we use NumDimensions here?
610 typedef DSizes<Index, NumDims> Dimensions;
611
612 EIGEN_DEVICE_FUNC TensorEvaluator(const XprType& op, const Device& device) :
613 Base(op, device) { }
614
615 template <bool lhs_inner_dim_contiguous, bool rhs_inner_dim_contiguous, bool rhs_inner_dim_reordered, int Alignment>
616 EIGEN_DEVICE_FUNC void evalProduct(Scalar* buffer) const {
617 if (this->m_j_size == 1) {
618 this->template evalGemv<lhs_inner_dim_contiguous, rhs_inner_dim_contiguous, rhs_inner_dim_reordered, Alignment>(buffer);
619 return;
620 }
621
622 this->template evalGemm<lhs_inner_dim_contiguous, rhs_inner_dim_contiguous, rhs_inner_dim_reordered, Alignment>(buffer);
623 }
624};
625
626} // end namespace Eigen
627
628#endif // EIGEN_CXX11_TENSOR_TENSOR_CONTRACTION_H
Namespace containing all symbols from the Eigen library.
Definition: AdolcForward:45
const Device & device() const
required by sycl in order to construct sycl buffer from raw pointer
Definition: TensorEvaluator.h:114