MatrixSquareRoot.h
1// This file is part of Eigen, a lightweight C++ template library
2// for linear algebra.
3//
4// Copyright (C) 2011 Jitse Niesen <jitse@maths.leeds.ac.uk>
5//
6// This Source Code Form is subject to the terms of the Mozilla
7// Public License v. 2.0. If a copy of the MPL was not distributed
8// with this file, You can obtain one at http://mozilla.org/MPL/2.0/.
9
10#ifndef EIGEN_MATRIX_SQUARE_ROOT
11#define EIGEN_MATRIX_SQUARE_ROOT
12
13namespace Eigen {
14
26template <typename MatrixType>
28{
29 public:
30
39 MatrixSquareRootQuasiTriangular(const MatrixType& A)
40 : m_A(A)
41 {
42 eigen_assert(A.rows() == A.cols());
43 }
44
53 template <typename ResultType> void compute(ResultType &result);
54
55 private:
56 typedef typename MatrixType::Index Index;
57 typedef typename MatrixType::Scalar Scalar;
58
59 void computeDiagonalPartOfSqrt(MatrixType& sqrtT, const MatrixType& T);
60 void computeOffDiagonalPartOfSqrt(MatrixType& sqrtT, const MatrixType& T);
61 void compute2x2diagonalBlock(MatrixType& sqrtT, const MatrixType& T, typename MatrixType::Index i);
62 void compute1x1offDiagonalBlock(MatrixType& sqrtT, const MatrixType& T,
63 typename MatrixType::Index i, typename MatrixType::Index j);
64 void compute1x2offDiagonalBlock(MatrixType& sqrtT, const MatrixType& T,
65 typename MatrixType::Index i, typename MatrixType::Index j);
66 void compute2x1offDiagonalBlock(MatrixType& sqrtT, const MatrixType& T,
67 typename MatrixType::Index i, typename MatrixType::Index j);
68 void compute2x2offDiagonalBlock(MatrixType& sqrtT, const MatrixType& T,
69 typename MatrixType::Index i, typename MatrixType::Index j);
70
71 template <typename SmallMatrixType>
72 static void solveAuxiliaryEquation(SmallMatrixType& X, const SmallMatrixType& A,
73 const SmallMatrixType& B, const SmallMatrixType& C);
74
75 const MatrixType& m_A;
76};
77
78template <typename MatrixType>
79template <typename ResultType>
81{
82 // Compute Schur decomposition of m_A
83 const RealSchur<MatrixType> schurOfA(m_A);
84 const MatrixType& T = schurOfA.matrixT();
85 const MatrixType& U = schurOfA.matrixU();
86
87 // Compute square root of T
88 MatrixType sqrtT = MatrixType::Zero(m_A.rows(), m_A.rows());
89 computeDiagonalPartOfSqrt(sqrtT, T);
90 computeOffDiagonalPartOfSqrt(sqrtT, T);
91
92 // Compute square root of m_A
93 result = U * sqrtT * U.adjoint();
94}
95
96// pre: T is quasi-upper-triangular and sqrtT is a zero matrix of the same size
97// post: the diagonal blocks of sqrtT are the square roots of the diagonal blocks of T
98template <typename MatrixType>
99void MatrixSquareRootQuasiTriangular<MatrixType>::computeDiagonalPartOfSqrt(MatrixType& sqrtT,
100 const MatrixType& T)
101{
102 const Index size = m_A.rows();
103 for (Index i = 0; i < size; i++) {
104 if (i == size - 1 || T.coeff(i+1, i) == 0) {
105 eigen_assert(T(i,i) > 0);
106 sqrtT.coeffRef(i,i) = internal::sqrt(T.coeff(i,i));
107 }
108 else {
109 compute2x2diagonalBlock(sqrtT, T, i);
110 ++i;
111 }
112 }
113}
114
115// pre: T is quasi-upper-triangular and diagonal blocks of sqrtT are square root of diagonal blocks of T.
116// post: sqrtT is the square root of T.
117template <typename MatrixType>
118void MatrixSquareRootQuasiTriangular<MatrixType>::computeOffDiagonalPartOfSqrt(MatrixType& sqrtT,
119 const MatrixType& T)
120{
121 const Index size = m_A.rows();
122 for (Index j = 1; j < size; j++) {
123 if (T.coeff(j, j-1) != 0) // if T(j-1:j, j-1:j) is a 2-by-2 block
124 continue;
125 for (Index i = j-1; i >= 0; i--) {
126 if (i > 0 && T.coeff(i, i-1) != 0) // if T(i-1:i, i-1:i) is a 2-by-2 block
127 continue;
128 bool iBlockIs2x2 = (i < size - 1) && (T.coeff(i+1, i) != 0);
129 bool jBlockIs2x2 = (j < size - 1) && (T.coeff(j+1, j) != 0);
130 if (iBlockIs2x2 && jBlockIs2x2)
131 compute2x2offDiagonalBlock(sqrtT, T, i, j);
132 else if (iBlockIs2x2 && !jBlockIs2x2)
133 compute2x1offDiagonalBlock(sqrtT, T, i, j);
134 else if (!iBlockIs2x2 && jBlockIs2x2)
135 compute1x2offDiagonalBlock(sqrtT, T, i, j);
136 else if (!iBlockIs2x2 && !jBlockIs2x2)
137 compute1x1offDiagonalBlock(sqrtT, T, i, j);
138 }
139 }
140}
141
142// pre: T.block(i,i,2,2) has complex conjugate eigenvalues
143// post: sqrtT.block(i,i,2,2) is square root of T.block(i,i,2,2)
144template <typename MatrixType>
146 ::compute2x2diagonalBlock(MatrixType& sqrtT, const MatrixType& T, typename MatrixType::Index i)
147{
148 // TODO: This case (2-by-2 blocks with complex conjugate eigenvalues) is probably hidden somewhere
149 // in EigenSolver. If we expose it, we could call it directly from here.
150 Matrix<Scalar,2,2> block = T.template block<2,2>(i,i);
152 sqrtT.template block<2,2>(i,i)
153 = (es.eigenvectors() * es.eigenvalues().cwiseSqrt().asDiagonal() * es.eigenvectors().inverse()).real();
154}
155
156// pre: block structure of T is such that (i,j) is a 1x1 block,
157// all blocks of sqrtT to left of and below (i,j) are correct
158// post: sqrtT(i,j) has the correct value
159template <typename MatrixType>
161 ::compute1x1offDiagonalBlock(MatrixType& sqrtT, const MatrixType& T,
162 typename MatrixType::Index i, typename MatrixType::Index j)
163{
164 Scalar tmp = (sqrtT.row(i).segment(i+1,j-i-1) * sqrtT.col(j).segment(i+1,j-i-1)).value();
165 sqrtT.coeffRef(i,j) = (T.coeff(i,j) - tmp) / (sqrtT.coeff(i,i) + sqrtT.coeff(j,j));
166}
167
168// similar to compute1x1offDiagonalBlock()
169template <typename MatrixType>
171 ::compute1x2offDiagonalBlock(MatrixType& sqrtT, const MatrixType& T,
172 typename MatrixType::Index i, typename MatrixType::Index j)
173{
174 Matrix<Scalar,1,2> rhs = T.template block<1,2>(i,j);
175 if (j-i > 1)
176 rhs -= sqrtT.block(i, i+1, 1, j-i-1) * sqrtT.block(i+1, j, j-i-1, 2);
177 Matrix<Scalar,2,2> A = sqrtT.coeff(i,i) * Matrix<Scalar,2,2>::Identity();
178 A += sqrtT.template block<2,2>(j,j).transpose();
179 sqrtT.template block<1,2>(i,j).transpose() = A.fullPivLu().solve(rhs.transpose());
180}
181
182// similar to compute1x1offDiagonalBlock()
183template <typename MatrixType>
185 ::compute2x1offDiagonalBlock(MatrixType& sqrtT, const MatrixType& T,
186 typename MatrixType::Index i, typename MatrixType::Index j)
187{
188 Matrix<Scalar,2,1> rhs = T.template block<2,1>(i,j);
189 if (j-i > 2)
190 rhs -= sqrtT.block(i, i+2, 2, j-i-2) * sqrtT.block(i+2, j, j-i-2, 1);
191 Matrix<Scalar,2,2> A = sqrtT.coeff(j,j) * Matrix<Scalar,2,2>::Identity();
192 A += sqrtT.template block<2,2>(i,i);
193 sqrtT.template block<2,1>(i,j) = A.fullPivLu().solve(rhs);
194}
195
196// similar to compute1x1offDiagonalBlock()
197template <typename MatrixType>
199 ::compute2x2offDiagonalBlock(MatrixType& sqrtT, const MatrixType& T,
200 typename MatrixType::Index i, typename MatrixType::Index j)
201{
202 Matrix<Scalar,2,2> A = sqrtT.template block<2,2>(i,i);
203 Matrix<Scalar,2,2> B = sqrtT.template block<2,2>(j,j);
204 Matrix<Scalar,2,2> C = T.template block<2,2>(i,j);
205 if (j-i > 2)
206 C -= sqrtT.block(i, i+2, 2, j-i-2) * sqrtT.block(i+2, j, j-i-2, 2);
208 solveAuxiliaryEquation(X, A, B, C);
209 sqrtT.template block<2,2>(i,j) = X;
210}
211
212// solves the equation A X + X B = C where all matrices are 2-by-2
213template <typename MatrixType>
214template <typename SmallMatrixType>
216 ::solveAuxiliaryEquation(SmallMatrixType& X, const SmallMatrixType& A,
217 const SmallMatrixType& B, const SmallMatrixType& C)
218{
219 EIGEN_STATIC_ASSERT((internal::is_same<SmallMatrixType, Matrix<Scalar,2,2> >::value),
220 EIGEN_INTERNAL_ERROR_PLEASE_FILE_A_BUG_REPORT);
221
223 coeffMatrix.coeffRef(0,0) = A.coeff(0,0) + B.coeff(0,0);
224 coeffMatrix.coeffRef(1,1) = A.coeff(0,0) + B.coeff(1,1);
225 coeffMatrix.coeffRef(2,2) = A.coeff(1,1) + B.coeff(0,0);
226 coeffMatrix.coeffRef(3,3) = A.coeff(1,1) + B.coeff(1,1);
227 coeffMatrix.coeffRef(0,1) = B.coeff(1,0);
228 coeffMatrix.coeffRef(0,2) = A.coeff(0,1);
229 coeffMatrix.coeffRef(1,0) = B.coeff(0,1);
230 coeffMatrix.coeffRef(1,3) = A.coeff(0,1);
231 coeffMatrix.coeffRef(2,0) = A.coeff(1,0);
232 coeffMatrix.coeffRef(2,3) = B.coeff(1,0);
233 coeffMatrix.coeffRef(3,1) = A.coeff(1,0);
234 coeffMatrix.coeffRef(3,2) = B.coeff(0,1);
235
237 rhs.coeffRef(0) = C.coeff(0,0);
238 rhs.coeffRef(1) = C.coeff(0,1);
239 rhs.coeffRef(2) = C.coeff(1,0);
240 rhs.coeffRef(3) = C.coeff(1,1);
241
242 Matrix<Scalar,4,1> result;
243 result = coeffMatrix.fullPivLu().solve(rhs);
244
245 X.coeffRef(0,0) = result.coeff(0);
246 X.coeffRef(0,1) = result.coeff(1);
247 X.coeffRef(1,0) = result.coeff(2);
248 X.coeffRef(1,1) = result.coeff(3);
249}
250
251
263template <typename MatrixType>
264class MatrixSquareRootTriangular
265{
266 public:
267 MatrixSquareRootTriangular(const MatrixType& A)
268 : m_A(A)
269 {
270 eigen_assert(A.rows() == A.cols());
271 }
272
282 template <typename ResultType> void compute(ResultType &result);
283
284 private:
285 const MatrixType& m_A;
286};
287
288template <typename MatrixType>
289template <typename ResultType>
291{
292 // Compute Schur decomposition of m_A
293 const ComplexSchur<MatrixType> schurOfA(m_A);
294 const MatrixType& T = schurOfA.matrixT();
295 const MatrixType& U = schurOfA.matrixU();
296
297 // Compute square root of T and store it in upper triangular part of result
298 // This uses that the square root of triangular matrices can be computed directly.
299 result.resize(m_A.rows(), m_A.cols());
300 typedef typename MatrixType::Index Index;
301 for (Index i = 0; i < m_A.rows(); i++) {
302 result.coeffRef(i,i) = internal::sqrt(T.coeff(i,i));
303 }
304 for (Index j = 1; j < m_A.cols(); j++) {
305 for (Index i = j-1; i >= 0; i--) {
306 typedef typename MatrixType::Scalar Scalar;
307 // if i = j-1, then segment has length 0 so tmp = 0
308 Scalar tmp = (result.row(i).segment(i+1,j-i-1) * result.col(j).segment(i+1,j-i-1)).value();
309 // denominator may be zero if original matrix is singular
310 result.coeffRef(i,j) = (T.coeff(i,j) - tmp) / (result.coeff(i,i) + result.coeff(j,j));
311 }
312 }
313
314 // Compute square root of m_A as U * result * U.adjoint()
315 MatrixType tmp;
316 tmp.noalias() = U * result.template triangularView<Upper>();
317 result.noalias() = tmp * U.adjoint();
318}
319
320
328template <typename MatrixType, int IsComplex = NumTraits<typename internal::traits<MatrixType>::Scalar>::IsComplex>
330{
331 public:
332
340 MatrixSquareRoot(const MatrixType& A);
341
349 template <typename ResultType> void compute(ResultType &result);
350};
351
352
353// ********** Partial specialization for real matrices **********
354
355template <typename MatrixType>
356class MatrixSquareRoot<MatrixType, 0>
357{
358 public:
359
360 MatrixSquareRoot(const MatrixType& A)
361 : m_A(A)
362 {
363 eigen_assert(A.rows() == A.cols());
364 }
365
366 template <typename ResultType> void compute(ResultType &result)
367 {
368 // Compute Schur decomposition of m_A
369 const RealSchur<MatrixType> schurOfA(m_A);
370 const MatrixType& T = schurOfA.matrixT();
371 const MatrixType& U = schurOfA.matrixU();
372
373 // Compute square root of T
374 MatrixSquareRootQuasiTriangular<MatrixType> tmp(T);
375 MatrixType sqrtT = MatrixType::Zero(m_A.rows(), m_A.rows());
376 tmp.compute(sqrtT);
377
378 // Compute square root of m_A
379 result = U * sqrtT * U.adjoint();
380 }
381
382 private:
383 const MatrixType& m_A;
384};
385
386
387// ********** Partial specialization for complex matrices **********
388
389template <typename MatrixType>
390class MatrixSquareRoot<MatrixType, 1>
391{
392 public:
393
394 MatrixSquareRoot(const MatrixType& A)
395 : m_A(A)
396 {
397 eigen_assert(A.rows() == A.cols());
398 }
399
400 template <typename ResultType> void compute(ResultType &result)
401 {
402 // Compute Schur decomposition of m_A
403 const ComplexSchur<MatrixType> schurOfA(m_A);
404 const MatrixType& T = schurOfA.matrixT();
405 const MatrixType& U = schurOfA.matrixU();
406
407 // Compute square root of T
408 MatrixSquareRootTriangular<MatrixType> tmp(T);
409 MatrixType sqrtT = MatrixType::Zero(m_A.rows(), m_A.rows());
410 tmp.compute(sqrtT);
411
412 // Compute square root of m_A
413 result = U * sqrtT * U.adjoint();
414 }
415
416 private:
417 const MatrixType& m_A;
418};
419
420
433template<typename Derived> class MatrixSquareRootReturnValue
434: public ReturnByValue<MatrixSquareRootReturnValue<Derived> >
435{
436 typedef typename Derived::Index Index;
437 public:
443 MatrixSquareRootReturnValue(const Derived& src) : m_src(src) { }
444
450 template <typename ResultType>
451 inline void evalTo(ResultType& result) const
452 {
453 const typename Derived::PlainObject srcEvaluated = m_src.eval();
455 me.compute(result);
456 }
457
458 Index rows() const { return m_src.rows(); }
459 Index cols() const { return m_src.cols(); }
460
461 protected:
462 const Derived& m_src;
463 private:
465};
466
467namespace internal {
468template<typename Derived>
469struct traits<MatrixSquareRootReturnValue<Derived> >
470{
471 typedef typename Derived::PlainObject ReturnType;
472};
473}
474
475template <typename Derived>
477{
478 eigen_assert(rows() == cols());
479 return MatrixSquareRootReturnValue<Derived>(derived());
480}
481
482} // end namespace Eigen
483
484#endif // EIGEN_MATRIX_FUNCTION
const ComplexMatrixType & matrixU() const
const ComplexMatrixType & matrixT() const
static const ConstantReturnType Zero()
static const IdentityReturnType Identity()
const FullPivLU< PlainObject > fullPivLu() const
Class for computing matrix square roots of upper quasi-triangular matrices.
Definition MatrixSquareRoot.h:28
void compute(ResultType &result)
Compute the matrix square root.
Definition MatrixSquareRoot.h:80
MatrixSquareRootQuasiTriangular(const MatrixType &A)
Constructor.
Definition MatrixSquareRoot.h:39
Proxy for the matrix square root of some matrix (expression).
Definition MatrixSquareRoot.h:435
MatrixSquareRootReturnValue(const Derived &src)
Constructor.
Definition MatrixSquareRoot.h:443
void evalTo(ResultType &result) const
Compute the matrix square root.
Definition MatrixSquareRoot.h:451
void compute(ResultType &result)
Compute the matrix square root.
Definition MatrixSquareRoot.h:290
Class for computing matrix square roots of general matrices.
Definition MatrixSquareRoot.h:330
void compute(ResultType &result)
Compute the matrix square root.
MatrixSquareRoot(const MatrixType &A)
Constructor.
const MatrixType & matrixU() const
const MatrixType & matrixT() const