PaStiXSupport.h
1// This file is part of Eigen, a lightweight C++ template library
2// for linear algebra.
3//
4// Copyright (C) 2012 Désiré Nuentsa-Wakam <desire.nuentsa_wakam@inria.fr>
5//
6// This Source Code Form is subject to the terms of the Mozilla
7// Public License v. 2.0. If a copy of the MPL was not distributed
8// with this file, You can obtain one at http://mozilla.org/MPL/2.0/.
9
10#ifndef EIGEN_PASTIXSUPPORT_H
11#define EIGEN_PASTIXSUPPORT_H
12
13namespace Eigen {
14
23template<typename _MatrixType, bool IsStrSym = false> class PastixLU;
24template<typename _MatrixType, int Options> class PastixLLT;
25template<typename _MatrixType, int Options> class PastixLDLT;
26
27namespace internal
28{
29
30 template<class Pastix> struct pastix_traits;
31
32 template<typename _MatrixType>
33 struct pastix_traits< PastixLU<_MatrixType> >
34 {
35 typedef _MatrixType MatrixType;
36 typedef typename _MatrixType::Scalar Scalar;
37 typedef typename _MatrixType::RealScalar RealScalar;
38 typedef typename _MatrixType::Index Index;
39 };
40
41 template<typename _MatrixType, int Options>
42 struct pastix_traits< PastixLLT<_MatrixType,Options> >
43 {
44 typedef _MatrixType MatrixType;
45 typedef typename _MatrixType::Scalar Scalar;
46 typedef typename _MatrixType::RealScalar RealScalar;
47 typedef typename _MatrixType::Index Index;
48 };
49
50 template<typename _MatrixType, int Options>
51 struct pastix_traits< PastixLDLT<_MatrixType,Options> >
52 {
53 typedef _MatrixType MatrixType;
54 typedef typename _MatrixType::Scalar Scalar;
55 typedef typename _MatrixType::RealScalar RealScalar;
56 typedef typename _MatrixType::Index Index;
57 };
58
59 void eigen_pastix(pastix_data_t **pastix_data, int pastix_comm, int n, int *ptr, int *idx, float *vals, int *perm, int * invp, float *x, int nbrhs, int *iparm, double *dparm)
60 {
61 if (n == 0) { ptr = NULL; idx = NULL; vals = NULL; }
62 if (nbrhs == 0) {x = NULL; nbrhs=1;}
63 s_pastix(pastix_data, pastix_comm, n, ptr, idx, vals, perm, invp, x, nbrhs, iparm, dparm);
64 }
65
66 void eigen_pastix(pastix_data_t **pastix_data, int pastix_comm, int n, int *ptr, int *idx, double *vals, int *perm, int * invp, double *x, int nbrhs, int *iparm, double *dparm)
67 {
68 if (n == 0) { ptr = NULL; idx = NULL; vals = NULL; }
69 if (nbrhs == 0) {x = NULL; nbrhs=1;}
70 d_pastix(pastix_data, pastix_comm, n, ptr, idx, vals, perm, invp, x, nbrhs, iparm, dparm);
71 }
72
73 void eigen_pastix(pastix_data_t **pastix_data, int pastix_comm, int n, int *ptr, int *idx, std::complex<float> *vals, int *perm, int * invp, std::complex<float> *x, int nbrhs, int *iparm, double *dparm)
74 {
75 if (n == 0) { ptr = NULL; idx = NULL; vals = NULL; }
76 if (nbrhs == 0) {x = NULL; nbrhs=1;}
77 c_pastix(pastix_data, pastix_comm, n, ptr, idx, reinterpret_cast<COMPLEX*>(vals), perm, invp, reinterpret_cast<COMPLEX*>(x), nbrhs, iparm, dparm);
78 }
79
80 void eigen_pastix(pastix_data_t **pastix_data, int pastix_comm, int n, int *ptr, int *idx, std::complex<double> *vals, int *perm, int * invp, std::complex<double> *x, int nbrhs, int *iparm, double *dparm)
81 {
82 if (n == 0) { ptr = NULL; idx = NULL; vals = NULL; }
83 if (nbrhs == 0) {x = NULL; nbrhs=1;}
84 z_pastix(pastix_data, pastix_comm, n, ptr, idx, reinterpret_cast<DCOMPLEX*>(vals), perm, invp, reinterpret_cast<DCOMPLEX*>(x), nbrhs, iparm, dparm);
85 }
86
87 // Convert the matrix to Fortran-style Numbering
88 template <typename MatrixType>
89 void c_to_fortran_numbering (MatrixType& mat)
90 {
91 if ( !(mat.outerIndexPtr()[0]) )
92 {
93 int i;
94 for(i = 0; i <= mat.rows(); ++i)
95 ++mat.outerIndexPtr()[i];
96 for(i = 0; i < mat.nonZeros(); ++i)
97 ++mat.innerIndexPtr()[i];
98 }
99 }
100
101 // Convert to C-style Numbering
102 template <typename MatrixType>
103 void fortran_to_c_numbering (MatrixType& mat)
104 {
105 // Check the Numbering
106 if ( mat.outerIndexPtr()[0] == 1 )
107 { // Convert to C-style numbering
108 int i;
109 for(i = 0; i <= mat.rows(); ++i)
110 --mat.outerIndexPtr()[i];
111 for(i = 0; i < mat.nonZeros(); ++i)
112 --mat.innerIndexPtr()[i];
113 }
114 }
115}
116
117// This is the base class to interface with PaStiX functions.
118// Users should not used this class directly.
119template <class Derived>
120class PastixBase : internal::noncopyable
121{
122 public:
123 typedef typename internal::pastix_traits<Derived>::MatrixType _MatrixType;
124 typedef _MatrixType MatrixType;
125 typedef typename MatrixType::Scalar Scalar;
126 typedef typename MatrixType::RealScalar RealScalar;
127 typedef typename MatrixType::Index Index;
128 typedef Matrix<Scalar,Dynamic,1> Vector;
129 typedef SparseMatrix<Scalar, ColMajor> ColSpMatrix;
130
131 public:
132
133 PastixBase() : m_initisOk(false), m_analysisIsOk(false), m_factorizationIsOk(false), m_isInitialized(false), m_pastixdata(0), m_size(0)
134 {
135 init();
136 }
137
138 ~PastixBase()
139 {
140 clean();
141 }
142
147 template<typename Rhs>
148 inline const internal::solve_retval<PastixBase, Rhs>
149 solve(const MatrixBase<Rhs>& b) const
150 {
151 eigen_assert(m_isInitialized && "Pastix solver is not initialized.");
152 eigen_assert(rows()==b.rows()
153 && "PastixBase::solve(): invalid number of rows of the right hand side matrix b");
154 return internal::solve_retval<PastixBase, Rhs>(*this, b.derived());
155 }
156
157 template<typename Rhs,typename Dest>
158 bool _solve (const MatrixBase<Rhs> &b, MatrixBase<Dest> &x) const;
159
161 template<typename Rhs, typename DestScalar, int DestOptions, typename DestIndex>
162 void _solve_sparse(const Rhs& b, SparseMatrix<DestScalar,DestOptions,DestIndex> &dest) const
163 {
164 eigen_assert(m_factorizationIsOk && "The decomposition is not in a valid state for solving, you must first call either compute() or symbolic()/numeric()");
165 eigen_assert(rows()==b.rows());
166
167 // we process the sparse rhs per block of NbColsAtOnce columns temporarily stored into a dense matrix.
168 static const int NbColsAtOnce = 1;
169 int rhsCols = b.cols();
170 int size = b.rows();
171 Eigen::Matrix<DestScalar,Dynamic,Dynamic> tmp(size,rhsCols);
172 for(int k=0; k<rhsCols; k+=NbColsAtOnce)
173 {
174 int actualCols = std::min<int>(rhsCols-k, NbColsAtOnce);
175 tmp.leftCols(actualCols) = b.middleCols(k,actualCols);
176 tmp.leftCols(actualCols) = derived().solve(tmp.leftCols(actualCols));
177 dest.middleCols(k,actualCols) = tmp.leftCols(actualCols).sparseView();
178 }
179 }
180
181 Derived& derived()
182 {
183 return *static_cast<Derived*>(this);
184 }
185 const Derived& derived() const
186 {
187 return *static_cast<const Derived*>(this);
188 }
189
195 Array<Index,IPARM_SIZE,1>& iparm()
196 {
197 return m_iparm;
198 }
199
203
204 int& iparm(int idxparam)
205 {
206 return m_iparm(idxparam);
207 }
208
213 Array<RealScalar,IPARM_SIZE,1>& dparm()
214 {
215 return m_dparm;
216 }
217
218
222 double& dparm(int idxparam)
223 {
224 return m_dparm(idxparam);
225 }
226
227 inline Index cols() const { return m_size; }
228 inline Index rows() const { return m_size; }
229
238 ComputationInfo info() const
239 {
240 eigen_assert(m_isInitialized && "Decomposition is not initialized.");
241 return m_info;
242 }
243
248 template<typename Rhs>
249 inline const internal::sparse_solve_retval<PastixBase, Rhs>
250 solve(const SparseMatrixBase<Rhs>& b) const
251 {
252 eigen_assert(m_isInitialized && "Pastix LU, LLT or LDLT is not initialized.");
253 eigen_assert(rows()==b.rows()
254 && "PastixBase::solve(): invalid number of rows of the right hand side matrix b");
255 return internal::sparse_solve_retval<PastixBase, Rhs>(*this, b.derived());
256 }
257
258 protected:
259
260 // Initialize the Pastix data structure, check the matrix
261 void init();
262
263 // Compute the ordering and the symbolic factorization
264 void analyzePattern(ColSpMatrix& mat);
265
266 // Compute the numerical factorization
267 void factorize(ColSpMatrix& mat);
268
269 // Free all the data allocated by Pastix
270 void clean()
271 {
272 eigen_assert(m_initisOk && "The Pastix structure should be allocated first");
273 m_iparm(IPARM_START_TASK) = API_TASK_CLEAN;
274 m_iparm(IPARM_END_TASK) = API_TASK_CLEAN;
275 internal::eigen_pastix(&m_pastixdata, MPI_COMM_WORLD, 0, 0, 0, (Scalar*)0,
276 m_perm.data(), m_invp.data(), 0, 0, m_iparm.data(), m_dparm.data());
277 }
278
279 void compute(ColSpMatrix& mat);
280
281 int m_initisOk;
282 int m_analysisIsOk;
283 int m_factorizationIsOk;
284 bool m_isInitialized;
285 mutable ComputationInfo m_info;
286 mutable pastix_data_t *m_pastixdata; // Data structure for pastix
287 mutable int m_comm; // The MPI communicator identifier
288 mutable Matrix<int,IPARM_SIZE,1> m_iparm; // integer vector for the input parameters
289 mutable Matrix<double,DPARM_SIZE,1> m_dparm; // Scalar vector for the input parameters
290 mutable Matrix<Index,Dynamic,1> m_perm; // Permutation vector
291 mutable Matrix<Index,Dynamic,1> m_invp; // Inverse permutation vector
292 mutable int m_size; // Size of the matrix
293};
294
299template <class Derived>
300void PastixBase<Derived>::init()
301{
302 m_size = 0;
303 m_iparm.setZero(IPARM_SIZE);
304 m_dparm.setZero(DPARM_SIZE);
305
306 m_iparm(IPARM_MODIFY_PARAMETER) = API_NO;
307 pastix(&m_pastixdata, MPI_COMM_WORLD,
308 0, 0, 0, 0,
309 0, 0, 0, 1, m_iparm.data(), m_dparm.data());
310
311 m_iparm[IPARM_MATRIX_VERIFICATION] = API_NO;
312 m_iparm[IPARM_VERBOSE] = 2;
313 m_iparm[IPARM_ORDERING] = API_ORDER_SCOTCH;
314 m_iparm[IPARM_INCOMPLETE] = API_NO;
315 m_iparm[IPARM_OOC_LIMIT] = 2000;
316 m_iparm[IPARM_RHS_MAKING] = API_RHS_B;
317 m_iparm(IPARM_MATRIX_VERIFICATION) = API_NO;
318
319 m_iparm(IPARM_START_TASK) = API_TASK_INIT;
320 m_iparm(IPARM_END_TASK) = API_TASK_INIT;
321 internal::eigen_pastix(&m_pastixdata, MPI_COMM_WORLD, 0, 0, 0, (Scalar*)0,
322 0, 0, 0, 0, m_iparm.data(), m_dparm.data());
323
324 // Check the returned error
325 if(m_iparm(IPARM_ERROR_NUMBER)) {
326 m_info = InvalidInput;
327 m_initisOk = false;
328 }
329 else {
330 m_info = Success;
331 m_initisOk = true;
332 }
333}
334
335template <class Derived>
336void PastixBase<Derived>::compute(ColSpMatrix& mat)
337{
338 eigen_assert(mat.rows() == mat.cols() && "The input matrix should be squared");
339
340 analyzePattern(mat);
341 factorize(mat);
342
343 m_iparm(IPARM_MATRIX_VERIFICATION) = API_NO;
344 m_isInitialized = m_factorizationIsOk;
345}
346
347
348template <class Derived>
349void PastixBase<Derived>::analyzePattern(ColSpMatrix& mat)
350{
351 eigen_assert(m_initisOk && "The initialization of PaSTiX failed");
352
353 // clean previous calls
354 if(m_size>0)
355 clean();
356
357 m_size = mat.rows();
358 m_perm.resize(m_size);
359 m_invp.resize(m_size);
360
361 m_iparm(IPARM_START_TASK) = API_TASK_ORDERING;
362 m_iparm(IPARM_END_TASK) = API_TASK_ANALYSE;
363 internal::eigen_pastix(&m_pastixdata, MPI_COMM_WORLD, m_size, mat.outerIndexPtr(), mat.innerIndexPtr(),
364 mat.valuePtr(), m_perm.data(), m_invp.data(), 0, 0, m_iparm.data(), m_dparm.data());
365
366 // Check the returned error
367 if(m_iparm(IPARM_ERROR_NUMBER))
368 {
369 m_info = NumericalIssue;
370 m_analysisIsOk = false;
371 }
372 else
373 {
374 m_info = Success;
375 m_analysisIsOk = true;
376 }
377}
378
379template <class Derived>
380void PastixBase<Derived>::factorize(ColSpMatrix& mat)
381{
382// if(&m_cpyMat != &mat) m_cpyMat = mat;
383 eigen_assert(m_analysisIsOk && "The analysis phase should be called before the factorization phase");
384 m_iparm(IPARM_START_TASK) = API_TASK_NUMFACT;
385 m_iparm(IPARM_END_TASK) = API_TASK_NUMFACT;
386 m_size = mat.rows();
387
388 internal::eigen_pastix(&m_pastixdata, MPI_COMM_WORLD, m_size, mat.outerIndexPtr(), mat.innerIndexPtr(),
389 mat.valuePtr(), m_perm.data(), m_invp.data(), 0, 0, m_iparm.data(), m_dparm.data());
390
391 // Check the returned error
392 if(m_iparm(IPARM_ERROR_NUMBER))
393 {
394 m_info = NumericalIssue;
395 m_factorizationIsOk = false;
396 m_isInitialized = false;
397 }
398 else
399 {
400 m_info = Success;
401 m_factorizationIsOk = true;
402 m_isInitialized = true;
403 }
404}
405
406/* Solve the system */
407template<typename Base>
408template<typename Rhs,typename Dest>
409bool PastixBase<Base>::_solve (const MatrixBase<Rhs> &b, MatrixBase<Dest> &x) const
410{
411 eigen_assert(m_isInitialized && "The matrix should be factorized first");
412 EIGEN_STATIC_ASSERT((Dest::Flags&RowMajorBit)==0,
413 THIS_METHOD_IS_ONLY_FOR_COLUMN_MAJOR_MATRICES);
414 int rhs = 1;
415
416 x = b; /* on return, x is overwritten by the computed solution */
417
418 for (int i = 0; i < b.cols(); i++){
419 m_iparm[IPARM_START_TASK] = API_TASK_SOLVE;
420 m_iparm[IPARM_END_TASK] = API_TASK_REFINE;
421
422 internal::eigen_pastix(&m_pastixdata, MPI_COMM_WORLD, x.rows(), 0, 0, 0,
423 m_perm.data(), m_invp.data(), &x(0, i), rhs, m_iparm.data(), m_dparm.data());
424 }
425
426 // Check the returned error
427 m_info = m_iparm(IPARM_ERROR_NUMBER)==0 ? Success : NumericalIssue;
428
429 return m_iparm(IPARM_ERROR_NUMBER)==0;
430}
431
451template<typename _MatrixType, bool IsStrSym>
452class PastixLU : public PastixBase< PastixLU<_MatrixType> >
453{
454 public:
455 typedef _MatrixType MatrixType;
456 typedef PastixBase<PastixLU<MatrixType> > Base;
457 typedef typename Base::ColSpMatrix ColSpMatrix;
458 typedef typename MatrixType::Index Index;
459
460 public:
461 PastixLU() : Base()
462 {
463 init();
464 }
465
466 PastixLU(const MatrixType& matrix):Base()
467 {
468 init();
469 compute(matrix);
470 }
476 void compute (const MatrixType& matrix)
477 {
478 m_structureIsUptodate = false;
479 ColSpMatrix temp;
480 grabMatrix(matrix, temp);
481 Base::compute(temp);
482 }
483
488 void analyzePattern(const MatrixType& matrix)
489 {
490 m_structureIsUptodate = false;
491 ColSpMatrix temp;
492 grabMatrix(matrix, temp);
493 Base::analyzePattern(temp);
494 }
495
501 void factorize(const MatrixType& matrix)
502 {
503 ColSpMatrix temp;
504 grabMatrix(matrix, temp);
505 Base::factorize(temp);
506 }
507 protected:
508
509 void init()
510 {
511 m_structureIsUptodate = false;
512 m_iparm(IPARM_SYM) = API_SYM_NO;
513 m_iparm(IPARM_FACTORIZATION) = API_FACT_LU;
514 }
515
516 void grabMatrix(const MatrixType& matrix, ColSpMatrix& out)
517 {
518 if(IsStrSym)
519 out = matrix;
520 else
521 {
522 if(!m_structureIsUptodate)
523 {
524 // update the transposed structure
525 m_transposedStructure = matrix.transpose();
526
527 // Set the elements of the matrix to zero
528 for (Index j=0; j<m_transposedStructure.outerSize(); ++j)
529 for(typename ColSpMatrix::InnerIterator it(m_transposedStructure, j); it; ++it)
530 it.valueRef() = 0.0;
531
532 m_structureIsUptodate = true;
533 }
534
535 out = m_transposedStructure + matrix;
536 }
537 internal::c_to_fortran_numbering(out);
538 }
539
540 using Base::m_iparm;
541 using Base::m_dparm;
542
543 ColSpMatrix m_transposedStructure;
544 bool m_structureIsUptodate;
545};
546
561template<typename _MatrixType, int _UpLo>
562class PastixLLT : public PastixBase< PastixLLT<_MatrixType, _UpLo> >
563{
564 public:
565 typedef _MatrixType MatrixType;
566 typedef PastixBase<PastixLLT<MatrixType, _UpLo> > Base;
567 typedef typename Base::ColSpMatrix ColSpMatrix;
568
569 public:
570 enum { UpLo = _UpLo };
571 PastixLLT() : Base()
572 {
573 init();
574 }
575
576 PastixLLT(const MatrixType& matrix):Base()
577 {
578 init();
579 compute(matrix);
580 }
581
585 void compute (const MatrixType& matrix)
586 {
587 ColSpMatrix temp;
588 grabMatrix(matrix, temp);
589 Base::compute(temp);
590 }
591
596 void analyzePattern(const MatrixType& matrix)
597 {
598 ColSpMatrix temp;
599 grabMatrix(matrix, temp);
600 Base::analyzePattern(temp);
601 }
602
605 void factorize(const MatrixType& matrix)
606 {
607 ColSpMatrix temp;
608 grabMatrix(matrix, temp);
609 Base::factorize(temp);
610 }
611 protected:
612 using Base::m_iparm;
613
614 void init()
615 {
616 m_iparm(IPARM_SYM) = API_SYM_YES;
617 m_iparm(IPARM_FACTORIZATION) = API_FACT_LLT;
618 }
619
620 void grabMatrix(const MatrixType& matrix, ColSpMatrix& out)
621 {
622 // Pastix supports only lower, column-major matrices
623 out.template selfadjointView<Lower>() = matrix.template selfadjointView<UpLo>();
624 internal::c_to_fortran_numbering(out);
625 }
626};
627
642template<typename _MatrixType, int _UpLo>
643class PastixLDLT : public PastixBase< PastixLDLT<_MatrixType, _UpLo> >
644{
645 public:
646 typedef _MatrixType MatrixType;
647 typedef PastixBase<PastixLDLT<MatrixType, _UpLo> > Base;
648 typedef typename Base::ColSpMatrix ColSpMatrix;
649
650 public:
651 enum { UpLo = _UpLo };
652 PastixLDLT():Base()
653 {
654 init();
655 }
656
657 PastixLDLT(const MatrixType& matrix):Base()
658 {
659 init();
660 compute(matrix);
661 }
662
666 void compute (const MatrixType& matrix)
667 {
668 ColSpMatrix temp;
669 grabMatrix(matrix, temp);
670 Base::compute(temp);
671 }
672
677 void analyzePattern(const MatrixType& matrix)
678 {
679 ColSpMatrix temp;
680 grabMatrix(matrix, temp);
681 Base::analyzePattern(temp);
682 }
683
686 void factorize(const MatrixType& matrix)
687 {
688 ColSpMatrix temp;
689 grabMatrix(matrix, temp);
690 Base::factorize(temp);
691 }
692
693 protected:
694 using Base::m_iparm;
695
696 void init()
697 {
698 m_iparm(IPARM_SYM) = API_SYM_YES;
699 m_iparm(IPARM_FACTORIZATION) = API_FACT_LDLT;
700 }
701
702 void grabMatrix(const MatrixType& matrix, ColSpMatrix& out)
703 {
704 // Pastix supports only lower, column-major matrices
705 out.template selfadjointView<Lower>() = matrix.template selfadjointView<UpLo>();
706 internal::c_to_fortran_numbering(out);
707 }
708};
709
710namespace internal {
711
712template<typename _MatrixType, typename Rhs>
713struct solve_retval<PastixBase<_MatrixType>, Rhs>
714 : solve_retval_base<PastixBase<_MatrixType>, Rhs>
715{
716 typedef PastixBase<_MatrixType> Dec;
717 EIGEN_MAKE_SOLVE_HELPERS(Dec,Rhs)
718
719 template<typename Dest> void evalTo(Dest& dst) const
720 {
721 dec()._solve(rhs(),dst);
722 }
723};
724
725template<typename _MatrixType, typename Rhs>
726struct sparse_solve_retval<PastixBase<_MatrixType>, Rhs>
727 : sparse_solve_retval_base<PastixBase<_MatrixType>, Rhs>
728{
729 typedef PastixBase<_MatrixType> Dec;
730 EIGEN_MAKE_SPARSE_SOLVE_HELPERS(Dec,Rhs)
731
732 template<typename Dest> void evalTo(Dest& dst) const
733 {
734 dec()._solve_sparse(rhs(),dst);
735 }
736};
737
738} // end namespace internal
739
740} // end namespace Eigen
741
742#endif
Base class for all dense matrices, vectors, and expressions.
Definition MatrixBase.h:50
A sparse direct supernodal Cholesky (LLT) factorization and solver based on the PaStiX library.
Definition PaStiXSupport.h:644
void compute(const MatrixType &matrix)
Definition PaStiXSupport.h:666
void factorize(const MatrixType &matrix)
Definition PaStiXSupport.h:686
void analyzePattern(const MatrixType &matrix)
Definition PaStiXSupport.h:677
A sparse direct supernodal Cholesky (LLT) factorization and solver based on the PaStiX library.
Definition PaStiXSupport.h:563
void compute(const MatrixType &matrix)
Definition PaStiXSupport.h:585
void factorize(const MatrixType &matrix)
Definition PaStiXSupport.h:605
void analyzePattern(const MatrixType &matrix)
Definition PaStiXSupport.h:596
Interface to the PaStix solver.
Definition PaStiXSupport.h:453
void compute(const MatrixType &matrix)
Definition PaStiXSupport.h:476
void factorize(const MatrixType &matrix)
Definition PaStiXSupport.h:501
void analyzePattern(const MatrixType &matrix)
Definition PaStiXSupport.h:488
Index outerSize() const
Definition SparseMatrix.h:126
ComputationInfo
Definition Constants.h:367
@ NumericalIssue
Definition Constants.h:371
@ InvalidInput
Definition Constants.h:376
@ Success
Definition Constants.h:369
const unsigned int RowMajorBit
Definition Constants.h:48
Definition LDLT.h:18