Loading...
Searching...
No Matches
TensorContractionBlocking.h
1// This file is part of Eigen, a lightweight C++ template library
2// for linear algebra.
3//
4// Copyright (C) 2014 Benoit Steiner <benoit.steiner.goog@gmail.com>
5//
6// This Source Code Form is subject to the terms of the Mozilla
7// Public License v. 2.0. If a copy of the MPL was not distributed
8// with this file, You can obtain one at http://mozilla.org/MPL/2.0/.
9
10#ifndef EIGEN_CXX11_TENSOR_TENSOR_CONTRACTION_BLOCKING_H
11#define EIGEN_CXX11_TENSOR_TENSOR_CONTRACTION_BLOCKING_H
12
13
14namespace Eigen {
15namespace internal {
16
17enum {
18 ShardByRow = 0,
19 ShardByCol = 1
20};
21
22
23// Default Blocking Strategy
24template <typename LhsMapper, typename RhsMapper, typename Index, int ShardingType=ShardByCol>
25class TensorContractionBlocking {
26 public:
27
28 typedef typename LhsMapper::Scalar LhsScalar;
29 typedef typename RhsMapper::Scalar RhsScalar;
30
31 EIGEN_DEVICE_FUNC TensorContractionBlocking(Index k, Index m, Index n, Index num_threads = 1) :
32 kc_(k), mc_(m), nc_(n)
33 {
34 if (ShardingType == ShardByCol) {
35 computeProductBlockingSizes<LhsScalar, RhsScalar, 1>(kc_, mc_, nc_, num_threads);
36 }
37 else {
38 computeProductBlockingSizes<LhsScalar, RhsScalar, 1>(kc_, nc_, mc_, num_threads);
39 }
40 }
41
42 EIGEN_DEVICE_FUNC EIGEN_ALWAYS_INLINE Index kc() const { return kc_; }
43 EIGEN_DEVICE_FUNC EIGEN_ALWAYS_INLINE Index mc() const { return mc_; }
44 EIGEN_DEVICE_FUNC EIGEN_ALWAYS_INLINE Index nc() const { return nc_; }
45
46 private:
47 Index kc_;
48 Index mc_;
49 Index nc_;
50};
51
52
53} // end namespace internal
54} // end namespace Eigen
55
56#endif // EIGEN_CXX11_TENSOR_TENSOR_CONTRACTION_BLOCKING_H
Namespace containing all symbols from the Eigen library.
Definition: AdolcForward:45