10#ifndef EIGEN_MATRIX_SQUARE_ROOT
11#define EIGEN_MATRIX_SQUARE_ROOT
26template <
typename MatrixType>
42 eigen_assert(A.rows() == A.cols());
53 template <
typename ResultType>
void compute(ResultType &result);
56 typedef typename MatrixType::Index Index;
57 typedef typename MatrixType::Scalar Scalar;
59 void computeDiagonalPartOfSqrt(MatrixType& sqrtT,
const MatrixType& T);
60 void computeOffDiagonalPartOfSqrt(MatrixType& sqrtT,
const MatrixType& T);
61 void compute2x2diagonalBlock(MatrixType& sqrtT,
const MatrixType& T,
typename MatrixType::Index i);
62 void compute1x1offDiagonalBlock(MatrixType& sqrtT,
const MatrixType& T,
63 typename MatrixType::Index i,
typename MatrixType::Index j);
64 void compute1x2offDiagonalBlock(MatrixType& sqrtT,
const MatrixType& T,
65 typename MatrixType::Index i,
typename MatrixType::Index j);
66 void compute2x1offDiagonalBlock(MatrixType& sqrtT,
const MatrixType& T,
67 typename MatrixType::Index i,
typename MatrixType::Index j);
68 void compute2x2offDiagonalBlock(MatrixType& sqrtT,
const MatrixType& T,
69 typename MatrixType::Index i,
typename MatrixType::Index j);
71 template <
typename SmallMatrixType>
72 static void solveAuxiliaryEquation(SmallMatrixType& X,
const SmallMatrixType& A,
73 const SmallMatrixType& B,
const SmallMatrixType& C);
75 const MatrixType& m_A;
78template <
typename MatrixType>
79template <
typename ResultType>
84 const MatrixType& T = schurOfA.
matrixT();
85 const MatrixType& U = schurOfA.
matrixU();
88 MatrixType sqrtT = MatrixType::Zero(m_A.rows(), m_A.rows());
89 computeDiagonalPartOfSqrt(sqrtT, T);
90 computeOffDiagonalPartOfSqrt(sqrtT, T);
93 result = U * sqrtT * U.adjoint();
98template <
typename MatrixType>
99void MatrixSquareRootQuasiTriangular<MatrixType>::computeDiagonalPartOfSqrt(MatrixType& sqrtT,
102 const Index size = m_A.rows();
103 for (Index i = 0; i < size; i++) {
104 if (i == size - 1 || T.coeff(i+1, i) == 0) {
105 eigen_assert(T(i,i) > 0);
106 sqrtT.coeffRef(i,i) = internal::sqrt(T.coeff(i,i));
109 compute2x2diagonalBlock(sqrtT, T, i);
117template <
typename MatrixType>
118void MatrixSquareRootQuasiTriangular<MatrixType>::computeOffDiagonalPartOfSqrt(MatrixType& sqrtT,
121 const Index size = m_A.rows();
122 for (Index j = 1; j < size; j++) {
123 if (T.coeff(j, j-1) != 0)
125 for (Index i = j-1; i >= 0; i--) {
126 if (i > 0 && T.coeff(i, i-1) != 0)
128 bool iBlockIs2x2 = (i < size - 1) && (T.coeff(i+1, i) != 0);
129 bool jBlockIs2x2 = (j < size - 1) && (T.coeff(j+1, j) != 0);
130 if (iBlockIs2x2 && jBlockIs2x2)
131 compute2x2offDiagonalBlock(sqrtT, T, i, j);
132 else if (iBlockIs2x2 && !jBlockIs2x2)
133 compute2x1offDiagonalBlock(sqrtT, T, i, j);
134 else if (!iBlockIs2x2 && jBlockIs2x2)
135 compute1x2offDiagonalBlock(sqrtT, T, i, j);
136 else if (!iBlockIs2x2 && !jBlockIs2x2)
137 compute1x1offDiagonalBlock(sqrtT, T, i, j);
144template <
typename MatrixType>
152 sqrtT.template block<2,2>(i,i)
153 = (es.eigenvectors() * es.eigenvalues().cwiseSqrt().asDiagonal() * es.eigenvectors().inverse()).real();
159template <
typename MatrixType>
162 typename MatrixType::Index i,
typename MatrixType::Index j)
164 Scalar tmp = (sqrtT.row(i).segment(i+1,j-i-1) * sqrtT.col(j).segment(i+1,j-i-1)).value();
165 sqrtT.coeffRef(i,j) = (T.coeff(i,j) - tmp) / (sqrtT.coeff(i,i) + sqrtT.coeff(j,j));
169template <
typename MatrixType>
172 typename MatrixType::Index i,
typename MatrixType::Index j)
176 rhs -= sqrtT.block(i, i+1, 1, j-i-1) * sqrtT.block(i+1, j, j-i-1, 2);
178 A += sqrtT.template block<2,2>(j,j).transpose();
179 sqrtT.template block<1,2>(i,j).transpose() = A.fullPivLu().solve(rhs.transpose());
183template <
typename MatrixType>
186 typename MatrixType::Index i,
typename MatrixType::Index j)
190 rhs -= sqrtT.block(i, i+2, 2, j-i-2) * sqrtT.block(i+2, j, j-i-2, 1);
192 A += sqrtT.template block<2,2>(i,i);
193 sqrtT.template block<2,1>(i,j) = A.fullPivLu().solve(rhs);
197template <
typename MatrixType>
200 typename MatrixType::Index i,
typename MatrixType::Index j)
206 C -= sqrtT.block(i, i+2, 2, j-i-2) * sqrtT.block(i+2, j, j-i-2, 2);
208 solveAuxiliaryEquation(X, A, B, C);
209 sqrtT.template block<2,2>(i,j) = X;
213template <
typename MatrixType>
214template <
typename SmallMatrixType>
217 const SmallMatrixType& B,
const SmallMatrixType& C)
220 EIGEN_INTERNAL_ERROR_PLEASE_FILE_A_BUG_REPORT);
223 coeffMatrix.coeffRef(0,0) = A.coeff(0,0) + B.coeff(0,0);
224 coeffMatrix.coeffRef(1,1) = A.coeff(0,0) + B.coeff(1,1);
225 coeffMatrix.coeffRef(2,2) = A.coeff(1,1) + B.coeff(0,0);
226 coeffMatrix.coeffRef(3,3) = A.coeff(1,1) + B.coeff(1,1);
227 coeffMatrix.coeffRef(0,1) = B.coeff(1,0);
228 coeffMatrix.coeffRef(0,2) = A.coeff(0,1);
229 coeffMatrix.coeffRef(1,0) = B.coeff(0,1);
230 coeffMatrix.coeffRef(1,3) = A.coeff(0,1);
231 coeffMatrix.coeffRef(2,0) = A.coeff(1,0);
232 coeffMatrix.coeffRef(2,3) = B.coeff(1,0);
233 coeffMatrix.coeffRef(3,1) = A.coeff(1,0);
234 coeffMatrix.coeffRef(3,2) = B.coeff(0,1);
237 rhs.coeffRef(0) = C.coeff(0,0);
238 rhs.coeffRef(1) = C.coeff(0,1);
239 rhs.coeffRef(2) = C.coeff(1,0);
240 rhs.coeffRef(3) = C.coeff(1,1);
243 result = coeffMatrix.
fullPivLu().solve(rhs);
245 X.coeffRef(0,0) = result.coeff(0);
246 X.coeffRef(0,1) = result.coeff(1);
247 X.coeffRef(1,0) = result.coeff(2);
248 X.coeffRef(1,1) = result.coeff(3);
263template <
typename MatrixType>
264class MatrixSquareRootTriangular
267 MatrixSquareRootTriangular(
const MatrixType& A)
270 eigen_assert(A.rows() == A.cols());
282 template <
typename ResultType>
void compute(ResultType &result);
285 const MatrixType& m_A;
288template <
typename MatrixType>
289template <
typename ResultType>
294 const MatrixType& T = schurOfA.
matrixT();
295 const MatrixType& U = schurOfA.
matrixU();
299 result.resize(m_A.rows(), m_A.cols());
300 typedef typename MatrixType::Index Index;
301 for (Index i = 0; i < m_A.rows(); i++) {
302 result.coeffRef(i,i) = internal::sqrt(T.coeff(i,i));
304 for (Index j = 1; j < m_A.cols(); j++) {
305 for (Index i = j-1; i >= 0; i--) {
306 typedef typename MatrixType::Scalar Scalar;
308 Scalar tmp = (result.row(i).segment(i+1,j-i-1) * result.col(j).segment(i+1,j-i-1)).value();
310 result.coeffRef(i,j) = (T.coeff(i,j) - tmp) / (result.coeff(i,i) + result.coeff(j,j));
316 tmp.noalias() = U * result.template triangularView<Upper>();
317 result.noalias() = tmp * U.adjoint();
328template <typename MatrixType, int IsComplex = NumTraits<typename internal::traits<MatrixType>::Scalar>::IsComplex>
349 template <
typename ResultType>
void compute(ResultType &result);
355template <
typename MatrixType>
363 eigen_assert(A.rows() == A.cols());
366 template <
typename ResultType>
void compute(ResultType &result)
369 const RealSchur<MatrixType> schurOfA(m_A);
370 const MatrixType& T = schurOfA.matrixT();
371 const MatrixType& U = schurOfA.matrixU();
374 MatrixSquareRootQuasiTriangular<MatrixType> tmp(T);
375 MatrixType sqrtT = MatrixType::Zero(m_A.rows(), m_A.rows());
379 result = U * sqrtT * U.adjoint();
383 const MatrixType& m_A;
389template <
typename MatrixType>
397 eigen_assert(A.rows() == A.cols());
400 template <
typename ResultType>
void compute(ResultType &result)
403 const ComplexSchur<MatrixType> schurOfA(m_A);
404 const MatrixType& T = schurOfA.matrixT();
405 const MatrixType& U = schurOfA.matrixU();
408 MatrixSquareRootTriangular<MatrixType> tmp(T);
409 MatrixType sqrtT = MatrixType::Zero(m_A.rows(), m_A.rows());
413 result = U * sqrtT * U.adjoint();
417 const MatrixType& m_A;
434:
public ReturnByValue<MatrixSquareRootReturnValue<Derived> >
436 typedef typename Derived::Index Index;
450 template <
typename ResultType>
451 inline void evalTo(ResultType& result)
const
453 const typename Derived::PlainObject srcEvaluated = m_src.eval();
458 Index rows()
const {
return m_src.rows(); }
459 Index cols()
const {
return m_src.cols(); }
462 const Derived& m_src;
468template<
typename Derived>
469struct traits<MatrixSquareRootReturnValue<Derived> >
471 typedef typename Derived::PlainObject ReturnType;
475template <
typename Derived>
478 eigen_assert(rows() == cols());
const ComplexMatrixType & matrixU() const
const ComplexMatrixType & matrixT() const
static const ConstantReturnType Zero()
static const IdentityReturnType Identity()
const FullPivLU< PlainObject > fullPivLu() const
Class for computing matrix square roots of upper quasi-triangular matrices.
Definition MatrixSquareRoot.h:28
void compute(ResultType &result)
Compute the matrix square root.
Definition MatrixSquareRoot.h:80
MatrixSquareRootQuasiTriangular(const MatrixType &A)
Constructor.
Definition MatrixSquareRoot.h:39
Proxy for the matrix square root of some matrix (expression).
Definition MatrixSquareRoot.h:435
MatrixSquareRootReturnValue(const Derived &src)
Constructor.
Definition MatrixSquareRoot.h:443
void evalTo(ResultType &result) const
Compute the matrix square root.
Definition MatrixSquareRoot.h:451
void compute(ResultType &result)
Compute the matrix square root.
Definition MatrixSquareRoot.h:290
Class for computing matrix square roots of general matrices.
Definition MatrixSquareRoot.h:330
void compute(ResultType &result)
Compute the matrix square root.
MatrixSquareRoot(const MatrixType &A)
Constructor.
const MatrixType & matrixU() const
const MatrixType & matrixT() const