37 #ifndef VIGRA_LINEAR_SOLVE_HXX
38 #define VIGRA_LINEAR_SOLVE_HXX
42 #include "mathutil.hxx"
44 #include "singular_value_decomposition.hxx"
55 template <
class T,
class C1>
56 T determinantByLUDecomposition(MultiArrayView<2, T, C1>
const & a)
58 typedef MultiArrayShape<2>::type Shape;
61 vigra_precondition(n == m,
62 "determinant(): square matrix required.");
74 LU(i,j) = LU(i,j) -= s;
99 T givensCoefficients(T a, T b, T & c, T & s)
127 bool givensRotationMatrix(T a, T b, Matrix<T> & gTranspose)
131 givensCoefficients(a, b, gTranspose(0,0), gTranspose(0,1));
132 gTranspose(1,1) = gTranspose(0,0);
133 gTranspose(1,0) = -gTranspose(0,1);
141 givensReflectionMatrix(T a, T b, Matrix<T> & g)
145 givensCoefficients(a, b, g(0,0), g(0,1));
152 template <
class T,
class C1,
class C2>
154 qrGivensStepImpl(
MultiArrayIndex i, MultiArrayView<2, T, C1> r, MultiArrayView<2, T, C2> rhs)
156 typedef typename Matrix<T>::difference_type Shape;
161 vigra_precondition(m ==
rowCount(rhs),
162 "qrGivensStepImpl(): Matrix shape mismatch.");
164 Matrix<T> givens(2,2);
165 for(
int k=m-1; k>(int)i; --k)
167 if(!givensReflectionMatrix(r(k-1,i), r(k,i), givens))
170 r(k-1,i) = givens(0,0)*r(k-1,i) + givens(0,1)*r(k,i);
173 r.subarray(Shape(k-1,i+1), Shape(k+1,n)) = givens*r.subarray(Shape(k-1,i+1), Shape(k+1,n));
174 rhs.subarray(Shape(k-1,0), Shape(k+1,rhsCount)) = givens*rhs.subarray(Shape(k-1,0), Shape(k+1,rhsCount));
176 return r(i,i) != 0.0;
180 template <
class T,
class C1,
class C2,
class Permutation>
183 MultiArrayView<2, T, C1> &r, MultiArrayView<2, T, C2> &rhs, Permutation & permutation)
185 typedef typename Matrix<T>::difference_type Shape;
190 vigra_precondition(i < n && j < n,
191 "upperTriangularCyclicShiftColumns(): Shift indices out of range.");
192 vigra_precondition(m ==
rowCount(rhs),
193 "upperTriangularCyclicShiftColumns(): Matrix shape mismatch.");
205 permutation[k] = permutation[k+1];
210 Matrix<T> givens(2,2);
213 if(!givensReflectionMatrix(r(k,k), r(k+1,k), givens))
216 r(k,k) = givens(0,0)*r(k,k) + givens(0,1)*r(k+1,k);
219 r.subarray(Shape(k,k+1), Shape(k+2,n)) = givens*r.subarray(Shape(k,k+1), Shape(k+2,n));
220 rhs.subarray(Shape(k,0), Shape(k+2,rhsCount)) = givens*rhs.subarray(Shape(k,0), Shape(k+2,rhsCount));
225 template <
class T,
class C1,
class C2,
class Permutation>
228 MultiArrayView<2, T, C1> &r, MultiArrayView<2, T, C2> &rhs, Permutation & permutation)
230 typedef typename Matrix<T>::difference_type Shape;
235 vigra_precondition(i < n && j < n,
236 "upperTriangularSwapColumns(): Swap indices out of range.");
237 vigra_precondition(m ==
rowCount(rhs),
238 "upperTriangularSwapColumns(): Matrix shape mismatch.");
246 std::swap(permutation[i], permutation[j]);
248 Matrix<T> givens(2,2);
249 for(
int k=m-1; k>(int)i; --k)
251 if(!givensReflectionMatrix(r(k-1,i), r(k,i), givens))
254 r(k-1,i) = givens(0,0)*r(k-1,i) + givens(0,1)*r(k,i);
257 r.subarray(Shape(k-1,i+1), Shape(k+1,n)) = givens*r.subarray(Shape(k-1,i+1), Shape(k+1,n));
258 rhs.subarray(Shape(k-1,0), Shape(k+1,rhsCount)) = givens*rhs.subarray(Shape(k-1,0), Shape(k+1,rhsCount));
263 if(!givensReflectionMatrix(r(k,k), r(k+1,k), givens))
266 r(k,k) = givens(0,0)*r(k,k) + givens(0,1)*r(k+1,k);
269 r.subarray(Shape(k,k+1), Shape(k+2,n)) = givens*r.subarray(Shape(k,k+1), Shape(k+2,n));
270 rhs.subarray(Shape(k,0), Shape(k+2,rhsCount)) = givens*rhs.subarray(Shape(k,0), Shape(k+2,rhsCount));
275 template <
class T,
class C1,
class C2,
class U>
276 bool householderVector(MultiArrayView<2, T, C1>
const & v, MultiArrayView<2, T, C2> & u, U & vnorm)
278 vnorm = (v(0,0) > 0.0)
283 if(f == NumericTraits<U>::zero())
285 u.init(NumericTraits<T>::zero());
290 u(0,0) = (v(0,0) - vnorm) / f;
298 template <
class T,
class C1,
class C2,
class C3>
301 MultiArrayView<2, T, C2> & rhs, MultiArrayView<2, T, C3> & householderMatrix)
303 typedef typename Matrix<T>::difference_type Shape;
309 vigra_precondition(i < n && i < m,
310 "qrHouseholderStepImpl(): Index i out of range.");
314 bool nontrivial = householderVector(
columnVector(r, Shape(i,i), m), u, vnorm);
317 columnVector(r, Shape(i+1,i), m).init(NumericTraits<T>::zero());
329 return r(i,i) != 0.0;
332 template <
class T,
class C1,
class C2>
334 qrColumnHouseholderStep(
MultiArrayIndex i, MultiArrayView<2, T, C1> &r, MultiArrayView<2, T, C2> &rhs)
336 Matrix<T> dontStoreHouseholderVectors;
337 return qrHouseholderStepImpl(i, r, rhs, dontStoreHouseholderVectors);
340 template <
class T,
class C1,
class C2>
342 qrRowHouseholderStep(
MultiArrayIndex i, MultiArrayView<2, T, C1> &r, MultiArrayView<2, T, C2> & householderMatrix)
344 Matrix<T> dontTransformRHS;
345 MultiArrayView<2, T, StridedArrayTag> rt =
transpose(r),
347 return qrHouseholderStepImpl(i, rt, dontTransformRHS, ht);
351 template <
class T,
class C1,
class C2,
class SNType>
353 incrementalMaxSingularValueApproximation(MultiArrayView<2, T, C1>
const & newColumn,
354 MultiArrayView<2, T, C2> & z, SNType & v)
356 typedef typename Matrix<T>::difference_type Shape;
367 z(n,0) = s*newColumn(n,0);
371 template <
class T,
class C1,
class C2,
class SNType>
373 incrementalMinSingularValueApproximation(MultiArrayView<2, T, C1>
const & newColumn,
374 MultiArrayView<2, T, C2> & z, SNType & v,
double tolerance)
376 typedef typename Matrix<T>::difference_type Shape;
386 T gamma = newColumn(n,0);
399 z(n,0) = (s - c*yv) / gamma;
400 v *=
norm(gamma) /
hypot(c*gamma, v*(s - c*yv));
404 template <
class T,
class C1,
class C2,
class C3>
406 qrTransformToTriangularImpl(MultiArrayView<2, T, C1> & r, MultiArrayView<2, T, C2> & rhs, MultiArrayView<2, T, C3> & householder,
407 ArrayVector<MultiArrayIndex> & permutation,
double epsilon)
409 typedef typename Matrix<T>::difference_type Shape;
410 typedef typename NormTraits<MultiArrayView<2, T, C1> >::NormType NormType;
411 typedef typename NormTraits<MultiArrayView<2, T, C1> >::SquaredNormType SNType;
417 vigra_precondition(m >= n,
418 "qrTransformToTriangularImpl(): Coefficient matrix with at least as many rows as columns required.");
421 bool transformRHS = rhsCount > 0;
422 vigra_precondition(!transformRHS || m ==
rowCount(rhs),
423 "qrTransformToTriangularImpl(): RHS matrix shape mismatch.");
425 bool storeHouseholderSteps =
columnCount(householder) > 0;
426 vigra_precondition(!storeHouseholderSteps || r.shape() == householder.shape(),
427 "qrTransformToTriangularImpl(): Householder matrix shape mismatch.");
429 bool pivoting = permutation.size() > 0;
430 vigra_precondition(!pivoting || n == (
MultiArrayIndex)permutation.size(),
431 "qrTransformToTriangularImpl(): Permutation array size mismatch.");
436 Matrix<SNType> columnSquaredNorms;
439 columnSquaredNorms.reshape(Shape(1,n));
443 int pivot =
argMax(columnSquaredNorms);
447 std::swap(columnSquaredNorms[0], columnSquaredNorms[pivot]);
448 std::swap(permutation[0], permutation[pivot]);
452 qrHouseholderStepImpl(0, r, rhs, householder);
455 NormType maxApproxSingularValue =
norm(r(0,0)),
456 minApproxSingularValue = maxApproxSingularValue;
458 double tolerance = (epsilon == 0.0)
459 ? m*maxApproxSingularValue*NumericTraits<T>::epsilon()
462 bool simpleSingularValueApproximation = (n < 4);
463 Matrix<T> zmax, zmin;
464 if(minApproxSingularValue <= tolerance)
468 simpleSingularValueApproximation =
true;
470 if(!simpleSingularValueApproximation)
472 zmax.reshape(Shape(m,1));
473 zmin.reshape(Shape(m,1));
475 zmin(0,0) = 1.0 / r(0,0);
488 std::swap(columnSquaredNorms[k], columnSquaredNorms[pivot]);
489 std::swap(permutation[k], permutation[pivot]);
493 qrHouseholderStepImpl(k, r, rhs, householder);
495 if(simpleSingularValueApproximation)
497 NormType nv =
norm(r(k,k));
498 maxApproxSingularValue = std::max(nv, maxApproxSingularValue);
499 minApproxSingularValue = std::min(nv, minApproxSingularValue);
503 incrementalMaxSingularValueApproximation(
columnVector(r, Shape(0,k),k+1), zmax, maxApproxSingularValue);
504 incrementalMinSingularValueApproximation(
columnVector(r, Shape(0,k),k+1), zmin, minApproxSingularValue, tolerance);
508 Matrix<T> u(k+1,k+1), s(k+1, 1), v(k+1,k+1);
510 std::cerr <<
"estimate, svd " << k <<
": " << minApproxSingularValue <<
" " << s(k,0) <<
"\n";
514 tolerance = m*maxApproxSingularValue*NumericTraits<T>::epsilon();
516 if(minApproxSingularValue > tolerance)
521 return (
unsigned int)rank;
524 template <
class T,
class C1,
class C2>
526 qrTransformToUpperTriangular(MultiArrayView<2, T, C1> & r, MultiArrayView<2, T, C2> & rhs,
527 ArrayVector<MultiArrayIndex> & permutation,
double epsilon = 0.0)
529 Matrix<T> dontStoreHouseholderVectors;
530 return qrTransformToTriangularImpl(r, rhs, dontStoreHouseholderVectors, permutation, epsilon);
534 template <
class T,
class C1,
class C2,
class C3>
536 qrTransformToLowerTriangular(MultiArrayView<2, T, C1> & r, MultiArrayView<2, T, C2> & rhs, MultiArrayView<2, T, C3> & householderMatrix,
537 double epsilon = 0.0)
539 ArrayVector<MultiArrayIndex> permutation((
unsigned int)
rowCount(rhs));
542 Matrix<T> dontTransformRHS;
543 MultiArrayView<2, T, StridedArrayTag> rt =
transpose(r),
545 unsigned int rank = qrTransformToTriangularImpl(rt, dontTransformRHS, ht, permutation, epsilon);
548 Matrix<T> tempRHS(rhs);
555 template <
class T,
class C1,
class C2>
557 qrTransformToUpperTriangular(MultiArrayView<2, T, C1> & r, MultiArrayView<2, T, C2> & rhs,
558 double epsilon = 0.0)
560 ArrayVector<MultiArrayIndex> noPivoting;
562 return (qrTransformToUpperTriangular(r, rhs, noPivoting, epsilon) ==
567 template <
class T,
class C1,
class C2>
569 qrTransformToLowerTriangular(MultiArrayView<2, T, C1> & r, MultiArrayView<2, T, C2> & householder,
570 double epsilon = 0.0)
572 Matrix<T> noPivoting;
574 return (qrTransformToLowerTriangular(r, noPivoting, householder, epsilon) ==
579 template <
class T,
class C1,
class C2,
class Permutation>
580 void inverseRowPermutation(MultiArrayView<2, T, C1> &permuted, MultiArrayView<2, T, C2> &res,
581 Permutation
const & permutation)
585 res(permutation[l], k) = permuted(l,k);
588 template <
class T,
class C1,
class C2>
589 void applyHouseholderColumnReflections(MultiArrayView<2, T, C1>
const &householder, MultiArrayView<2, T, C2> &res)
591 typedef typename Matrix<T>::difference_type Shape;
596 for(
int k = m-1; k >= 0; --k)
598 MultiArrayView<2, T, C1> u =
columnVector(householder, Shape(k,k), n);
606 template <
class T,
class C1,
class C2,
class C3>
608 linearSolveQRReplace(MultiArrayView<2, T, C1> &A, MultiArrayView<2, T, C2> &b,
609 MultiArrayView<2, T, C3> & res,
610 double epsilon = 0.0)
612 typedef typename Matrix<T>::difference_type Shape;
622 "linearSolveQR(): RHS and solution must have the same number of columns.");
623 vigra_precondition(m ==
rowCount(b),
624 "linearSolveQR(): Coefficient matrix and RHS must have the same number of rows.");
625 vigra_precondition(n ==
rowCount(res),
626 "linearSolveQR(): Mismatch between column count of coefficient matrix and row count of solution.");
627 vigra_precondition(epsilon >= 0.0,
628 "linearSolveQR(): 'epsilon' must be non-negative.");
633 Matrix<T> householderMatrix(n, m);
634 MultiArrayView<2, T, StridedArrayTag> ht =
transpose(householderMatrix);
635 rank = (
MultiArrayIndex)detail::qrTransformToLowerTriangular(A, b, ht, epsilon);
636 res.subarray(Shape(rank,0), Shape(n, rhsCount)).init(NumericTraits<T>::zero());
640 MultiArrayView<2, T, C1> Asub = A.subarray(ul, Shape(m,rank));
641 detail::qrTransformToUpperTriangular(Asub, b, epsilon);
643 b.subarray(ul, Shape(rank,rhsCount)),
644 res.subarray(ul, Shape(rank, rhsCount)));
650 b.subarray(ul, Shape(rank, rhsCount)),
651 res.subarray(ul, Shape(rank, rhsCount)));
653 detail::applyHouseholderColumnReflections(householderMatrix.subarray(ul, Shape(n, rank)), res);
658 ArrayVector<MultiArrayIndex> permutation((
unsigned int)n);
662 rank = (
MultiArrayIndex)detail::qrTransformToUpperTriangular(A, b, permutation, epsilon);
664 Matrix<T> permutedSolution(n, rhsCount);
668 Matrix<T> householderMatrix(n, rank);
669 MultiArrayView<2, T, StridedArrayTag> ht =
transpose(householderMatrix);
670 MultiArrayView<2, T, C1> Asub = A.subarray(ul, Shape(rank,n));
671 detail::qrTransformToLowerTriangular(Asub, ht, epsilon);
673 b.subarray(ul, Shape(rank, rhsCount)),
674 permutedSolution.subarray(ul, Shape(rank, rhsCount)));
675 detail::applyHouseholderColumnReflections(householderMatrix, permutedSolution);
681 b.subarray(ul, Shape(rank,rhsCount)),
684 detail::inverseRowPermutation(permutedSolution, res, permutation);
686 return (
unsigned int)rank;
689 template <
class T,
class C1,
class C2,
class C3>
690 unsigned int linearSolveQR(MultiArrayView<2, T, C1>
const & A, MultiArrayView<2, T, C2>
const & b,
691 MultiArrayView<2, T, C3> & res)
693 Matrix<T> r(A), rhs(b);
694 return linearSolveQRReplace(r, rhs, res);
718 template <
class T,
class C1,
class C2>
722 vigra_precondition(n <=
rowCount(v),
723 "inverse(): input matrix must have at least as many rows as columns.");
725 "inverse(): shape of output matrix must be the transpose of the input matrix' shape.");
756 template <
class T,
class C>
760 vigra_precondition(
inverse(v, ret),
761 "inverse(): matrix is not invertible.");
766 TemporaryMatrix<T>
inverse(
const TemporaryMatrix<T> &v)
770 vigra_precondition(
inverse(v,
const_cast<TemporaryMatrix<T> &
>(v)),
771 "inverse(): matrix is not invertible.");
777 vigra_precondition(
inverse(v, ret),
778 "inverse(): matrix is not invertible.");
799 template <
class T,
class C1>
803 vigra_precondition(
rowCount(a) == n,
804 "determinant(): Square matrix required.");
806 for(
unsigned int k=0; k<method.size(); ++k)
807 method[k] = tolower(method[k]);
812 return a(0,0)*a(1,1) - a(0,1)*a(1,0);
815 return detail::determinantByLUDecomposition(a);
817 else if(method ==
"cholesky")
821 "determinant(): Cholesky method requires symmetric positive definite matrix.");
829 vigra_precondition(
false,
"determinant(): Unknown solution method.");
843 template <
class T,
class C1>
847 vigra_precondition(
rowCount(a) == n,
848 "logDeterminant(): Square matrix required.");
851 vigra_precondition(a(0,0) > 0.0,
852 "logDeterminant(): Matrix not positive definite.");
857 T det = a(0,0)*a(1,1) - a(0,1)*a(1,0);
858 vigra_precondition(det > 0.0,
859 "logDeterminant(): Matrix not positive definite.");
866 "logDeterminant(): Matrix not positive definite.");
891 template <
class T,
class C1,
class C2>
896 vigra_precondition(
rowCount(A) == n,
897 "choleskyDecomposition(): Input matrix must be square.");
899 "choleskyDecomposition(): Output matrix must have same shape as input matrix.");
901 "choleskyDecomposition(): Input matrix must be symmetric.");
911 s += L(k, i)*L(j, i);
913 L(j, k) = s = (A(j, k) - s)/L(k, k);
946 template <
class T,
class C1,
class C2,
class C3>
949 double epsilon = 0.0)
955 "qrDecomposition(): Matrix shape mismatch.");
957 q = identityMatrix<T>(m);
961 return ((
MultiArrayIndex)detail::qrTransformToUpperTriangular(r, tq, noPivoting, epsilon) == std::min(m,n));
966 template <
class T,
class C1,
class C2,
class C3>
998 template <
class T,
class C1,
class C2,
class C3>
1006 "linearSolveUpperTriangular(): square coefficient matrix required.");
1008 "linearSolveUpperTriangular(): matrix shape mismatch.");
1012 for(
int i=m-1; i>=0; --i)
1014 if(r(i,i) == NumericTraits<T>::zero())
1018 sum -= r(i, j) * x(j, k);
1019 x(i, k) = sum / r(i, i);
1049 template <
class T,
class C1,
class C2,
class C3>
1055 vigra_precondition(m ==
rowCount(l),
1056 "linearSolveLowerTriangular(): square coefficient matrix required.");
1058 "linearSolveLowerTriangular(): matrix shape mismatch.");
1064 if(l(i,i) == NumericTraits<T>::zero())
1068 sum -= l(i, j) * x(j, k);
1069 x(i, k) = sum / l(i, i);
1098 template <
class T,
class C1,
class C2,
class C3>
1157 template <
class T,
class C1,
class C2,
class C3>
1167 vigra_precondition(n <= m,
1168 "linearSolve(): Coefficient matrix A must have at least as many rows as columns.");
1169 vigra_precondition(n ==
rowCount(res) &&
1171 "linearSolve(): matrix shape mismatch.");
1173 for(
unsigned int k=0; k<method.size(); ++k)
1174 method[k] = (std::string::value_type)tolower(method[k]);
1176 if(method ==
"cholesky")
1179 "linearSolve(): Cholesky method requires square coefficient matrix.");
1185 else if(method ==
"qr")
1189 else if(method ==
"ne")
1193 else if(method ==
"svd")
1206 t(k,l) = NumericTraits<T>::zero();
1214 vigra_precondition(
false,
"linearSolve(): Unknown solution method.");
1236 #endif // VIGRA_LINEAR_SOLVE_HXX