37 #ifndef VIGRA_MATRIX_HXX
38 #define VIGRA_MATRIX_HXX
43 #include "multi_array.hxx"
44 #include "mathutil.hxx"
45 #include "numerictraits.hxx"
64 template <
class T,
class C>
66 rowCount(
const MultiArrayView<2, T, C> &x);
68 template <
class T,
class C>
72 template <
class T,
class C>
73 inline MultiArrayView <2, T, C>
76 template <
class T,
class C>
77 inline MultiArrayView <2, T, C>
80 template <
class T,
class ALLOC>
81 class TemporaryMatrix;
83 template <
class T,
class C1,
class C2>
84 void transpose(
const MultiArrayView<2, T, C1> &v, MultiArrayView<2, T, C2> &r);
86 template <
class T,
class C>
89 enum RawArrayMemoryLayout { RowMajor, ColumnMajor };
119 template <
class T,
class ALLOC = std::allocator<T> >
127 typedef TemporaryMatrix<T, ALLOC> temp_type;
136 typedef ALLOC allocator_type;
155 ALLOC
const & alloc = allocator_type())
164 Matrix(difference_type_1 rows, difference_type_1 columns,
165 ALLOC
const & alloc = allocator_type())
166 :
BaseType(difference_type(rows, columns), alloc)
175 allocator_type
const & alloc = allocator_type())
184 Matrix(difference_type_1 rows, difference_type_1 columns, const_reference init,
185 allocator_type
const & alloc = allocator_type())
186 :
BaseType(difference_type(rows, columns), init, alloc)
196 Matrix(
const difference_type &shape, const_pointer init, RawArrayMemoryLayout layout = RowMajor,
197 allocator_type
const & alloc = allocator_type())
200 if(layout == RowMajor)
202 difference_type trans(shape[1], shape[0]);
218 Matrix(difference_type_1 rows, difference_type_1 columns, const_pointer init, RawArrayMemoryLayout layout = RowMajor,
219 allocator_type
const & alloc = allocator_type())
220 :
BaseType(difference_type(rows, columns), alloc)
222 if(layout == RowMajor)
224 difference_type trans(columns, rows);
250 Matrix(
const TemporaryMatrix<T, ALLOC> &rhs)
253 this->
swap(
const_cast<TemporaryMatrix<T, ALLOC> &
>(rhs));
259 template<
class U,
class C>
283 if(this->
shape() == rhs.shape())
286 this->
swap(
const_cast<TemporaryMatrix<T, ALLOC> &
>(rhs));
296 template <
class U,
class C>
314 void reshape(difference_type_1 rows, difference_type_1 columns)
321 void reshape(difference_type_1 rows, difference_type_1 columns, const_reference init)
335 void reshape(difference_type
const & shape, const_reference init)
389 value_type &
operator()(difference_type_1 row, difference_type_1 column);
395 value_type
operator()(difference_type_1 row, difference_type_1 column)
const;
399 typename NormTraits<Matrix>::SquaredNormType
squaredNorm()
const;
403 typename NormTraits<Matrix>::NormType
norm()
const;
418 template <
class U,
class C>
427 template <
class U,
class C>
436 template <
class U,
class C>
445 template <
class U,
class C>
489 template <
class T,
class ALLOC = std::allocator<T> >
490 class TemporaryMatrix
491 :
public Matrix<T, ALLOC>
493 typedef Matrix<T, ALLOC> BaseType;
495 typedef Matrix<T, ALLOC> matrix_type;
496 typedef TemporaryMatrix<T, ALLOC> temp_type;
498 typedef typename BaseType::value_type value_type;
499 typedef typename BaseType::pointer pointer;
500 typedef typename BaseType::const_pointer const_pointer;
501 typedef typename BaseType::reference reference;
502 typedef typename BaseType::const_reference const_reference;
503 typedef typename BaseType::difference_type difference_type;
504 typedef typename BaseType::difference_type_1 difference_type_1;
505 typedef ALLOC allocator_type;
507 TemporaryMatrix(difference_type
const & shape)
508 : BaseType(shape, ALLOC())
511 TemporaryMatrix(difference_type
const & shape, const_reference init)
512 : BaseType(shape, init, ALLOC())
515 TemporaryMatrix(difference_type_1 rows, difference_type_1 columns)
516 : BaseType(rows, columns, ALLOC())
519 TemporaryMatrix(difference_type_1 rows, difference_type_1 columns, const_reference init)
520 : BaseType(rows, columns, init, ALLOC())
523 template<
class U,
class C>
524 TemporaryMatrix(
const MultiArrayView<2, U, C> &rhs)
528 TemporaryMatrix(
const TemporaryMatrix &rhs)
531 this->
swap(const_cast<TemporaryMatrix &>(rhs));
535 TemporaryMatrix & init(
const U & init)
541 template <
class U,
class C>
542 TemporaryMatrix &
operator+=(MultiArrayView<2, U, C>
const & other)
548 template <
class U,
class C>
549 TemporaryMatrix &
operator-=(MultiArrayView<2, U, C>
const & other)
555 template <
class U,
class C>
556 TemporaryMatrix &
operator*=(MultiArrayView<2, U, C>
const & other)
562 template <
class U,
class C>
563 TemporaryMatrix &
operator/=(MultiArrayView<2, U, C>
const & other)
594 TemporaryMatrix &
operator=(
const TemporaryMatrix &rhs);
611 template <
class T,
class C>
624 template <
class T,
class C>
637 template <
class T,
class C>
653 template <
class T,
class C>
658 return m.
subarray(first, Shape(first[0]+1, end));
667 template <
class T,
class C>
682 template <
class T,
class C>
687 return m.
subarray(first, Shape(end, first[1]+1));
701 template <
class T,
class C>
707 return m.
subarray(Shape(first, 0), Shape(end, 1));
708 vigra_precondition(
rowCount(m) == 1,
709 "linalg::subVector(): Input must be a vector (1xN or Nx1).");
710 return m.
subarray(Shape(0, first), Shape(1, end));
719 template <
class T,
class C>
729 if(m(j, i) != m(i, j))
741 template <
class T,
class C>
742 typename NumericTraits<T>::Promote
745 typedef typename NumericTraits<T>::Promote SumType;
748 vigra_precondition(size ==
columnCount(m),
"linalg::trace(): Matrix must be square.");
750 SumType sum = NumericTraits<SumType>::zero();
756 #ifdef DOXYGEN // documentation only -- function is already defined in vigra/multi_array.hxx
764 template <
class T,
class ALLOC>
765 typename Matrix<T, ALLLOC>::SquaredNormType
774 template <
class T,
class ALLOC>
775 typename Matrix<T, ALLLOC>::NormType
776 norm(
const Matrix<T, ALLLOC> &a);
786 template <
class T,
class C>
791 "identityMatrix(): Matrix must be square.");
794 r(j, i) = NumericTraits<T>::zero();
795 r(i, i) = NumericTraits<T>::one();
813 TemporaryMatrix<T> ret(size, size, NumericTraits<T>::zero());
815 ret(i, i) = NumericTraits<T>::one();
819 template <
class T,
class C1,
class C2>
824 "diagonalMatrix(): result must be a square matrix.");
837 template <
class T,
class C1,
class C2>
841 "diagonalMatrix(): input must be a vector.");
842 r.
init(NumericTraits<T>::zero());
865 template <
class T,
class C>
869 "diagonalMatrix(): input must be a vector.");
871 TemporaryMatrix<T> ret(size, size, NumericTraits<T>::zero());
887 template <
class T,
class C1,
class C2>
893 "transpose(): arrays must have transposed shapes.");
916 template <
class T,
class C>
931 template <
class T,
class C1,
class C2>
932 inline TemporaryMatrix<T>
935 typedef typename TemporaryMatrix<T>::difference_type Shape;
939 "joinVertically(): shape mismatch.");
943 TemporaryMatrix<T> t(ma + mb, n, T());
944 t.subarray(Shape(0,0), Shape(ma, n)) = a;
945 t.subarray(Shape(ma,0), Shape(ma+mb, n)) = b;
957 template <
class T,
class C1,
class C2>
958 inline TemporaryMatrix<T>
961 typedef typename TemporaryMatrix<T>::difference_type Shape;
964 vigra_precondition(m ==
rowCount(b),
965 "joinHorizontally(): shape mismatch.");
969 TemporaryMatrix<T> t(m, na + nb, T());
970 t.subarray(Shape(0,0), Shape(m, na)) = a;
971 t.subarray(Shape(0, na), Shape(m, na + nb)) = b;
985 template <
class T,
class C1,
class C2>
987 unsigned int verticalCount,
unsigned int horizontalCount)
993 "repeatMatrix(): Shape mismatch.");
999 r.
subarray(Shape(k*m, l*n), Shape((k+1)*m, (l+1)*n)) = v;
1015 template <
class T,
class C>
1020 TemporaryMatrix<T> ret(verticalCount*m, horizontalCount*n);
1032 template <
class T,
class C1,
class C2,
class C3>
1040 "add(): Matrix shapes must agree.");
1044 r(j, i) = a(j, i) + b(j, i);
1057 template <
class T,
class C1,
class C2>
1058 inline TemporaryMatrix<T>
1061 return TemporaryMatrix<T>(a) += b;
1064 template <
class T,
class C>
1065 inline TemporaryMatrix<T>
1068 return const_cast<TemporaryMatrix<T> &
>(a) += b;
1071 template <
class T,
class C>
1072 inline TemporaryMatrix<T>
1073 operator+(
const MultiArrayView<2, T, C> &a,
const TemporaryMatrix<T> &b)
1075 return const_cast<TemporaryMatrix<T> &
>(b) += a;
1079 inline TemporaryMatrix<T>
1080 operator+(
const TemporaryMatrix<T> &a,
const TemporaryMatrix<T> &b)
1082 return const_cast<TemporaryMatrix<T> &
>(a) += b;
1092 template <
class T,
class C>
1093 inline TemporaryMatrix<T>
1096 return TemporaryMatrix<T>(a) += b;
1100 inline TemporaryMatrix<T>
1101 operator+(
const TemporaryMatrix<T> &a, T b)
1103 return const_cast<TemporaryMatrix<T> &
>(a) += b;
1113 template <
class T,
class C>
1114 inline TemporaryMatrix<T>
1117 return TemporaryMatrix<T>(b) += a;
1121 inline TemporaryMatrix<T>
1122 operator+(T a,
const TemporaryMatrix<T> &b)
1124 return const_cast<TemporaryMatrix<T> &
>(b) += a;
1134 template <
class T,
class C1,
class C2,
class C3>
1142 "subtract(): Matrix shapes must agree.");
1146 r(j, i) = a(j, i) - b(j, i);
1159 template <
class T,
class C1,
class C2>
1160 inline TemporaryMatrix<T>
1163 return TemporaryMatrix<T>(a) -= b;
1166 template <
class T,
class C>
1167 inline TemporaryMatrix<T>
1170 return const_cast<TemporaryMatrix<T> &
>(a) -= b;
1173 template <
class T,
class C>
1175 operator-(
const MultiArrayView<2, T, C> &a,
const TemporaryMatrix<T> &b)
1179 vigra_precondition(rows == b.rowCount() && cols == b.columnCount(),
1180 "Matrix::operator-(): Shape mismatch.");
1184 const_cast<TemporaryMatrix<T> &
>(b)(j, i) = a(j, i) - b(j, i);
1189 inline TemporaryMatrix<T>
1190 operator-(
const TemporaryMatrix<T> &a,
const TemporaryMatrix<T> &b)
1192 return const_cast<TemporaryMatrix<T> &
>(a) -= b;
1202 template <
class T,
class C>
1203 inline TemporaryMatrix<T>
1206 return TemporaryMatrix<T>(a) *= -NumericTraits<T>::one();
1210 inline TemporaryMatrix<T>
1213 return const_cast<TemporaryMatrix<T> &
>(a) *= -NumericTraits<T>::one();
1223 template <
class T,
class C>
1224 inline TemporaryMatrix<T>
1227 return TemporaryMatrix<T>(a) -= b;
1231 inline TemporaryMatrix<T>
1232 operator-(
const TemporaryMatrix<T> &a, T b)
1234 return const_cast<TemporaryMatrix<T> &
>(a) -= b;
1244 template <
class T,
class C>
1245 inline TemporaryMatrix<T>
1248 return TemporaryMatrix<T>(b.
shape(), a) -= b;
1263 template <
class T,
class C1,
class C2>
1264 typename NormTraits<T>::SquaredNormType
1267 typename NormTraits<T>::SquaredNormType ret =
1268 NumericTraits<typename NormTraits<T>::SquaredNormType>::zero();
1271 std::ptrdiff_t size = y.
shape(0);
1273 for(std::ptrdiff_t i = 0; i < size; ++i)
1274 ret += x(0, i) * y(i, 0);
1275 else if(x.
shape(1) == 1u && x.
shape(0) == size)
1276 for(std::ptrdiff_t i = 0; i < size; ++i)
1277 ret += x(i, 0) * y(i, 0);
1279 vigra_precondition(
false,
"dot(): wrong matrix shapes.");
1281 else if(y.
shape(0) == 1)
1283 std::ptrdiff_t size = y.
shape(1);
1285 for(std::ptrdiff_t i = 0; i < size; ++i)
1286 ret += x(0, i) * y(0, i);
1287 else if(x.
shape(1) == 1u && x.
shape(0) == size)
1288 for(std::ptrdiff_t i = 0; i < size; ++i)
1289 ret += x(i, 0) * y(0, i);
1291 vigra_precondition(
false,
"dot(): wrong matrix shapes.");
1294 vigra_precondition(
false,
"dot(): wrong matrix shapes.");
1305 template <
class T,
class C1,
class C2>
1306 typename NormTraits<T>::SquaredNormType
1311 "dot(): shape mismatch.");
1312 typename NormTraits<T>::SquaredNormType ret =
1313 NumericTraits<typename NormTraits<T>::SquaredNormType>::zero();
1326 template <
class T,
class C1,
class C2,
class C3>
1331 "cross(): vectors must have length 3.");
1332 r(0) = x(1)*y(2) - x(2)*y(1);
1333 r(1) = x(2)*y(0) - x(0)*y(2);
1334 r(2) = x(0)*y(1) - x(1)*y(0);
1345 template <
class T,
class C1,
class C2,
class C3>
1350 "cross(): vectors must have length 3.");
1351 r(0,0) = x(1,0)*y(2,0) - x(2,0)*y(1,0);
1352 r(1,0) = x(2,0)*y(0,0) - x(0,0)*y(2,0);
1353 r(2,0) = x(0,0)*y(1,0) - x(1,0)*y(0,0);
1364 template <
class T,
class C1,
class C2>
1368 TemporaryMatrix<T> ret(3, 1);
1381 template <
class T,
class C1,
class C2,
class C3>
1389 "outer(): shape mismatch.");
1392 r(j, i) = x(j, 0) * y(0, i);
1404 template <
class T,
class C1,
class C2>
1411 "outer(): shape mismatch.");
1412 TemporaryMatrix<T> ret(rows, cols);
1424 template <
class T,
class C>
1430 vigra_precondition(rows == 1 || cols == 1,
1431 "outer(): matrix does not represent a vector.");
1433 TemporaryMatrix<T> ret(size, size);
1439 ret(j, i) = x(0, j) * x(0, i);
1445 ret(j, i) = x(j, 0) * x(i, 0);
1456 PointWise(T
const & it)
1462 PointWise<T> pointWise(T
const & t)
1464 return PointWise<T>(t);
1475 template <
class T,
class C1,
class C2>
1481 "smul(): Matrix sizes must agree.");
1485 r(j, i) = a(j, i) * b;
1495 template <
class T,
class C2,
class C3>
1508 template <
class T,
class C1,
class C2,
class C3>
1516 "mmul(): Matrix shapes must agree.");
1522 r(j, i) = a(j, 0) * b(0, i);
1525 r(j, i) += a(j, k) * b(k, i);
1537 template <
class T,
class C1,
class C2>
1538 inline TemporaryMatrix<T>
1553 template <
class T,
class C1,
class C2,
class C3>
1561 "pmul(): Matrix shapes must agree.");
1565 r(j, i) = a(j, i) * b(j, i);
1578 template <
class T,
class C1,
class C2>
1579 inline TemporaryMatrix<T>
1582 TemporaryMatrix<T> ret(a.
shape());
1605 template <
class T,
class C,
class U>
1606 inline TemporaryMatrix<T>
1609 return pmul(a, b.t);
1619 template <
class T,
class C>
1620 inline TemporaryMatrix<T>
1623 return TemporaryMatrix<T>(a) *= b;
1627 inline TemporaryMatrix<T>
1628 operator*(
const TemporaryMatrix<T> &a, T b)
1630 return const_cast<TemporaryMatrix<T> &
>(a) *= b;
1640 template <
class T,
class C>
1641 inline TemporaryMatrix<T>
1644 return TemporaryMatrix<T>(b) *= a;
1648 inline TemporaryMatrix<T>
1649 operator*(T a,
const TemporaryMatrix<T> &b)
1651 return const_cast<TemporaryMatrix<T> &
>(b) *= a;
1662 template <
class T,
class A,
int N,
class DATA,
class DERIVED>
1667 "operator*(Matrix, TinyVector): Shape mismatch.");
1683 template <
class T,
int N,
class DATA,
class DERIVED,
class A>
1688 "operator*(TinyVector, Matrix): Shape mismatch.");
1704 template <
class T,
class C1,
class C2>
1705 inline TemporaryMatrix<T>
1720 template <
class T,
class C1,
class C2>
1726 "sdiv(): Matrix sizes must agree.");
1730 r(j, i) = a(j, i) / b;
1740 template <
class T,
class C1,
class C2,
class C3>
1748 "pdiv(): Matrix shapes must agree.");
1752 r(j, i) = a(j, i) / b(j, i);
1765 template <
class T,
class C1,
class C2>
1766 inline TemporaryMatrix<T>
1769 TemporaryMatrix<T> ret(a.
shape());
1792 template <
class T,
class C,
class U>
1793 inline TemporaryMatrix<T>
1796 return pdiv(a, b.t);
1806 template <
class T,
class C>
1807 inline TemporaryMatrix<T>
1810 return TemporaryMatrix<T>(a) /= b;
1814 inline TemporaryMatrix<T>
1815 operator/(
const TemporaryMatrix<T> &a, T b)
1817 return const_cast<TemporaryMatrix<T> &
>(a) /= b;
1827 template <
class T,
class C>
1828 inline TemporaryMatrix<T>
1831 return TemporaryMatrix<T>(b.
shape(), a) / pointWise(b);
1856 template <
class T,
class C>
1859 T vopt = NumericTraits<T>::max();
1861 for(
int k=0; k < a.
size(); ++k)
1889 template <
class T,
class C>
1892 T vopt = NumericTraits<T>::min();
1894 for(
int k=0; k < a.
size(); ++k)
1924 template <
class T,
class C,
class UnaryFunctor>
1927 T vopt = NumericTraits<T>::max();
1929 for(
int k=0; k < a.
size(); ++k)
1931 if(condition(a[k]) && a[k] < vopt)
1959 template <
class T,
class C,
class UnaryFunctor>
1962 T vopt = NumericTraits<T>::min();
1964 for(
int k=0; k < a.
size(); ++k)
1966 if(condition(a[k]) && vopt < a[k])
1977 template <
class T,
class C>
1980 linalg::TemporaryMatrix<T> t(v.
shape());
1990 linalg::TemporaryMatrix<T>
pow(linalg::TemporaryMatrix<T>
const & v, T exponent)
1992 linalg::TemporaryMatrix<T> & t =
const_cast<linalg::TemporaryMatrix<T> &
>(v);
2001 template <
class T,
class C>
2002 linalg::TemporaryMatrix<T>
pow(MultiArrayView<2, T, C>
const & v,
int exponent)
2004 linalg::TemporaryMatrix<T> t(v.shape());
2014 linalg::TemporaryMatrix<T>
pow(linalg::TemporaryMatrix<T>
const & v,
int exponent)
2016 linalg::TemporaryMatrix<T> & t =
const_cast<linalg::TemporaryMatrix<T> &
>(v);
2026 linalg::TemporaryMatrix<int>
pow(MultiArrayView<2, int, C>
const & v,
int exponent)
2028 linalg::TemporaryMatrix<int> t(v.shape());
2033 t(j, i) = (int)
vigra::pow((
double)v(j, i), exponent);
2038 linalg::TemporaryMatrix<int>
pow(linalg::TemporaryMatrix<int>
const & v,
int exponent)
2040 linalg::TemporaryMatrix<int> & t =
const_cast<linalg::TemporaryMatrix<int> &
>(v);
2045 t(j, i) = (int)
vigra::pow((
double)t(j, i), exponent);
2050 template <
class T,
class C>
2051 linalg::TemporaryMatrix<T>
sqrt(MultiArrayView<2, T, C>
const & v);
2053 template <
class T,
class C>
2054 linalg::TemporaryMatrix<T>
exp(MultiArrayView<2, T, C>
const & v);
2056 template <
class T,
class C>
2057 linalg::TemporaryMatrix<T>
log(MultiArrayView<2, T, C>
const & v);
2059 template <
class T,
class C>
2060 linalg::TemporaryMatrix<T>
log10(MultiArrayView<2, T, C>
const & v);
2062 template <
class T,
class C>
2063 linalg::TemporaryMatrix<T>
sin(MultiArrayView<2, T, C>
const & v);
2065 template <
class T,
class C>
2066 linalg::TemporaryMatrix<T>
asin(MultiArrayView<2, T, C>
const & v);
2068 template <
class T,
class C>
2069 linalg::TemporaryMatrix<T>
cos(MultiArrayView<2, T, C>
const & v);
2071 template <
class T,
class C>
2072 linalg::TemporaryMatrix<T>
acos(MultiArrayView<2, T, C>
const & v);
2074 template <
class T,
class C>
2075 linalg::TemporaryMatrix<T>
tan(MultiArrayView<2, T, C>
const & v);
2077 template <
class T,
class C>
2078 linalg::TemporaryMatrix<T>
atan(MultiArrayView<2, T, C>
const & v);
2080 template <
class T,
class C>
2081 linalg::TemporaryMatrix<T>
round(MultiArrayView<2, T, C>
const & v);
2083 template <
class T,
class C>
2084 linalg::TemporaryMatrix<T>
floor(MultiArrayView<2, T, C>
const & v);
2086 template <
class T,
class C>
2087 linalg::TemporaryMatrix<T>
ceil(MultiArrayView<2, T, C>
const & v);
2089 template <
class T,
class C>
2090 linalg::TemporaryMatrix<T>
abs(MultiArrayView<2, T, C>
const & v);
2092 template <
class T,
class C>
2093 linalg::TemporaryMatrix<T>
sq(MultiArrayView<2, T, C>
const & v);
2095 template <
class T,
class C>
2096 linalg::TemporaryMatrix<T>
sign(MultiArrayView<2, T, C>
const & v);
2098 #define VIGRA_MATRIX_UNARY_FUNCTION(FUNCTION, NAMESPACE) \
2099 using NAMESPACE::FUNCTION; \
2100 template <class T, class C> \
2101 linalg::TemporaryMatrix<T> FUNCTION(MultiArrayView<2, T, C> const & v) \
2103 linalg::TemporaryMatrix<T> t(v.shape()); \
2104 MultiArrayIndex m = rowCount(v), n = columnCount(v); \
2106 for(MultiArrayIndex i = 0; i < n; ++i) \
2107 for(MultiArrayIndex j = 0; j < m; ++j) \
2108 t(j, i) = NAMESPACE::FUNCTION(v(j, i)); \
2112 template <class T> \
2113 linalg::TemporaryMatrix<T> FUNCTION(linalg::Matrix<T> const & v) \
2115 linalg::TemporaryMatrix<T> t(v.shape()); \
2116 MultiArrayIndex m = rowCount(v), n = columnCount(v); \
2118 for(MultiArrayIndex i = 0; i < n; ++i) \
2119 for(MultiArrayIndex j = 0; j < m; ++j) \
2120 t(j, i) = NAMESPACE::FUNCTION(v(j, i)); \
2124 template <class T> \
2125 linalg::TemporaryMatrix<T> FUNCTION(linalg::TemporaryMatrix<T> const & v) \
2127 linalg::TemporaryMatrix<T> & t = const_cast<linalg::TemporaryMatrix<T> &>(v); \
2128 MultiArrayIndex m = rowCount(t), n = columnCount(t); \
2130 for(MultiArrayIndex i = 0; i < n; ++i) \
2131 for(MultiArrayIndex j = 0; j < m; ++j) \
2132 t(j, i) = NAMESPACE::FUNCTION(t(j, i)); \
2136 using linalg::FUNCTION;\
2139 VIGRA_MATRIX_UNARY_FUNCTION(
sqrt, std)
2140 VIGRA_MATRIX_UNARY_FUNCTION(
exp, std)
2141 VIGRA_MATRIX_UNARY_FUNCTION(
log, std)
2142 VIGRA_MATRIX_UNARY_FUNCTION(
log10, std)
2143 VIGRA_MATRIX_UNARY_FUNCTION(
sin, std)
2144 VIGRA_MATRIX_UNARY_FUNCTION(
asin, std)
2145 VIGRA_MATRIX_UNARY_FUNCTION(
cos, std)
2146 VIGRA_MATRIX_UNARY_FUNCTION(
acos, std)
2147 VIGRA_MATRIX_UNARY_FUNCTION(
tan, std)
2148 VIGRA_MATRIX_UNARY_FUNCTION(
atan, std)
2149 VIGRA_MATRIX_UNARY_FUNCTION(
round, vigra)
2150 VIGRA_MATRIX_UNARY_FUNCTION(
floor, vigra)
2151 VIGRA_MATRIX_UNARY_FUNCTION(
ceil, vigra)
2152 VIGRA_MATRIX_UNARY_FUNCTION(
abs, vigra)
2153 VIGRA_MATRIX_UNARY_FUNCTION(
sq, vigra)
2154 VIGRA_MATRIX_UNARY_FUNCTION(
sign, vigra)
2156 #undef VIGRA_MATRIX_UNARY_FUNCTION
2162 using linalg::RowMajor;
2163 using linalg::ColumnMajor;
2164 using linalg::Matrix;
2168 using linalg::pointWise;
2192 template <
class T,
class ALLOC>
2193 struct NormTraits<Matrix<T, ALLOC> >
2194 :
public NormTraits<MultiArray<2, T, ALLOC> >
2196 typedef NormTraits<MultiArray<2, T, ALLOC> > BaseType;
2197 typedef Matrix<T, ALLOC> Type;
2198 typedef typename BaseType::SquaredNormType SquaredNormType;
2199 typedef typename BaseType::NormType NormType;
2202 template <
class T,
class ALLOC>
2203 struct NormTraits<linalg::TemporaryMatrix<T, ALLOC> >
2204 :
public NormTraits<Matrix<T, ALLOC> >
2206 typedef NormTraits<Matrix<T, ALLOC> > BaseType;
2207 typedef linalg::TemporaryMatrix<T, ALLOC> Type;
2208 typedef typename BaseType::SquaredNormType SquaredNormType;
2209 typedef typename BaseType::NormType NormType;
2226 template <
class T,
class C>
2228 operator<<(ostream & s, const vigra::MultiArrayView<2, T, C> &m)
2232 ios::fmtflags flags = s.setf(ios::right | ios::fixed, ios::adjustfield | ios::floatfield);
2237 s << m(j, i) <<
" ";
2255 template <
class T1,
class C1,
class T2,
class C2,
class T3,
class C3>
2257 columnStatisticsImpl(MultiArrayView<2, T1, C1>
const & A,
2258 MultiArrayView<2, T2, C2> & mean, MultiArrayView<2, T3, C3> & sumOfSquaredDifferences)
2264 "columnStatistics(): Shape mismatch between input and output.");
2267 mean.init(NumericTraits<T2>::zero());
2268 sumOfSquaredDifferences.init(NumericTraits<T3>::zero());
2273 typename NumericTraits<T2>::RealPromote f = 1.0 / (k + 1.0),
2276 sumOfSquaredDifferences += f1*
sq(t);
2280 template <
class T1,
class C1,
class T2,
class C2,
class T3,
class C3>
2282 columnStatistics2PassImpl(MultiArrayView<2, T1, C1>
const & A,
2283 MultiArrayView<2, T2, C2> & mean, MultiArrayView<2, T3, C3> & sumOfSquaredDifferences)
2289 "columnStatistics(): Shape mismatch between input and output.");
2292 mean.init(NumericTraits<T2>::zero());
2299 sumOfSquaredDifferences.init(NumericTraits<T3>::zero());
2302 sumOfSquaredDifferences +=
sq(
rowVector(A, k) - mean);
2367 template <
class T1,
class C1,
class T2,
class C2>
2370 MultiArrayView<2, T2, C2> & mean)
2375 "columnStatistics(): Shape mismatch between input and output.");
2377 mean.init(NumericTraits<T2>::zero());
2386 template <
class T1,
class C1,
class T2,
class C2,
class T3,
class C3>
2389 MultiArrayView<2, T2, C2> & mean, MultiArrayView<2, T3, C3> & stdDev)
2391 detail::columnStatisticsImpl(A, mean, stdDev);
2397 template <
class T1,
class C1,
class T2,
class C2,
class T3,
class C3,
class T4,
class C4>
2400 MultiArrayView<2, T2, C2> & mean, MultiArrayView<2, T3, C3> & stdDev, MultiArrayView<2, T4, C4> &
norm)
2407 "columnStatistics(): Shape mismatch between input and output.");
2409 detail::columnStatisticsImpl(A, mean, stdDev);
2410 norm =
sqrt(stdDev + T2(m) *
sq(mean));
2411 stdDev =
sqrt(stdDev / T3(m - 1.0));
2470 template <
class T1,
class C1,
class T2,
class C2>
2473 MultiArrayView<2, T2, C2> & mean)
2476 "rowStatistics(): Shape mismatch between input and output.");
2477 MultiArrayView<2, T2, StridedArrayTag> tm =
transpose(mean);
2481 template <
class T1,
class C1,
class T2,
class C2,
class T3,
class C3>
2484 MultiArrayView<2, T2, C2> & mean, MultiArrayView<2, T3, C3> & stdDev)
2488 "rowStatistics(): Shape mismatch between input and output.");
2489 MultiArrayView<2, T2, StridedArrayTag> tm =
transpose(mean);
2490 MultiArrayView<2, T3, StridedArrayTag> ts =
transpose(stdDev);
2494 template <
class T1,
class C1,
class T2,
class C2,
class T3,
class C3,
class T4,
class C4>
2497 MultiArrayView<2, T2, C2> & mean, MultiArrayView<2, T3, C3> & stdDev, MultiArrayView<2, T4, C4> & norm)
2502 "rowStatistics(): Shape mismatch between input and output.");
2503 MultiArrayView<2, T2, StridedArrayTag> tm =
transpose(mean);
2504 MultiArrayView<2, T3, StridedArrayTag> ts =
transpose(stdDev);
2505 MultiArrayView<2, T4, StridedArrayTag> tn =
transpose(norm);
2511 template <
class T1,
class C1,
class U,
class T2,
class C2,
class T3,
class C3>
2512 void updateCovarianceMatrix(MultiArrayView<2, T1, C1>
const & features,
2513 U & count, MultiArrayView<2, T2, C2> & mean, MultiArrayView<2, T3, C3> & covariance)
2517 "updateCovarianceMatrix(): Features must be a row or column vector.");
2518 vigra_precondition(mean.shape() == features.shape(),
2519 "updateCovarianceMatrix(): Shape mismatch between feature vector and mean vector.");
2521 "updateCovarianceMatrix(): Shape mismatch between feature vector and covariance matrix.");
2524 Matrix<T2> t = features - mean;
2526 double f = 1.0 / count,
2534 covariance(k, k) += f1*
sq(t(0, k));
2537 covariance(k, l) += f1*t(0, k)*t(0, l);
2538 covariance(l, k) = covariance(k, l);
2546 covariance(k, k) += f1*
sq(t(k, 0));
2549 covariance(k, l) += f1*t(k, 0)*t(l, 0);
2550 covariance(l, k) = covariance(k, l);
2566 template <
class T1,
class C1,
class T2,
class C2>
2572 "covarianceMatrixOfColumns(): Shape mismatch between feature matrix and covariance matrix.");
2575 covariance.
init(NumericTraits<T2>::zero());
2577 detail::updateCovarianceMatrix(
rowVector(features, k), count, means, covariance);
2578 covariance /= T2(m - 1);
2589 template <
class T,
class C>
2606 template <
class T1,
class C1,
class T2,
class C2>
2612 "covarianceMatrixOfRows(): Shape mismatch between feature matrix and covariance matrix.");
2615 covariance.
init(NumericTraits<T2>::zero());
2617 detail::updateCovarianceMatrix(
columnVector(features, k), count, means, covariance);
2618 covariance /= T2(m - 1);
2629 template <
class T,
class C>
2638 enum DataPreparationGoals { ZeroMean = 1, UnitVariance = 2, UnitNorm = 4 };
2640 inline DataPreparationGoals operator|(DataPreparationGoals l, DataPreparationGoals r)
2642 return DataPreparationGoals(
int(l) |
int(r));
2647 template <
class T,
class C1,
class C2,
class C3,
class C4>
2649 prepareDataImpl(
const MultiArrayView<2, T, C1> & A,
2650 MultiArrayView<2, T, C2> & res, MultiArrayView<2, T, C3> & offset, MultiArrayView<2, T, C4> & scaling,
2651 DataPreparationGoals goals)
2655 vigra_precondition(A.shape() == res.shape() &&
2658 "prepareDataImpl(): Shape mismatch between input and output.");
2663 offset.init(NumericTraits<T>::zero());
2664 scaling.init(NumericTraits<T>::one());
2668 bool zeroMean = (goals & ZeroMean) != 0;
2669 bool unitVariance = (goals & UnitVariance) != 0;
2670 bool unitNorm = (goals & UnitNorm) != 0;
2672 vigra_precondition(!(unitVariance && unitNorm),
2673 "prepareDataImpl(): Unit variance and unit norm cannot be achieved at the same time.");
2675 Matrix<T> mean(1, n), sumOfSquaredDifferences(1, n);
2676 detail::columnStatisticsImpl(A, mean, sumOfSquaredDifferences);
2680 T stdDev =
std::sqrt(sumOfSquaredDifferences(0, k) / T(m-1));
2682 stdDev = NumericTraits<T>::zero();
2683 if(zeroMean && stdDev > NumericTraits<T>::zero())
2686 offset(0, k) = mean(0, k);
2687 mean(0, k) = NumericTraits<T>::zero();
2692 offset(0, k) = NumericTraits<T>::zero();
2695 T norm = mean(0,k) == NumericTraits<T>::zero()
2696 ?
std::sqrt(sumOfSquaredDifferences(0, k))
2697 : std::
sqrt(sumOfSquaredDifferences(0, k) + T(m) *
sq(mean(0,k)));
2698 if(unitNorm && norm > NumericTraits<T>::zero())
2701 scaling(0, k) = NumericTraits<T>::one() /
norm;
2703 else if(unitVariance && stdDev > NumericTraits<T>::zero())
2706 scaling(0, k) = NumericTraits<T>::one() / stdDev;
2710 scaling(0, k) = NumericTraits<T>::one();
2790 template <
class T,
class C1,
class C2,
class C3,
class C4>
2793 MultiArrayView<2, T, C2> & res, MultiArrayView<2, T, C3> & offset, MultiArrayView<2, T, C4> & scaling,
2794 DataPreparationGoals goals = ZeroMean | UnitVariance)
2796 detail::prepareDataImpl(A, res, offset, scaling, goals);
2799 template <
class T,
class C1,
class C2>
2801 prepareColumns(MultiArrayView<2, T, C1>
const & A, MultiArrayView<2, T, C2> & res,
2802 DataPreparationGoals goals = ZeroMean | UnitVariance)
2805 detail::prepareDataImpl(A, res, offset, scaling, goals);
2866 template <
class T,
class C1,
class C2,
class C3,
class C4>
2869 MultiArrayView<2, T, C2> & res, MultiArrayView<2, T, C3> & offset, MultiArrayView<2, T, C4> & scaling,
2870 DataPreparationGoals goals = ZeroMean | UnitVariance)
2873 detail::prepareDataImpl(
transpose(A), tr, to, ts, goals);
2876 template <
class T,
class C1,
class C2>
2878 prepareRows(MultiArrayView<2, T, C1>
const & A, MultiArrayView<2, T, C2> & res,
2879 DataPreparationGoals goals = ZeroMean | UnitVariance)
2881 MultiArrayView<2, T, StridedArrayTag> tr =
transpose(res);
2883 detail::prepareDataImpl(
transpose(A), tr, offset, scaling, goals);
2894 using linalg::ZeroMean;
2895 using linalg::UnitVariance;
2896 using linalg::UnitNorm;
2902 #endif // VIGRA_MATRIX_HXX