33#ifndef EIGEN_TRIANGULAR_MATRIX_VECTOR_BLAS_H
34#define EIGEN_TRIANGULAR_MATRIX_VECTOR_BLAS_H
46template<
typename Index,
int Mode,
typename LhsScalar,
bool ConjLhs,
typename RhsScalar,
bool ConjRhs,
int StorageOrder>
47struct triangular_matrix_vector_product_trmv :
48 triangular_matrix_vector_product<Index,Mode,LhsScalar,ConjLhs,RhsScalar,ConjRhs,StorageOrder,BuiltIn> {};
50#define EIGEN_BLAS_TRMV_SPECIALIZE(Scalar) \
51template<typename Index, int Mode, bool ConjLhs, bool ConjRhs> \
52struct triangular_matrix_vector_product<Index,Mode,Scalar,ConjLhs,Scalar,ConjRhs,ColMajor,Specialized> { \
53 static void run(Index _rows, Index _cols, const Scalar* _lhs, Index lhsStride, \
54 const Scalar* _rhs, Index rhsIncr, Scalar* _res, Index resIncr, Scalar alpha) { \
55 triangular_matrix_vector_product_trmv<Index,Mode,Scalar,ConjLhs,Scalar,ConjRhs,ColMajor>::run( \
56 _rows, _cols, _lhs, lhsStride, _rhs, rhsIncr, _res, resIncr, alpha); \
59template<typename Index, int Mode, bool ConjLhs, bool ConjRhs> \
60struct triangular_matrix_vector_product<Index,Mode,Scalar,ConjLhs,Scalar,ConjRhs,RowMajor,Specialized> { \
61 static void run(Index _rows, Index _cols, const Scalar* _lhs, Index lhsStride, \
62 const Scalar* _rhs, Index rhsIncr, Scalar* _res, Index resIncr, Scalar alpha) { \
63 triangular_matrix_vector_product_trmv<Index,Mode,Scalar,ConjLhs,Scalar,ConjRhs,RowMajor>::run( \
64 _rows, _cols, _lhs, lhsStride, _rhs, rhsIncr, _res, resIncr, alpha); \
68EIGEN_BLAS_TRMV_SPECIALIZE(
double)
69EIGEN_BLAS_TRMV_SPECIALIZE(
float)
70EIGEN_BLAS_TRMV_SPECIALIZE(dcomplex)
71EIGEN_BLAS_TRMV_SPECIALIZE(scomplex)
74#define EIGEN_BLAS_TRMV_CM(EIGTYPE, BLASTYPE, EIGPREFIX, BLASPREFIX) \
75template<typename Index, int Mode, bool ConjLhs, bool ConjRhs> \
76struct triangular_matrix_vector_product_trmv<Index,Mode,EIGTYPE,ConjLhs,EIGTYPE,ConjRhs,ColMajor> { \
78 IsLower = (Mode&Lower) == Lower, \
79 SetDiag = (Mode&(ZeroDiag|UnitDiag)) ? 0 : 1, \
80 IsUnitDiag = (Mode&UnitDiag) ? 1 : 0, \
81 IsZeroDiag = (Mode&ZeroDiag) ? 1 : 0, \
82 LowUp = IsLower ? Lower : Upper \
84 static void run(Index _rows, Index _cols, const EIGTYPE* _lhs, Index lhsStride, \
85 const EIGTYPE* _rhs, Index rhsIncr, EIGTYPE* _res, Index resIncr, EIGTYPE alpha) \
87 if (ConjLhs || IsZeroDiag) { \
88 triangular_matrix_vector_product<Index,Mode,EIGTYPE,ConjLhs,EIGTYPE,ConjRhs,ColMajor,BuiltIn>::run( \
89 _rows, _cols, _lhs, lhsStride, _rhs, rhsIncr, _res, resIncr, alpha); \
92 Index size = (std::min)(_rows,_cols); \
93 Index rows = IsLower ? _rows : size; \
94 Index cols = IsLower ? size : _cols; \
96 typedef VectorX##EIGPREFIX VectorRhs; \
100 Map<const VectorRhs, 0, InnerStride<> > rhs(_rhs,cols,InnerStride<>(rhsIncr)); \
102 if (ConjRhs) x_tmp = rhs.conjugate(); else x_tmp = rhs; \
107 char trans, uplo, diag; \
108 BlasIndex m, n, lda, incx, incy; \
113 n = convert_index<BlasIndex>(size); \
114 lda = convert_index<BlasIndex>(lhsStride); \
116 incy = convert_index<BlasIndex>(resIncr); \
120 uplo = IsLower ? 'L' : 'U'; \
121 diag = IsUnitDiag ? 'U' : 'N'; \
124 BLASPREFIX##trmv_(&uplo, &trans, &diag, &n, (const BLASTYPE*)_lhs, &lda, (BLASTYPE*)x, &incx); \
127 BLASPREFIX##axpy_(&n, &numext::real_ref(alpha),(const BLASTYPE*)x, &incx, (BLASTYPE*)_res, &incy); \
129 if (size<(std::max)(rows,cols)) { \
130 if (ConjRhs) x_tmp = rhs.conjugate(); else x_tmp = rhs; \
133 y = _res + size*resIncr; \
135 m = convert_index<BlasIndex>(rows-size); \
136 n = convert_index<BlasIndex>(size); \
141 a = _lhs + size*lda; \
142 m = convert_index<BlasIndex>(size); \
143 n = convert_index<BlasIndex>(cols-size); \
145 BLASPREFIX##gemv_(&trans, &m, &n, &numext::real_ref(alpha), (const BLASTYPE*)a, &lda, (const BLASTYPE*)x, &incx, &numext::real_ref(beta), (BLASTYPE*)y, &incy); \
150EIGEN_BLAS_TRMV_CM(
double,
double, d, d)
151EIGEN_BLAS_TRMV_CM(dcomplex,
double, cd, z)
152EIGEN_BLAS_TRMV_CM(
float,
float, f, s)
153EIGEN_BLAS_TRMV_CM(scomplex,
float, cf, c)
156#define EIGEN_BLAS_TRMV_RM(EIGTYPE, BLASTYPE, EIGPREFIX, BLASPREFIX) \
157template<typename Index, int Mode, bool ConjLhs, bool ConjRhs> \
158struct triangular_matrix_vector_product_trmv<Index,Mode,EIGTYPE,ConjLhs,EIGTYPE,ConjRhs,RowMajor> { \
160 IsLower = (Mode&Lower) == Lower, \
161 SetDiag = (Mode&(ZeroDiag|UnitDiag)) ? 0 : 1, \
162 IsUnitDiag = (Mode&UnitDiag) ? 1 : 0, \
163 IsZeroDiag = (Mode&ZeroDiag) ? 1 : 0, \
164 LowUp = IsLower ? Lower : Upper \
166 static void run(Index _rows, Index _cols, const EIGTYPE* _lhs, Index lhsStride, \
167 const EIGTYPE* _rhs, Index rhsIncr, EIGTYPE* _res, Index resIncr, EIGTYPE alpha) \
170 triangular_matrix_vector_product<Index,Mode,EIGTYPE,ConjLhs,EIGTYPE,ConjRhs,RowMajor,BuiltIn>::run( \
171 _rows, _cols, _lhs, lhsStride, _rhs, rhsIncr, _res, resIncr, alpha); \
174 Index size = (std::min)(_rows,_cols); \
175 Index rows = IsLower ? _rows : size; \
176 Index cols = IsLower ? size : _cols; \
178 typedef VectorX##EIGPREFIX VectorRhs; \
182 Map<const VectorRhs, 0, InnerStride<> > rhs(_rhs,cols,InnerStride<>(rhsIncr)); \
184 if (ConjRhs) x_tmp = rhs.conjugate(); else x_tmp = rhs; \
189 char trans, uplo, diag; \
190 BlasIndex m, n, lda, incx, incy; \
195 n = convert_index<BlasIndex>(size); \
196 lda = convert_index<BlasIndex>(lhsStride); \
198 incy = convert_index<BlasIndex>(resIncr); \
201 trans = ConjLhs ? 'C' : 'T'; \
202 uplo = IsLower ? 'U' : 'L'; \
203 diag = IsUnitDiag ? 'U' : 'N'; \
206 BLASPREFIX##trmv_(&uplo, &trans, &diag, &n, (const BLASTYPE*)_lhs, &lda, (BLASTYPE*)x, &incx); \
209 BLASPREFIX##axpy_(&n, &numext::real_ref(alpha),(const BLASTYPE*)x, &incx, (BLASTYPE*)_res, &incy); \
211 if (size<(std::max)(rows,cols)) { \
212 if (ConjRhs) x_tmp = rhs.conjugate(); else x_tmp = rhs; \
215 y = _res + size*resIncr; \
216 a = _lhs + size*lda; \
217 m = convert_index<BlasIndex>(rows-size); \
218 n = convert_index<BlasIndex>(size); \
224 m = convert_index<BlasIndex>(size); \
225 n = convert_index<BlasIndex>(cols-size); \
227 BLASPREFIX##gemv_(&trans, &n, &m, &numext::real_ref(alpha), (const BLASTYPE*)a, &lda, (const BLASTYPE*)x, &incx, &numext::real_ref(beta), (BLASTYPE*)y, &incy); \
232EIGEN_BLAS_TRMV_RM(
double,
double, d, d)
233EIGEN_BLAS_TRMV_RM(dcomplex,
double, cd, z)
234EIGEN_BLAS_TRMV_RM(
float,
float, f, s)
235EIGEN_BLAS_TRMV_RM(scomplex,
float, cf, c)
Namespace containing all symbols from the Eigen library.
Definition: Core:287