Eigen  3.3.0
 
Loading...
Searching...
No Matches
TriangularMatrixVector.h
1// This file is part of Eigen, a lightweight C++ template library
2// for linear algebra.
3//
4// Copyright (C) 2009 Gael Guennebaud <gael.guennebaud@inria.fr>
5//
6// This Source Code Form is subject to the terms of the Mozilla
7// Public License v. 2.0. If a copy of the MPL was not distributed
8// with this file, You can obtain one at http://mozilla.org/MPL/2.0/.
9
10#ifndef EIGEN_TRIANGULARMATRIXVECTOR_H
11#define EIGEN_TRIANGULARMATRIXVECTOR_H
12
13namespace Eigen {
14
15namespace internal {
16
17template<typename Index, int Mode, typename LhsScalar, bool ConjLhs, typename RhsScalar, bool ConjRhs, int StorageOrder, int Version=Specialized>
18struct triangular_matrix_vector_product;
19
20template<typename Index, int Mode, typename LhsScalar, bool ConjLhs, typename RhsScalar, bool ConjRhs, int Version>
21struct triangular_matrix_vector_product<Index,Mode,LhsScalar,ConjLhs,RhsScalar,ConjRhs,ColMajor,Version>
22{
23 typedef typename ScalarBinaryOpTraits<LhsScalar, RhsScalar>::ReturnType ResScalar;
24 enum {
25 IsLower = ((Mode&Lower)==Lower),
26 HasUnitDiag = (Mode & UnitDiag)==UnitDiag,
27 HasZeroDiag = (Mode & ZeroDiag)==ZeroDiag
28 };
29 static EIGEN_DONT_INLINE void run(Index _rows, Index _cols, const LhsScalar* _lhs, Index lhsStride,
30 const RhsScalar* _rhs, Index rhsIncr, ResScalar* _res, Index resIncr, const RhsScalar& alpha);
31};
32
33template<typename Index, int Mode, typename LhsScalar, bool ConjLhs, typename RhsScalar, bool ConjRhs, int Version>
34EIGEN_DONT_INLINE void triangular_matrix_vector_product<Index,Mode,LhsScalar,ConjLhs,RhsScalar,ConjRhs,ColMajor,Version>
35 ::run(Index _rows, Index _cols, const LhsScalar* _lhs, Index lhsStride,
36 const RhsScalar* _rhs, Index rhsIncr, ResScalar* _res, Index resIncr, const RhsScalar& alpha)
37 {
38 static const Index PanelWidth = EIGEN_TUNE_TRIANGULAR_PANEL_WIDTH;
39 Index size = (std::min)(_rows,_cols);
40 Index rows = IsLower ? _rows : (std::min)(_rows,_cols);
41 Index cols = IsLower ? (std::min)(_rows,_cols) : _cols;
42
43 typedef Map<const Matrix<LhsScalar,Dynamic,Dynamic,ColMajor>, 0, OuterStride<> > LhsMap;
44 const LhsMap lhs(_lhs,rows,cols,OuterStride<>(lhsStride));
45 typename conj_expr_if<ConjLhs,LhsMap>::type cjLhs(lhs);
46
47 typedef Map<const Matrix<RhsScalar,Dynamic,1>, 0, InnerStride<> > RhsMap;
48 const RhsMap rhs(_rhs,cols,InnerStride<>(rhsIncr));
49 typename conj_expr_if<ConjRhs,RhsMap>::type cjRhs(rhs);
50
51 typedef Map<Matrix<ResScalar,Dynamic,1> > ResMap;
52 ResMap res(_res,rows);
53
54 typedef const_blas_data_mapper<LhsScalar,Index,ColMajor> LhsMapper;
55 typedef const_blas_data_mapper<RhsScalar,Index,RowMajor> RhsMapper;
56
57 for (Index pi=0; pi<size; pi+=PanelWidth)
58 {
59 Index actualPanelWidth = (std::min)(PanelWidth, size-pi);
60 for (Index k=0; k<actualPanelWidth; ++k)
61 {
62 Index i = pi + k;
63 Index s = IsLower ? ((HasUnitDiag||HasZeroDiag) ? i+1 : i ) : pi;
64 Index r = IsLower ? actualPanelWidth-k : k+1;
65 if ((!(HasUnitDiag||HasZeroDiag)) || (--r)>0)
66 res.segment(s,r) += (alpha * cjRhs.coeff(i)) * cjLhs.col(i).segment(s,r);
67 if (HasUnitDiag)
68 res.coeffRef(i) += alpha * cjRhs.coeff(i);
69 }
70 Index r = IsLower ? rows - pi - actualPanelWidth : pi;
71 if (r>0)
72 {
73 Index s = IsLower ? pi+actualPanelWidth : 0;
74 general_matrix_vector_product<Index,LhsScalar,LhsMapper,ColMajor,ConjLhs,RhsScalar,RhsMapper,ConjRhs,BuiltIn>::run(
75 r, actualPanelWidth,
76 LhsMapper(&lhs.coeffRef(s,pi), lhsStride),
77 RhsMapper(&rhs.coeffRef(pi), rhsIncr),
78 &res.coeffRef(s), resIncr, alpha);
79 }
80 }
81 if((!IsLower) && cols>size)
82 {
83 general_matrix_vector_product<Index,LhsScalar,LhsMapper,ColMajor,ConjLhs,RhsScalar,RhsMapper,ConjRhs>::run(
84 rows, cols-size,
85 LhsMapper(&lhs.coeffRef(0,size), lhsStride),
86 RhsMapper(&rhs.coeffRef(size), rhsIncr),
87 _res, resIncr, alpha);
88 }
89 }
90
91template<typename Index, int Mode, typename LhsScalar, bool ConjLhs, typename RhsScalar, bool ConjRhs,int Version>
92struct triangular_matrix_vector_product<Index,Mode,LhsScalar,ConjLhs,RhsScalar,ConjRhs,RowMajor,Version>
93{
94 typedef typename ScalarBinaryOpTraits<LhsScalar, RhsScalar>::ReturnType ResScalar;
95 enum {
96 IsLower = ((Mode&Lower)==Lower),
97 HasUnitDiag = (Mode & UnitDiag)==UnitDiag,
98 HasZeroDiag = (Mode & ZeroDiag)==ZeroDiag
99 };
100 static EIGEN_DONT_INLINE void run(Index _rows, Index _cols, const LhsScalar* _lhs, Index lhsStride,
101 const RhsScalar* _rhs, Index rhsIncr, ResScalar* _res, Index resIncr, const ResScalar& alpha);
102};
103
104template<typename Index, int Mode, typename LhsScalar, bool ConjLhs, typename RhsScalar, bool ConjRhs,int Version>
105EIGEN_DONT_INLINE void triangular_matrix_vector_product<Index,Mode,LhsScalar,ConjLhs,RhsScalar,ConjRhs,RowMajor,Version>
106 ::run(Index _rows, Index _cols, const LhsScalar* _lhs, Index lhsStride,
107 const RhsScalar* _rhs, Index rhsIncr, ResScalar* _res, Index resIncr, const ResScalar& alpha)
108 {
109 static const Index PanelWidth = EIGEN_TUNE_TRIANGULAR_PANEL_WIDTH;
110 Index diagSize = (std::min)(_rows,_cols);
111 Index rows = IsLower ? _rows : diagSize;
112 Index cols = IsLower ? diagSize : _cols;
113
114 typedef Map<const Matrix<LhsScalar,Dynamic,Dynamic,RowMajor>, 0, OuterStride<> > LhsMap;
115 const LhsMap lhs(_lhs,rows,cols,OuterStride<>(lhsStride));
116 typename conj_expr_if<ConjLhs,LhsMap>::type cjLhs(lhs);
117
118 typedef Map<const Matrix<RhsScalar,Dynamic,1> > RhsMap;
119 const RhsMap rhs(_rhs,cols);
120 typename conj_expr_if<ConjRhs,RhsMap>::type cjRhs(rhs);
121
122 typedef Map<Matrix<ResScalar,Dynamic,1>, 0, InnerStride<> > ResMap;
123 ResMap res(_res,rows,InnerStride<>(resIncr));
124
125 typedef const_blas_data_mapper<LhsScalar,Index,RowMajor> LhsMapper;
126 typedef const_blas_data_mapper<RhsScalar,Index,RowMajor> RhsMapper;
127
128 for (Index pi=0; pi<diagSize; pi+=PanelWidth)
129 {
130 Index actualPanelWidth = (std::min)(PanelWidth, diagSize-pi);
131 for (Index k=0; k<actualPanelWidth; ++k)
132 {
133 Index i = pi + k;
134 Index s = IsLower ? pi : ((HasUnitDiag||HasZeroDiag) ? i+1 : i);
135 Index r = IsLower ? k+1 : actualPanelWidth-k;
136 if ((!(HasUnitDiag||HasZeroDiag)) || (--r)>0)
137 res.coeffRef(i) += alpha * (cjLhs.row(i).segment(s,r).cwiseProduct(cjRhs.segment(s,r).transpose())).sum();
138 if (HasUnitDiag)
139 res.coeffRef(i) += alpha * cjRhs.coeff(i);
140 }
141 Index r = IsLower ? pi : cols - pi - actualPanelWidth;
142 if (r>0)
143 {
144 Index s = IsLower ? 0 : pi + actualPanelWidth;
145 general_matrix_vector_product<Index,LhsScalar,LhsMapper,RowMajor,ConjLhs,RhsScalar,RhsMapper,ConjRhs,BuiltIn>::run(
146 actualPanelWidth, r,
147 LhsMapper(&lhs.coeffRef(pi,s), lhsStride),
148 RhsMapper(&rhs.coeffRef(s), rhsIncr),
149 &res.coeffRef(pi), resIncr, alpha);
150 }
151 }
152 if(IsLower && rows>diagSize)
153 {
154 general_matrix_vector_product<Index,LhsScalar,LhsMapper,RowMajor,ConjLhs,RhsScalar,RhsMapper,ConjRhs>::run(
155 rows-diagSize, cols,
156 LhsMapper(&lhs.coeffRef(diagSize,0), lhsStride),
157 RhsMapper(&rhs.coeffRef(0), rhsIncr),
158 &res.coeffRef(diagSize), resIncr, alpha);
159 }
160 }
161
162/***************************************************************************
163* Wrapper to product_triangular_vector
164***************************************************************************/
165
166template<int Mode,int StorageOrder>
167struct trmv_selector;
168
169} // end namespace internal
170
171namespace internal {
172
173template<int Mode, typename Lhs, typename Rhs>
174struct triangular_product_impl<Mode,true,Lhs,false,Rhs,true>
175{
176 template<typename Dest> static void run(Dest& dst, const Lhs &lhs, const Rhs &rhs, const typename Dest::Scalar& alpha)
177 {
178 eigen_assert(dst.rows()==lhs.rows() && dst.cols()==rhs.cols());
179
180 internal::trmv_selector<Mode,(int(internal::traits<Lhs>::Flags)&RowMajorBit) ? RowMajor : ColMajor>::run(lhs, rhs, dst, alpha);
181 }
182};
183
184template<int Mode, typename Lhs, typename Rhs>
185struct triangular_product_impl<Mode,false,Lhs,true,Rhs,false>
186{
187 template<typename Dest> static void run(Dest& dst, const Lhs &lhs, const Rhs &rhs, const typename Dest::Scalar& alpha)
188 {
189 eigen_assert(dst.rows()==lhs.rows() && dst.cols()==rhs.cols());
190
191 Transpose<Dest> dstT(dst);
192 internal::trmv_selector<(Mode & (UnitDiag|ZeroDiag)) | ((Mode & Lower) ? Upper : Lower),
193 (int(internal::traits<Rhs>::Flags)&RowMajorBit) ? ColMajor : RowMajor>
194 ::run(rhs.transpose(),lhs.transpose(), dstT, alpha);
195 }
196};
197
198} // end namespace internal
199
200namespace internal {
201
202// TODO: find a way to factorize this piece of code with gemv_selector since the logic is exactly the same.
203
204template<int Mode> struct trmv_selector<Mode,ColMajor>
205{
206 template<typename Lhs, typename Rhs, typename Dest>
207 static void run(const Lhs &lhs, const Rhs &rhs, Dest& dest, const typename Dest::Scalar& alpha)
208 {
209 typedef typename Lhs::Scalar LhsScalar;
210 typedef typename Rhs::Scalar RhsScalar;
211 typedef typename Dest::Scalar ResScalar;
212 typedef typename Dest::RealScalar RealScalar;
213
214 typedef internal::blas_traits<Lhs> LhsBlasTraits;
215 typedef typename LhsBlasTraits::DirectLinearAccessType ActualLhsType;
216 typedef internal::blas_traits<Rhs> RhsBlasTraits;
217 typedef typename RhsBlasTraits::DirectLinearAccessType ActualRhsType;
218
219 typedef Map<Matrix<ResScalar,Dynamic,1>, EIGEN_PLAIN_ENUM_MIN(AlignedMax,internal::packet_traits<ResScalar>::size)> MappedDest;
220
221 typename internal::add_const_on_value_type<ActualLhsType>::type actualLhs = LhsBlasTraits::extract(lhs);
222 typename internal::add_const_on_value_type<ActualRhsType>::type actualRhs = RhsBlasTraits::extract(rhs);
223
224 ResScalar actualAlpha = alpha * LhsBlasTraits::extractScalarFactor(lhs)
225 * RhsBlasTraits::extractScalarFactor(rhs);
226
227 enum {
228 // FIXME find a way to allow an inner stride on the result if packet_traits<Scalar>::size==1
229 // on, the other hand it is good for the cache to pack the vector anyways...
230 EvalToDestAtCompileTime = Dest::InnerStrideAtCompileTime==1,
231 ComplexByReal = (NumTraits<LhsScalar>::IsComplex) && (!NumTraits<RhsScalar>::IsComplex),
232 MightCannotUseDest = (Dest::InnerStrideAtCompileTime!=1) || ComplexByReal
233 };
234
235 gemv_static_vector_if<ResScalar,Dest::SizeAtCompileTime,Dest::MaxSizeAtCompileTime,MightCannotUseDest> static_dest;
236
237 bool alphaIsCompatible = (!ComplexByReal) || (numext::imag(actualAlpha)==RealScalar(0));
238 bool evalToDest = EvalToDestAtCompileTime && alphaIsCompatible;
239
240 RhsScalar compatibleAlpha = get_factor<ResScalar,RhsScalar>::run(actualAlpha);
241
242 ei_declare_aligned_stack_constructed_variable(ResScalar,actualDestPtr,dest.size(),
243 evalToDest ? dest.data() : static_dest.data());
244
245 if(!evalToDest)
246 {
247 #ifdef EIGEN_DENSE_STORAGE_CTOR_PLUGIN
248 Index size = dest.size();
249 EIGEN_DENSE_STORAGE_CTOR_PLUGIN
250 #endif
251 if(!alphaIsCompatible)
252 {
253 MappedDest(actualDestPtr, dest.size()).setZero();
254 compatibleAlpha = RhsScalar(1);
255 }
256 else
257 MappedDest(actualDestPtr, dest.size()) = dest;
258 }
259
260 internal::triangular_matrix_vector_product
261 <Index,Mode,
262 LhsScalar, LhsBlasTraits::NeedToConjugate,
263 RhsScalar, RhsBlasTraits::NeedToConjugate,
264 ColMajor>
265 ::run(actualLhs.rows(),actualLhs.cols(),
266 actualLhs.data(),actualLhs.outerStride(),
267 actualRhs.data(),actualRhs.innerStride(),
268 actualDestPtr,1,compatibleAlpha);
269
270 if (!evalToDest)
271 {
272 if(!alphaIsCompatible)
273 dest += actualAlpha * MappedDest(actualDestPtr, dest.size());
274 else
275 dest = MappedDest(actualDestPtr, dest.size());
276 }
277 }
278};
279
280template<int Mode> struct trmv_selector<Mode,RowMajor>
281{
282 template<typename Lhs, typename Rhs, typename Dest>
283 static void run(const Lhs &lhs, const Rhs &rhs, Dest& dest, const typename Dest::Scalar& alpha)
284 {
285 typedef typename Lhs::Scalar LhsScalar;
286 typedef typename Rhs::Scalar RhsScalar;
287 typedef typename Dest::Scalar ResScalar;
288
289 typedef internal::blas_traits<Lhs> LhsBlasTraits;
290 typedef typename LhsBlasTraits::DirectLinearAccessType ActualLhsType;
291 typedef internal::blas_traits<Rhs> RhsBlasTraits;
292 typedef typename RhsBlasTraits::DirectLinearAccessType ActualRhsType;
293 typedef typename internal::remove_all<ActualRhsType>::type ActualRhsTypeCleaned;
294
295 typename add_const<ActualLhsType>::type actualLhs = LhsBlasTraits::extract(lhs);
296 typename add_const<ActualRhsType>::type actualRhs = RhsBlasTraits::extract(rhs);
297
298 ResScalar actualAlpha = alpha * LhsBlasTraits::extractScalarFactor(lhs)
299 * RhsBlasTraits::extractScalarFactor(rhs);
300
301 enum {
302 DirectlyUseRhs = ActualRhsTypeCleaned::InnerStrideAtCompileTime==1
303 };
304
305 gemv_static_vector_if<RhsScalar,ActualRhsTypeCleaned::SizeAtCompileTime,ActualRhsTypeCleaned::MaxSizeAtCompileTime,!DirectlyUseRhs> static_rhs;
306
307 ei_declare_aligned_stack_constructed_variable(RhsScalar,actualRhsPtr,actualRhs.size(),
308 DirectlyUseRhs ? const_cast<RhsScalar*>(actualRhs.data()) : static_rhs.data());
309
310 if(!DirectlyUseRhs)
311 {
312 #ifdef EIGEN_DENSE_STORAGE_CTOR_PLUGIN
313 Index size = actualRhs.size();
314 EIGEN_DENSE_STORAGE_CTOR_PLUGIN
315 #endif
316 Map<typename ActualRhsTypeCleaned::PlainObject>(actualRhsPtr, actualRhs.size()) = actualRhs;
317 }
318
319 internal::triangular_matrix_vector_product
320 <Index,Mode,
321 LhsScalar, LhsBlasTraits::NeedToConjugate,
322 RhsScalar, RhsBlasTraits::NeedToConjugate,
323 RowMajor>
324 ::run(actualLhs.rows(),actualLhs.cols(),
325 actualLhs.data(),actualLhs.outerStride(),
326 actualRhsPtr,1,
327 dest.data(),dest.innerStride(),
328 actualAlpha);
329 }
330};
331
332} // end namespace internal
333
334} // end namespace Eigen
335
336#endif // EIGEN_TRIANGULARMATRIXVECTOR_H
@ UnitDiag
Definition: Constants.h:208
@ ZeroDiag
Definition: Constants.h:210
@ Lower
Definition: Constants.h:204
@ Upper
Definition: Constants.h:206
@ ColMajor
Definition: Constants.h:320
@ RowMajor
Definition: Constants.h:322
const unsigned int RowMajorBit
Definition: Constants.h:61
Namespace containing all symbols from the Eigen library.
Definition: Core:287
EIGEN_DEFAULT_DENSE_INDEX_TYPE Index
The Index type as used for the API.
Definition: Meta.h:33