25 #ifndef EIGEN_GENERAL_MATRIX_VECTOR_H
26 #define EIGEN_GENERAL_MATRIX_VECTOR_H
45 template<
typename Index,
typename LhsScalar,
bool ConjugateLhs,
typename RhsScalar,
bool ConjugateRhs,
int Version>
48 typedef typename scalar_product_traits<LhsScalar, RhsScalar>::ReturnType ResScalar;
51 Vectorizable = packet_traits<LhsScalar>::Vectorizable && packet_traits<RhsScalar>::Vectorizable
52 &&
int(packet_traits<LhsScalar>::size)==
int(packet_traits<RhsScalar>::size),
53 LhsPacketSize = Vectorizable ? packet_traits<LhsScalar>::size : 1,
54 RhsPacketSize = Vectorizable ? packet_traits<RhsScalar>::size : 1,
55 ResPacketSize = Vectorizable ? packet_traits<ResScalar>::size : 1
58 typedef typename packet_traits<LhsScalar>::type _LhsPacket;
59 typedef typename packet_traits<RhsScalar>::type _RhsPacket;
60 typedef typename packet_traits<ResScalar>::type _ResPacket;
62 typedef typename conditional<Vectorizable,_LhsPacket,LhsScalar>::type LhsPacket;
63 typedef typename conditional<Vectorizable,_RhsPacket,RhsScalar>::type RhsPacket;
64 typedef typename conditional<Vectorizable,_ResPacket,ResScalar>::type ResPacket;
67 Index rows, Index cols,
68 const LhsScalar* lhs, Index lhsStride,
69 const RhsScalar* rhs, Index rhsIncr,
71 #ifdef EIGEN_INTERNAL_DEBUGGING
77 #ifdef _EIGEN_ACCUMULATE_PACKETS
78 #error _EIGEN_ACCUMULATE_PACKETS has already been defined
80 #define _EIGEN_ACCUMULATE_PACKETS(A0,A13,A2) \
82 padd(pload<ResPacket>(&res[j]), \
84 padd(pcj.pmul(EIGEN_CAT(ploa , A0)<LhsPacket>(&lhs0[j]), ptmp0), \
85 pcj.pmul(EIGEN_CAT(ploa , A13)<LhsPacket>(&lhs1[j]), ptmp1)), \
86 padd(pcj.pmul(EIGEN_CAT(ploa , A2)<LhsPacket>(&lhs2[j]), ptmp2), \
87 pcj.pmul(EIGEN_CAT(ploa , A13)<LhsPacket>(&lhs3[j]), ptmp3)) )))
89 conj_helper<LhsScalar,RhsScalar,ConjugateLhs,ConjugateRhs> cj;
90 conj_helper<LhsPacket,RhsPacket,ConjugateLhs,ConjugateRhs> pcj;
94 enum { AllAligned = 0, EvenAligned, FirstAligned, NoneAligned };
95 const Index columnsAtOnce = 4;
96 const Index peels = 2;
97 const Index LhsPacketAlignedMask = LhsPacketSize-1;
98 const Index ResPacketAlignedMask = ResPacketSize-1;
99 const Index PeelAlignedMask = ResPacketSize*peels-1;
100 const Index size = rows;
104 Index alignedStart = internal::first_aligned(res,size);
105 Index alignedSize = ResPacketSize>1 ? alignedStart + ((size-alignedStart) & ~ResPacketAlignedMask) : 0;
106 const Index peeledSize = peels>1 ? alignedStart + ((alignedSize-alignedStart) & ~PeelAlignedMask) : alignedStart;
108 const Index alignmentStep = LhsPacketSize>1 ? (LhsPacketSize - lhsStride % LhsPacketSize) & LhsPacketAlignedMask : 0;
109 Index alignmentPattern = alignmentStep==0 ? AllAligned
110 : alignmentStep==(LhsPacketSize/2) ? EvenAligned
114 const Index lhsAlignmentOffset = internal::first_aligned(lhs,size);
117 Index skipColumns = 0;
119 if( (
size_t(lhs)%
sizeof(LhsScalar)) || (
size_t(res)%
sizeof(ResScalar)) )
124 else if (LhsPacketSize>1)
128 while (skipColumns<LhsPacketSize &&
129 alignedStart != ((lhsAlignmentOffset + alignmentStep*skipColumns)%LhsPacketSize))
131 if (skipColumns==LhsPacketSize)
134 alignmentPattern = NoneAligned;
139 skipColumns = (std::min)(skipColumns,cols);
144 || (skipColumns + columnsAtOnce >= cols)
145 || LhsPacketSize > size
146 || (
size_t(lhs+alignedStart+lhsStride*skipColumns)%
sizeof(LhsPacket))==0);
148 else if(Vectorizable)
152 alignmentPattern = AllAligned;
155 Index offset1 = (FirstAligned && alignmentStep==1?3:1);
156 Index offset3 = (FirstAligned && alignmentStep==1?1:3);
158 Index columnBound = ((cols-skipColumns)/columnsAtOnce)*columnsAtOnce + skipColumns;
159 for (Index i=skipColumns; i<columnBound; i+=columnsAtOnce)
161 RhsPacket ptmp0 = pset1<RhsPacket>(alpha*rhs[i*rhsIncr]),
162 ptmp1 = pset1<RhsPacket>(alpha*rhs[(i+offset1)*rhsIncr]),
163 ptmp2 = pset1<RhsPacket>(alpha*rhs[(i+2)*rhsIncr]),
164 ptmp3 = pset1<RhsPacket>(alpha*rhs[(i+offset3)*rhsIncr]);
167 const LhsScalar *lhs0 = lhs + i*lhsStride, *lhs1 = lhs + (i+offset1)*lhsStride,
168 *lhs2 = lhs + (i+2)*lhsStride, *lhs3 = lhs + (i+offset3)*lhsStride;
174 for (Index j=0; j<alignedStart; ++j)
176 res[j] = cj.pmadd(lhs0[j],
pfirst(ptmp0), res[j]);
177 res[j] = cj.pmadd(lhs1[j],
pfirst(ptmp1), res[j]);
178 res[j] = cj.pmadd(lhs2[j],
pfirst(ptmp2), res[j]);
179 res[j] = cj.pmadd(lhs3[j],
pfirst(ptmp3), res[j]);
182 if (alignedSize>alignedStart)
184 switch(alignmentPattern)
187 for (Index j = alignedStart; j<alignedSize; j+=ResPacketSize)
191 for (Index j = alignedStart; j<alignedSize; j+=ResPacketSize)
197 LhsPacket A00, A01, A02, A03, A10, A11, A12, A13;
200 A01 = pload<LhsPacket>(&lhs1[alignedStart-1]);
201 A02 = pload<LhsPacket>(&lhs2[alignedStart-2]);
202 A03 = pload<LhsPacket>(&lhs3[alignedStart-3]);
204 for (Index j = alignedStart; j<peeledSize; j+=peels*ResPacketSize)
206 A11 = pload<LhsPacket>(&lhs1[j-1+LhsPacketSize]); palign<1>(A01,A11);
207 A12 = pload<LhsPacket>(&lhs2[j-2+LhsPacketSize]); palign<2>(A02,A12);
208 A13 = pload<LhsPacket>(&lhs3[j-3+LhsPacketSize]); palign<3>(A03,A13);
210 A00 = pload<LhsPacket>(&lhs0[j]);
211 A10 = pload<LhsPacket>(&lhs0[j+LhsPacketSize]);
212 T0 = pcj.pmadd(A00, ptmp0, pload<ResPacket>(&res[j]));
213 T1 = pcj.pmadd(A10, ptmp0, pload<ResPacket>(&res[j+ResPacketSize]));
215 T0 = pcj.pmadd(A01, ptmp1, T0);
216 A01 = pload<LhsPacket>(&lhs1[j-1+2*LhsPacketSize]); palign<1>(A11,A01);
217 T0 = pcj.pmadd(A02, ptmp2, T0);
218 A02 = pload<LhsPacket>(&lhs2[j-2+2*LhsPacketSize]); palign<2>(A12,A02);
219 T0 = pcj.pmadd(A03, ptmp3, T0);
221 A03 = pload<LhsPacket>(&lhs3[j-3+2*LhsPacketSize]); palign<3>(A13,A03);
222 T1 = pcj.pmadd(A11, ptmp1, T1);
223 T1 = pcj.pmadd(A12, ptmp2, T1);
224 T1 = pcj.pmadd(A13, ptmp3, T1);
225 pstore(&res[j+ResPacketSize],T1);
228 for (Index j = peeledSize; j<alignedSize; j+=ResPacketSize)
232 for (Index j = alignedStart; j<alignedSize; j+=ResPacketSize)
240 for (Index j=alignedSize; j<size; ++j)
242 res[j] = cj.pmadd(lhs0[j],
pfirst(ptmp0), res[j]);
243 res[j] = cj.pmadd(lhs1[j],
pfirst(ptmp1), res[j]);
244 res[j] = cj.pmadd(lhs2[j],
pfirst(ptmp2), res[j]);
245 res[j] = cj.pmadd(lhs3[j],
pfirst(ptmp3), res[j]);
251 Index start = columnBound;
254 for (Index k=start; k<end; ++k)
256 RhsPacket ptmp0 = pset1<RhsPacket>(alpha*rhs[k*rhsIncr]);
257 const LhsScalar* lhs0 = lhs + k*lhsStride;
263 for (Index j=0; j<alignedStart; ++j)
264 res[j] += cj.pmul(lhs0[j],
pfirst(ptmp0));
266 if ((
size_t(lhs0+alignedStart)%
sizeof(LhsPacket))==0)
267 for (Index i = alignedStart;i<alignedSize;i+=ResPacketSize)
268 pstore(&res[i], pcj.pmadd(ploadu<LhsPacket>(&lhs0[i]), ptmp0, pload<ResPacket>(&res[i])));
270 for (Index i = alignedStart;i<alignedSize;i+=ResPacketSize)
271 pstore(&res[i], pcj.pmadd(ploadu<LhsPacket>(&lhs0[i]), ptmp0, pload<ResPacket>(&res[i])));
275 for (Index i=alignedSize; i<size; ++i)
276 res[i] += cj.pmul(lhs0[i],
pfirst(ptmp0));
286 }
while(Vectorizable);
287 #undef _EIGEN_ACCUMULATE_PACKETS
301 template<
typename Index,
typename LhsScalar,
bool ConjugateLhs,
typename RhsScalar,
bool ConjugateRhs,
int Version>
304 typedef typename scalar_product_traits<LhsScalar, RhsScalar>::ReturnType ResScalar;
307 Vectorizable = packet_traits<LhsScalar>::Vectorizable && packet_traits<RhsScalar>::Vectorizable
308 &&
int(packet_traits<LhsScalar>::size)==
int(packet_traits<RhsScalar>::size),
309 LhsPacketSize = Vectorizable ? packet_traits<LhsScalar>::size : 1,
310 RhsPacketSize = Vectorizable ? packet_traits<RhsScalar>::size : 1,
311 ResPacketSize = Vectorizable ? packet_traits<ResScalar>::size : 1
314 typedef typename packet_traits<LhsScalar>::type _LhsPacket;
315 typedef typename packet_traits<RhsScalar>::type _RhsPacket;
316 typedef typename packet_traits<ResScalar>::type _ResPacket;
318 typedef typename conditional<Vectorizable,_LhsPacket,LhsScalar>::type LhsPacket;
319 typedef typename conditional<Vectorizable,_RhsPacket,RhsScalar>::type RhsPacket;
320 typedef typename conditional<Vectorizable,_ResPacket,ResScalar>::type ResPacket;
323 Index rows, Index cols,
324 const LhsScalar* lhs, Index lhsStride,
325 const RhsScalar* rhs, Index rhsIncr,
326 ResScalar* res, Index resIncr,
331 #ifdef _EIGEN_ACCUMULATE_PACKETS
332 #error _EIGEN_ACCUMULATE_PACKETS has already been defined
335 #define _EIGEN_ACCUMULATE_PACKETS(A0,A13,A2) {\
336 RhsPacket b = pload<RhsPacket>(&rhs[j]); \
337 ptmp0 = pcj.pmadd(EIGEN_CAT(ploa,A0) <LhsPacket>(&lhs0[j]), b, ptmp0); \
338 ptmp1 = pcj.pmadd(EIGEN_CAT(ploa,A13)<LhsPacket>(&lhs1[j]), b, ptmp1); \
339 ptmp2 = pcj.pmadd(EIGEN_CAT(ploa,A2) <LhsPacket>(&lhs2[j]), b, ptmp2); \
340 ptmp3 = pcj.pmadd(EIGEN_CAT(ploa,A13)<LhsPacket>(&lhs3[j]), b, ptmp3); }
342 conj_helper<LhsScalar,RhsScalar,ConjugateLhs,ConjugateRhs> cj;
343 conj_helper<LhsPacket,RhsPacket,ConjugateLhs,ConjugateRhs> pcj;
345 enum { AllAligned=0, EvenAligned=1, FirstAligned=2, NoneAligned=3 };
346 const Index rowsAtOnce = 4;
347 const Index peels = 2;
348 const Index RhsPacketAlignedMask = RhsPacketSize-1;
349 const Index LhsPacketAlignedMask = LhsPacketSize-1;
350 const Index PeelAlignedMask = RhsPacketSize*peels-1;
351 const Index depth = cols;
356 Index alignedStart = internal::first_aligned(rhs, depth);
357 Index alignedSize = RhsPacketSize>1 ? alignedStart + ((depth-alignedStart) & ~RhsPacketAlignedMask) : 0;
358 const Index peeledSize = peels>1 ? alignedStart + ((alignedSize-alignedStart) & ~PeelAlignedMask) : alignedStart;
360 const Index alignmentStep = LhsPacketSize>1 ? (LhsPacketSize - lhsStride % LhsPacketSize) & LhsPacketAlignedMask : 0;
361 Index alignmentPattern = alignmentStep==0 ? AllAligned
362 : alignmentStep==(LhsPacketSize/2) ? EvenAligned
366 const Index lhsAlignmentOffset = internal::first_aligned(lhs,depth);
371 if( (
sizeof(LhsScalar)!=
sizeof(RhsScalar)) || (
size_t(lhs)%
sizeof(LhsScalar)) || (
size_t(rhs)%
sizeof(RhsScalar)) )
376 else if (LhsPacketSize>1)
380 while (skipRows<LhsPacketSize &&
381 alignedStart != ((lhsAlignmentOffset + alignmentStep*skipRows)%LhsPacketSize))
383 if (skipRows==LhsPacketSize)
386 alignmentPattern = NoneAligned;
391 skipRows = (std::min)(skipRows,Index(rows));
396 || (skipRows + rowsAtOnce >= rows)
397 || LhsPacketSize > depth
398 || (
size_t(lhs+alignedStart+lhsStride*skipRows)%
sizeof(LhsPacket))==0);
400 else if(Vectorizable)
404 alignmentPattern = AllAligned;
407 Index offset1 = (FirstAligned && alignmentStep==1?3:1);
408 Index offset3 = (FirstAligned && alignmentStep==1?1:3);
410 Index rowBound = ((rows-skipRows)/rowsAtOnce)*rowsAtOnce + skipRows;
411 for (Index i=skipRows; i<rowBound; i+=rowsAtOnce)
414 ResScalar tmp1 = ResScalar(0), tmp2 = ResScalar(0), tmp3 = ResScalar(0);
417 const LhsScalar *lhs0 = lhs + i*lhsStride, *lhs1 = lhs + (i+offset1)*lhsStride,
418 *lhs2 = lhs + (i+2)*lhsStride, *lhs3 = lhs + (i+offset3)*lhsStride;
423 ResPacket ptmp0 = pset1<ResPacket>(ResScalar(0)), ptmp1 = pset1<ResPacket>(ResScalar(0)),
424 ptmp2 = pset1<ResPacket>(ResScalar(0)), ptmp3 = pset1<ResPacket>(ResScalar(0));
428 for (Index j=0; j<alignedStart; ++j)
430 RhsScalar b = rhs[j];
431 tmp0 += cj.pmul(lhs0[j],b); tmp1 += cj.pmul(lhs1[j],b);
432 tmp2 += cj.pmul(lhs2[j],b); tmp3 += cj.pmul(lhs3[j],b);
435 if (alignedSize>alignedStart)
437 switch(alignmentPattern)
440 for (Index j = alignedStart; j<alignedSize; j+=RhsPacketSize)
444 for (Index j = alignedStart; j<alignedSize; j+=RhsPacketSize)
456 LhsPacket A01, A02, A03, A11, A12, A13;
457 A01 = pload<LhsPacket>(&lhs1[alignedStart-1]);
458 A02 = pload<LhsPacket>(&lhs2[alignedStart-2]);
459 A03 = pload<LhsPacket>(&lhs3[alignedStart-3]);
461 for (Index j = alignedStart; j<peeledSize; j+=peels*RhsPacketSize)
463 RhsPacket b = pload<RhsPacket>(&rhs[j]);
464 A11 = pload<LhsPacket>(&lhs1[j-1+LhsPacketSize]); palign<1>(A01,A11);
465 A12 = pload<LhsPacket>(&lhs2[j-2+LhsPacketSize]); palign<2>(A02,A12);
466 A13 = pload<LhsPacket>(&lhs3[j-3+LhsPacketSize]); palign<3>(A03,A13);
468 ptmp0 = pcj.pmadd(pload<LhsPacket>(&lhs0[j]), b, ptmp0);
469 ptmp1 = pcj.pmadd(A01, b, ptmp1);
470 A01 = pload<LhsPacket>(&lhs1[j-1+2*LhsPacketSize]); palign<1>(A11,A01);
471 ptmp2 = pcj.pmadd(A02, b, ptmp2);
472 A02 = pload<LhsPacket>(&lhs2[j-2+2*LhsPacketSize]); palign<2>(A12,A02);
473 ptmp3 = pcj.pmadd(A03, b, ptmp3);
474 A03 = pload<LhsPacket>(&lhs3[j-3+2*LhsPacketSize]); palign<3>(A13,A03);
476 b = pload<RhsPacket>(&rhs[j+RhsPacketSize]);
477 ptmp0 = pcj.pmadd(pload<LhsPacket>(&lhs0[j+LhsPacketSize]), b, ptmp0);
478 ptmp1 = pcj.pmadd(A11, b, ptmp1);
479 ptmp2 = pcj.pmadd(A12, b, ptmp2);
480 ptmp3 = pcj.pmadd(A13, b, ptmp3);
483 for (Index j = peeledSize; j<alignedSize; j+=RhsPacketSize)
487 for (Index j = alignedStart; j<alignedSize; j+=RhsPacketSize)
500 for (Index j=alignedSize; j<depth; ++j)
502 RhsScalar b = rhs[j];
503 tmp0 += cj.pmul(lhs0[j],b); tmp1 += cj.pmul(lhs1[j],b);
504 tmp2 += cj.pmul(lhs2[j],b); tmp3 += cj.pmul(lhs3[j],b);
506 res[i*resIncr] += alpha*tmp0;
507 res[(i+offset1)*resIncr] += alpha*tmp1;
508 res[(i+2)*resIncr] += alpha*tmp2;
509 res[(i+offset3)*resIncr] += alpha*tmp3;
514 Index start = rowBound;
517 for (Index i=start; i<end; ++i)
520 ResPacket ptmp0 = pset1<ResPacket>(tmp0);
521 const LhsScalar* lhs0 = lhs + i*lhsStride;
524 for (Index j=0; j<alignedStart; ++j)
525 tmp0 += cj.pmul(lhs0[j], rhs[j]);
527 if (alignedSize>alignedStart)
530 if ((
size_t(lhs0+alignedStart)%
sizeof(LhsPacket))==0)
531 for (Index j = alignedStart;j<alignedSize;j+=RhsPacketSize)
532 ptmp0 = pcj.pmadd(pload<LhsPacket>(&lhs0[j]), pload<RhsPacket>(&rhs[j]), ptmp0);
534 for (Index j = alignedStart;j<alignedSize;j+=RhsPacketSize)
535 ptmp0 = pcj.pmadd(ploadu<LhsPacket>(&lhs0[j]), pload<RhsPacket>(&rhs[j]), ptmp0);
541 for (Index j=alignedSize; j<depth; ++j)
542 tmp0 += cj.pmul(lhs0[j], rhs[j]);
543 res[i*resIncr] += alpha*tmp0;
553 }
while(Vectorizable);
555 #undef _EIGEN_ACCUMULATE_PACKETS
563 #endif // EIGEN_GENERAL_MATRIX_VECTOR_H