BlasUtil.h
Go to the documentation of this file.
1 // This file is part of Eigen, a lightweight C++ template library
2 // for linear algebra.
3 //
4 // Copyright (C) 2009-2010 Gael Guennebaud <gael.guennebaud@inria.fr>
5 //
6 // Eigen is free software; you can redistribute it and/or
7 // modify it under the terms of the GNU Lesser General Public
8 // License as published by the Free Software Foundation; either
9 // version 3 of the License, or (at your option) any later version.
10 //
11 // Alternatively, you can redistribute it and/or
12 // modify it under the terms of the GNU General Public License as
13 // published by the Free Software Foundation; either version 2 of
14 // the License, or (at your option) any later version.
15 //
16 // Eigen is distributed in the hope that it will be useful, but WITHOUT ANY
17 // WARRANTY; without even the implied warranty of MERCHANTABILITY or FITNESS
18 // FOR A PARTICULAR PURPOSE. See the GNU Lesser General Public License or the
19 // GNU General Public License for more details.
20 //
21 // You should have received a copy of the GNU Lesser General Public
22 // License and a copy of the GNU General Public License along with
23 // Eigen. If not, see <http://www.gnu.org/licenses/>.
24 
25 #ifndef EIGEN_BLASUTIL_H
26 #define EIGEN_BLASUTIL_H
27 
28 // This file contains many lightweight helper classes used to
29 // implement and control fast level 2 and level 3 BLAS-like routines.
30 
31 namespace Eigen {
32 
33 namespace internal {
34 
35 // forward declarations
36 template<typename LhsScalar, typename RhsScalar, typename Index, int mr, int nr, bool ConjugateLhs=false, bool ConjugateRhs=false>
37 struct gebp_kernel;
38 
39 template<typename Scalar, typename Index, int nr, int StorageOrder, bool Conjugate = false, bool PanelMode=false>
40 struct gemm_pack_rhs;
41 
42 template<typename Scalar, typename Index, int Pack1, int Pack2, int StorageOrder, bool Conjugate = false, bool PanelMode = false>
43 struct gemm_pack_lhs;
44 
45 template<
46  typename Index,
47  typename LhsScalar, int LhsStorageOrder, bool ConjugateLhs,
48  typename RhsScalar, int RhsStorageOrder, bool ConjugateRhs,
49  int ResStorageOrder>
50 struct general_matrix_matrix_product;
51 
52 template<typename Index, typename LhsScalar, int LhsStorageOrder, bool ConjugateLhs, typename RhsScalar, bool ConjugateRhs, int Version=Specialized>
54 
55 
56 template<bool Conjugate> struct conj_if;
57 
58 template<> struct conj_if<true> {
59  template<typename T>
60  inline T operator()(const T& x) { return conj(x); }
61  template<typename T>
62  inline T pconj(const T& x) { return internal::pconj(x); }
63 };
64 
65 template<> struct conj_if<false> {
66  template<typename T>
67  inline const T& operator()(const T& x) { return x; }
68  template<typename T>
69  inline const T& pconj(const T& x) { return x; }
70 };
71 
72 template<typename Scalar> struct conj_helper<Scalar,Scalar,false,false>
73 {
74  EIGEN_STRONG_INLINE Scalar pmadd(const Scalar& x, const Scalar& y, const Scalar& c) const { return internal::pmadd(x,y,c); }
75  EIGEN_STRONG_INLINE Scalar pmul(const Scalar& x, const Scalar& y) const { return internal::pmul(x,y); }
76 };
77 
78 template<typename RealScalar> struct conj_helper<std::complex<RealScalar>, std::complex<RealScalar>, false,true>
79 {
80  typedef std::complex<RealScalar> Scalar;
81  EIGEN_STRONG_INLINE Scalar pmadd(const Scalar& x, const Scalar& y, const Scalar& c) const
82  { return c + pmul(x,y); }
83 
84  EIGEN_STRONG_INLINE Scalar pmul(const Scalar& x, const Scalar& y) const
85  { return Scalar(real(x)*real(y) + imag(x)*imag(y), imag(x)*real(y) - real(x)*imag(y)); }
86 };
87 
88 template<typename RealScalar> struct conj_helper<std::complex<RealScalar>, std::complex<RealScalar>, true,false>
89 {
90  typedef std::complex<RealScalar> Scalar;
91  EIGEN_STRONG_INLINE Scalar pmadd(const Scalar& x, const Scalar& y, const Scalar& c) const
92  { return c + pmul(x,y); }
93 
94  EIGEN_STRONG_INLINE Scalar pmul(const Scalar& x, const Scalar& y) const
95  { return Scalar(real(x)*real(y) + imag(x)*imag(y), real(x)*imag(y) - imag(x)*real(y)); }
96 };
97 
98 template<typename RealScalar> struct conj_helper<std::complex<RealScalar>, std::complex<RealScalar>, true,true>
99 {
100  typedef std::complex<RealScalar> Scalar;
101  EIGEN_STRONG_INLINE Scalar pmadd(const Scalar& x, const Scalar& y, const Scalar& c) const
102  { return c + pmul(x,y); }
103 
104  EIGEN_STRONG_INLINE Scalar pmul(const Scalar& x, const Scalar& y) const
105  { return Scalar(real(x)*real(y) - imag(x)*imag(y), - real(x)*imag(y) - imag(x)*real(y)); }
106 };
107 
108 template<typename RealScalar,bool Conj> struct conj_helper<std::complex<RealScalar>, RealScalar, Conj,false>
109 {
110  typedef std::complex<RealScalar> Scalar;
111  EIGEN_STRONG_INLINE Scalar pmadd(const Scalar& x, const RealScalar& y, const Scalar& c) const
112  { return padd(c, pmul(x,y)); }
113  EIGEN_STRONG_INLINE Scalar pmul(const Scalar& x, const RealScalar& y) const
114  { return conj_if<Conj>()(x)*y; }
115 };
116 
117 template<typename RealScalar,bool Conj> struct conj_helper<RealScalar, std::complex<RealScalar>, false,Conj>
118 {
119  typedef std::complex<RealScalar> Scalar;
120  EIGEN_STRONG_INLINE Scalar pmadd(const RealScalar& x, const Scalar& y, const Scalar& c) const
121  { return padd(c, pmul(x,y)); }
122  EIGEN_STRONG_INLINE Scalar pmul(const RealScalar& x, const Scalar& y) const
123  { return x*conj_if<Conj>()(y); }
124 };
125 
126 template<typename From,typename To> struct get_factor {
127  static EIGEN_STRONG_INLINE To run(const From& x) { return x; }
128 };
129 
130 template<typename Scalar> struct get_factor<Scalar,typename NumTraits<Scalar>::Real> {
131  static EIGEN_STRONG_INLINE typename NumTraits<Scalar>::Real run(const Scalar& x) { return real(x); }
132 };
133 
134 // Lightweight helper class to access matrix coefficients.
135 // Yes, this is somehow redundant with Map<>, but this version is much much lighter,
136 // and so I hope better compilation performance (time and code quality).
137 template<typename Scalar, typename Index, int StorageOrder>
138 class blas_data_mapper
139 {
140  public:
141  blas_data_mapper(Scalar* data, Index stride) : m_data(data), m_stride(stride) {}
142  EIGEN_STRONG_INLINE Scalar& operator()(Index i, Index j)
143  { return m_data[StorageOrder==RowMajor ? j + i*m_stride : i + j*m_stride]; }
144  protected:
145  Scalar* EIGEN_RESTRICT m_data;
146  Index m_stride;
147 };
148 
149 // lightweight helper class to access matrix coefficients (const version)
150 template<typename Scalar, typename Index, int StorageOrder>
151 class const_blas_data_mapper
152 {
153  public:
154  const_blas_data_mapper(const Scalar* data, Index stride) : m_data(data), m_stride(stride) {}
155  EIGEN_STRONG_INLINE const Scalar& operator()(Index i, Index j) const
156  { return m_data[StorageOrder==RowMajor ? j + i*m_stride : i + j*m_stride]; }
157  protected:
158  const Scalar* EIGEN_RESTRICT m_data;
159  Index m_stride;
160 };
161 
162 
163 /* Helper class to analyze the factors of a Product expression.
164  * In particular it allows to pop out operator-, scalar multiples,
165  * and conjugate */
166 template<typename XprType> struct blas_traits
167 {
168  typedef typename traits<XprType>::Scalar Scalar;
169  typedef const XprType& ExtractType;
170  typedef XprType _ExtractType;
171  enum {
173  IsTransposed = false,
174  NeedToConjugate = false,
175  HasUsableDirectAccess = ( (int(XprType::Flags)&DirectAccessBit)
176  && ( bool(XprType::IsVectorAtCompileTime)
177  || int(inner_stride_at_compile_time<XprType>::ret) == 1)
178  ) ? 1 : 0
179  };
180  typedef typename conditional<bool(HasUsableDirectAccess),
181  ExtractType,
182  typename _ExtractType::PlainObject
183  >::type DirectLinearAccessType;
184  static inline ExtractType extract(const XprType& x) { return x; }
185  static inline const Scalar extractScalarFactor(const XprType&) { return Scalar(1); }
186 };
187 
188 // pop conjugate
189 template<typename Scalar, typename NestedXpr>
190 struct blas_traits<CwiseUnaryOp<scalar_conjugate_op<Scalar>, NestedXpr> >
191  : blas_traits<NestedXpr>
192 {
193  typedef blas_traits<NestedXpr> Base;
194  typedef CwiseUnaryOp<scalar_conjugate_op<Scalar>, NestedXpr> XprType;
195  typedef typename Base::ExtractType ExtractType;
196 
197  enum {
199  NeedToConjugate = Base::NeedToConjugate ? 0 : IsComplex
200  };
201  static inline ExtractType extract(const XprType& x) { return Base::extract(x.nestedExpression()); }
202  static inline Scalar extractScalarFactor(const XprType& x) { return conj(Base::extractScalarFactor(x.nestedExpression())); }
203 };
204 
205 // pop scalar multiple
206 template<typename Scalar, typename NestedXpr>
207 struct blas_traits<CwiseUnaryOp<scalar_multiple_op<Scalar>, NestedXpr> >
208  : blas_traits<NestedXpr>
209 {
210  typedef blas_traits<NestedXpr> Base;
211  typedef CwiseUnaryOp<scalar_multiple_op<Scalar>, NestedXpr> XprType;
212  typedef typename Base::ExtractType ExtractType;
213  static inline ExtractType extract(const XprType& x) { return Base::extract(x.nestedExpression()); }
214  static inline Scalar extractScalarFactor(const XprType& x)
215  { return x.functor().m_other * Base::extractScalarFactor(x.nestedExpression()); }
216 };
217 
218 // pop opposite
219 template<typename Scalar, typename NestedXpr>
220 struct blas_traits<CwiseUnaryOp<scalar_opposite_op<Scalar>, NestedXpr> >
221  : blas_traits<NestedXpr>
222 {
223  typedef blas_traits<NestedXpr> Base;
224  typedef CwiseUnaryOp<scalar_opposite_op<Scalar>, NestedXpr> XprType;
225  typedef typename Base::ExtractType ExtractType;
226  static inline ExtractType extract(const XprType& x) { return Base::extract(x.nestedExpression()); }
227  static inline Scalar extractScalarFactor(const XprType& x)
228  { return - Base::extractScalarFactor(x.nestedExpression()); }
229 };
230 
231 // pop/push transpose
232 template<typename NestedXpr>
233 struct blas_traits<Transpose<NestedXpr> >
234  : blas_traits<NestedXpr>
235 {
236  typedef typename NestedXpr::Scalar Scalar;
237  typedef blas_traits<NestedXpr> Base;
238  typedef Transpose<NestedXpr> XprType;
239  typedef Transpose<const typename Base::_ExtractType> ExtractType; // const to get rid of a compile error; anyway blas traits are only used on the RHS
240  typedef Transpose<const typename Base::_ExtractType> _ExtractType;
241  typedef typename conditional<bool(Base::HasUsableDirectAccess),
242  ExtractType,
243  typename ExtractType::PlainObject
244  >::type DirectLinearAccessType;
245  enum {
246  IsTransposed = Base::IsTransposed ? 0 : 1
247  };
248  static inline ExtractType extract(const XprType& x) { return Base::extract(x.nestedExpression()); }
249  static inline Scalar extractScalarFactor(const XprType& x) { return Base::extractScalarFactor(x.nestedExpression()); }
250 };
251 
252 template<typename T>
253 struct blas_traits<const T>
254  : blas_traits<T>
255 {};
256 
257 template<typename T, bool HasUsableDirectAccess=blas_traits<T>::HasUsableDirectAccess>
258 struct extract_data_selector {
259  static const typename T::Scalar* run(const T& m)
260  {
261  return blas_traits<T>::extract(m).data();
262  }
263 };
264 
265 template<typename T>
266 struct extract_data_selector<T,false> {
267  static typename T::Scalar* run(const T&) { return 0; }
268 };
269 
270 template<typename T> const typename T::Scalar* extract_data(const T& m)
271 {
272  return extract_data_selector<T>::run(m);
273 }
274 
275 } // end namespace internal
276 
277 } // end namespace Eigen
278 
279 #endif // EIGEN_BLASUTIL_H