33#ifndef EIGEN_TRIANGULAR_MATRIX_MATRIX_BLAS_H
34#define EIGEN_TRIANGULAR_MATRIX_MATRIX_BLAS_H
41template <
typename Scalar,
typename Index,
42 int Mode,
bool LhsIsTriangular,
43 int LhsStorageOrder,
bool ConjugateLhs,
44 int RhsStorageOrder,
bool ConjugateRhs,
46struct product_triangular_matrix_matrix_trmm :
47 product_triangular_matrix_matrix<Scalar,Index,Mode,
48 LhsIsTriangular,LhsStorageOrder,ConjugateLhs,
49 RhsStorageOrder, ConjugateRhs, ResStorageOrder, BuiltIn> {};
53#define EIGEN_BLAS_TRMM_SPECIALIZE(Scalar, LhsIsTriangular) \
54template <typename Index, int Mode, \
55 int LhsStorageOrder, bool ConjugateLhs, \
56 int RhsStorageOrder, bool ConjugateRhs> \
57struct product_triangular_matrix_matrix<Scalar,Index, Mode, LhsIsTriangular, \
58 LhsStorageOrder,ConjugateLhs, RhsStorageOrder,ConjugateRhs,ColMajor,Specialized> { \
59 static inline void run(Index _rows, Index _cols, Index _depth, const Scalar* _lhs, Index lhsStride,\
60 const Scalar* _rhs, Index rhsStride, Scalar* res, Index resStride, Scalar alpha, level3_blocking<Scalar,Scalar>& blocking) { \
61 product_triangular_matrix_matrix_trmm<Scalar,Index,Mode, \
62 LhsIsTriangular,LhsStorageOrder,ConjugateLhs, \
63 RhsStorageOrder, ConjugateRhs, ColMajor>::run( \
64 _rows, _cols, _depth, _lhs, lhsStride, _rhs, rhsStride, res, resStride, alpha, blocking); \
68EIGEN_BLAS_TRMM_SPECIALIZE(
double,
true)
69EIGEN_BLAS_TRMM_SPECIALIZE(
double, false)
70EIGEN_BLAS_TRMM_SPECIALIZE(dcomplex, true)
71EIGEN_BLAS_TRMM_SPECIALIZE(dcomplex, false)
72EIGEN_BLAS_TRMM_SPECIALIZE(
float, true)
73EIGEN_BLAS_TRMM_SPECIALIZE(
float, false)
74EIGEN_BLAS_TRMM_SPECIALIZE(scomplex, true)
75EIGEN_BLAS_TRMM_SPECIALIZE(scomplex, false)
78#define EIGEN_BLAS_TRMM_L(EIGTYPE, BLASTYPE, EIGPREFIX, BLASPREFIX) \
79template <typename Index, int Mode, \
80 int LhsStorageOrder, bool ConjugateLhs, \
81 int RhsStorageOrder, bool ConjugateRhs> \
82struct product_triangular_matrix_matrix_trmm<EIGTYPE,Index,Mode,true, \
83 LhsStorageOrder,ConjugateLhs,RhsStorageOrder,ConjugateRhs,ColMajor> \
86 IsLower = (Mode&Lower) == Lower, \
87 SetDiag = (Mode&(ZeroDiag|UnitDiag)) ? 0 : 1, \
88 IsUnitDiag = (Mode&UnitDiag) ? 1 : 0, \
89 IsZeroDiag = (Mode&ZeroDiag) ? 1 : 0, \
90 LowUp = IsLower ? Lower : Upper, \
91 conjA = ((LhsStorageOrder==ColMajor) && ConjugateLhs) ? 1 : 0 \
95 Index _rows, Index _cols, Index _depth, \
96 const EIGTYPE* _lhs, Index lhsStride, \
97 const EIGTYPE* _rhs, Index rhsStride, \
98 EIGTYPE* res, Index resStride, \
99 EIGTYPE alpha, level3_blocking<EIGTYPE,EIGTYPE>& blocking) \
101 Index diagSize = (std::min)(_rows,_depth); \
102 Index rows = IsLower ? _rows : diagSize; \
103 Index depth = IsLower ? diagSize : _depth; \
104 Index cols = _cols; \
106 typedef Matrix<EIGTYPE, Dynamic, Dynamic, LhsStorageOrder> MatrixLhs; \
107 typedef Matrix<EIGTYPE, Dynamic, Dynamic, RhsStorageOrder> MatrixRhs; \
110 if (rows != depth) { \
115 if (((nthr==1) && (((std::max)(rows,depth)-diagSize)/(double)diagSize < 0.5))) { \
117 product_triangular_matrix_matrix<EIGTYPE,Index,Mode,true, \
118 LhsStorageOrder,ConjugateLhs, RhsStorageOrder, ConjugateRhs, ColMajor, BuiltIn>::run( \
119 _rows, _cols, _depth, _lhs, lhsStride, _rhs, rhsStride, res, resStride, alpha, blocking); \
123 Map<const MatrixLhs, 0, OuterStride<> > lhsMap(_lhs,rows,depth,OuterStride<>(lhsStride)); \
124 MatrixLhs aa_tmp=lhsMap.template triangularView<Mode>(); \
125 BlasIndex aStride = convert_index<BlasIndex>(aa_tmp.outerStride()); \
126 gemm_blocking_space<ColMajor,EIGTYPE,EIGTYPE,Dynamic,Dynamic,Dynamic> gemm_blocking(_rows,_cols,_depth, 1, true); \
127 general_matrix_matrix_product<Index,EIGTYPE,LhsStorageOrder,ConjugateLhs,EIGTYPE,RhsStorageOrder,ConjugateRhs,ColMajor>::run( \
128 rows, cols, depth, aa_tmp.data(), aStride, _rhs, rhsStride, res, resStride, alpha, gemm_blocking, 0); \
134 char side = 'L', transa, uplo, diag = 'N'; \
137 BlasIndex m, n, lda, ldb; \
140 m = convert_index<BlasIndex>(diagSize); \
141 n = convert_index<BlasIndex>(cols); \
144 transa = (LhsStorageOrder==RowMajor) ? ((ConjugateLhs) ? 'C' : 'T') : 'N'; \
147 Map<const MatrixRhs, 0, OuterStride<> > rhs(_rhs,depth,cols,OuterStride<>(rhsStride)); \
148 MatrixX##EIGPREFIX b_tmp; \
150 if (ConjugateRhs) b_tmp = rhs.conjugate(); else b_tmp = rhs; \
152 ldb = convert_index<BlasIndex>(b_tmp.outerStride()); \
155 uplo = IsLower ? 'L' : 'U'; \
156 if (LhsStorageOrder==RowMajor) uplo = (uplo == 'L') ? 'U' : 'L'; \
158 Map<const MatrixLhs, 0, OuterStride<> > lhs(_lhs,rows,depth,OuterStride<>(lhsStride)); \
161 if ((conjA!=0) || (SetDiag==0)) { \
162 if (conjA) a_tmp = lhs.conjugate(); else a_tmp = lhs; \
164 a_tmp.diagonal().setZero(); \
165 else if (IsUnitDiag) \
166 a_tmp.diagonal().setOnes();\
168 lda = convert_index<BlasIndex>(a_tmp.outerStride()); \
171 lda = convert_index<BlasIndex>(lhsStride); \
175 BLASPREFIX##trmm_(&side, &uplo, &transa, &diag, &m, &n, &numext::real_ref(alpha), (const BLASTYPE*)a, &lda, (BLASTYPE*)b, &ldb); \
178 Map<MatrixX##EIGPREFIX, 0, OuterStride<> > res_tmp(res,rows,cols,OuterStride<>(resStride)); \
179 res_tmp=res_tmp+b_tmp; \
183EIGEN_BLAS_TRMM_L(
double,
double, d, d)
184EIGEN_BLAS_TRMM_L(dcomplex,
double, cd, z)
185EIGEN_BLAS_TRMM_L(
float,
float, f, s)
186EIGEN_BLAS_TRMM_L(scomplex,
float, cf, c)
189#define EIGEN_BLAS_TRMM_R(EIGTYPE, BLASTYPE, EIGPREFIX, BLASPREFIX) \
190template <typename Index, int Mode, \
191 int LhsStorageOrder, bool ConjugateLhs, \
192 int RhsStorageOrder, bool ConjugateRhs> \
193struct product_triangular_matrix_matrix_trmm<EIGTYPE,Index,Mode,false, \
194 LhsStorageOrder,ConjugateLhs,RhsStorageOrder,ConjugateRhs,ColMajor> \
197 IsLower = (Mode&Lower) == Lower, \
198 SetDiag = (Mode&(ZeroDiag|UnitDiag)) ? 0 : 1, \
199 IsUnitDiag = (Mode&UnitDiag) ? 1 : 0, \
200 IsZeroDiag = (Mode&ZeroDiag) ? 1 : 0, \
201 LowUp = IsLower ? Lower : Upper, \
202 conjA = ((RhsStorageOrder==ColMajor) && ConjugateRhs) ? 1 : 0 \
206 Index _rows, Index _cols, Index _depth, \
207 const EIGTYPE* _lhs, Index lhsStride, \
208 const EIGTYPE* _rhs, Index rhsStride, \
209 EIGTYPE* res, Index resStride, \
210 EIGTYPE alpha, level3_blocking<EIGTYPE,EIGTYPE>& blocking) \
212 Index diagSize = (std::min)(_cols,_depth); \
213 Index rows = _rows; \
214 Index depth = IsLower ? _depth : diagSize; \
215 Index cols = IsLower ? diagSize : _cols; \
217 typedef Matrix<EIGTYPE, Dynamic, Dynamic, LhsStorageOrder> MatrixLhs; \
218 typedef Matrix<EIGTYPE, Dynamic, Dynamic, RhsStorageOrder> MatrixRhs; \
221 if (cols != depth) { \
225 if ((nthr==1) && (((std::max)(cols,depth)-diagSize)/(double)diagSize < 0.5)) { \
227 product_triangular_matrix_matrix<EIGTYPE,Index,Mode,false, \
228 LhsStorageOrder,ConjugateLhs, RhsStorageOrder, ConjugateRhs, ColMajor, BuiltIn>::run( \
229 _rows, _cols, _depth, _lhs, lhsStride, _rhs, rhsStride, res, resStride, alpha, blocking); \
233 Map<const MatrixRhs, 0, OuterStride<> > rhsMap(_rhs,depth,cols, OuterStride<>(rhsStride)); \
234 MatrixRhs aa_tmp=rhsMap.template triangularView<Mode>(); \
235 BlasIndex aStride = convert_index<BlasIndex>(aa_tmp.outerStride()); \
236 gemm_blocking_space<ColMajor,EIGTYPE,EIGTYPE,Dynamic,Dynamic,Dynamic> gemm_blocking(_rows,_cols,_depth, 1, true); \
237 general_matrix_matrix_product<Index,EIGTYPE,LhsStorageOrder,ConjugateLhs,EIGTYPE,RhsStorageOrder,ConjugateRhs,ColMajor>::run( \
238 rows, cols, depth, _lhs, lhsStride, aa_tmp.data(), aStride, res, resStride, alpha, gemm_blocking, 0); \
244 char side = 'R', transa, uplo, diag = 'N'; \
247 BlasIndex m, n, lda, ldb; \
250 m = convert_index<BlasIndex>(rows); \
251 n = convert_index<BlasIndex>(diagSize); \
254 transa = (RhsStorageOrder==RowMajor) ? ((ConjugateRhs) ? 'C' : 'T') : 'N'; \
257 Map<const MatrixLhs, 0, OuterStride<> > lhs(_lhs,rows,depth,OuterStride<>(lhsStride)); \
258 MatrixX##EIGPREFIX b_tmp; \
260 if (ConjugateLhs) b_tmp = lhs.conjugate(); else b_tmp = lhs; \
262 ldb = convert_index<BlasIndex>(b_tmp.outerStride()); \
265 uplo = IsLower ? 'L' : 'U'; \
266 if (RhsStorageOrder==RowMajor) uplo = (uplo == 'L') ? 'U' : 'L'; \
268 Map<const MatrixRhs, 0, OuterStride<> > rhs(_rhs,depth,cols, OuterStride<>(rhsStride)); \
271 if ((conjA!=0) || (SetDiag==0)) { \
272 if (conjA) a_tmp = rhs.conjugate(); else a_tmp = rhs; \
274 a_tmp.diagonal().setZero(); \
275 else if (IsUnitDiag) \
276 a_tmp.diagonal().setOnes();\
278 lda = convert_index<BlasIndex>(a_tmp.outerStride()); \
281 lda = convert_index<BlasIndex>(rhsStride); \
285 BLASPREFIX##trmm_(&side, &uplo, &transa, &diag, &m, &n, &numext::real_ref(alpha), (const BLASTYPE*)a, &lda, (BLASTYPE*)b, &ldb); \
288 Map<MatrixX##EIGPREFIX, 0, OuterStride<> > res_tmp(res,rows,cols,OuterStride<>(resStride)); \
289 res_tmp=res_tmp+b_tmp; \
293EIGEN_BLAS_TRMM_R(
double,
double, d, d)
294EIGEN_BLAS_TRMM_R(dcomplex,
double, cd, z)
295EIGEN_BLAS_TRMM_R(
float,
float, f, s)
296EIGEN_BLAS_TRMM_R(scomplex,
float, cf, c)
Namespace containing all symbols from the Eigen library.
Definition: Core:287
EIGEN_DEFAULT_DENSE_INDEX_TYPE Index
The Index type as used for the API.
Definition: Meta.h:33