12#ifndef KRONECKER_TENSOR_PRODUCT_H
13#define KRONECKER_TENSOR_PRODUCT_H
16#include "./InternalHeaderCheck.h"
27template <
typename Derived>
30 typedef typename internal::traits<Derived> Traits;
31 typedef typename Traits::Scalar Scalar;
34 typedef typename Traits::Lhs Lhs;
35 typedef typename Traits::Rhs Rhs;
41 inline Index rows()
const {
return m_A.rows() * m_B.rows(); }
42 inline Index cols()
const {
return m_A.cols() * m_B.cols(); }
49 return m_A.coeff(row / m_B.rows(), col / m_B.cols()) * m_B.coeff(row % m_B.rows(), col % m_B.cols());
57 EIGEN_STATIC_ASSERT_VECTOR_ONLY(Derived);
58 return m_A.coeff(i / m_A.size()) * m_B.coeff(i % m_A.size());
62 typename Lhs::Nested m_A;
63 typename Rhs::Nested m_B;
78template <
typename Lhs,
typename Rhs>
90 template <
typename Dest>
91 void evalTo(Dest& dst)
const;
109template <
typename Lhs,
typename Rhs>
121 template <
typename Dest>
122 void evalTo(Dest& dst)
const;
125template <
typename Lhs,
typename Rhs>
126template <
typename Dest>
128 const int BlockRows = Rhs::RowsAtCompileTime, BlockCols = Rhs::ColsAtCompileTime;
129 const Index Br = m_B.rows(), Bc = m_B.cols();
130 for (
Index i = 0; i < m_A.rows(); ++i)
131 for (
Index j = 0; j < m_A.cols(); ++j)
135template <
typename Lhs,
typename Rhs>
136template <
typename Dest>
138 Index Br = m_B.rows(), Bc = m_B.cols();
139 dst.resize(this->rows(), this->cols());
140 dst.resizeNonZeros(0);
143 typedef typename internal::nested_eval<Lhs, Dynamic>::type Lhs1;
144 typedef internal::remove_all_t<Lhs1> Lhs1Cleaned;
145 const Lhs1 lhs1(m_A);
146 typedef typename internal::nested_eval<Rhs, Dynamic>::type Rhs1;
147 typedef internal::remove_all_t<Rhs1> Rhs1Cleaned;
148 const Rhs1 rhs1(m_B);
157 VectorXi nnzA = VectorXi::Zero(Dest::IsRowMajor ? m_A.rows() : m_A.cols());
158 for (
Index kA = 0; kA < m_A.outerSize(); ++kA)
159 for (LhsInnerIterator itA(lhs1, kA); itA; ++itA) nnzA(Dest::IsRowMajor ? itA.row() : itA.col())++;
161 VectorXi nnzB = VectorXi::Zero(Dest::IsRowMajor ? m_B.rows() : m_B.cols());
162 for (
Index kB = 0; kB < m_B.outerSize(); ++kB)
163 for (RhsInnerIterator itB(rhs1, kB); itB; ++itB) nnzB(Dest::IsRowMajor ? itB.row() : itB.col())++;
166 dst.reserve(VectorXi::Map(nnzAB.
data(), nnzAB.size()));
169 for (
Index kA = 0; kA < m_A.outerSize(); ++kA) {
170 for (
Index kB = 0; kB < m_B.outerSize(); ++kB) {
171 for (LhsInnerIterator itA(lhs1, kA); itA; ++itA) {
172 for (RhsInnerIterator itB(rhs1, kB); itB; ++itB) {
173 Index i = itA.row() * Br + itB.row(), j = itA.col() * Bc + itB.col();
174 dst.insert(i, j) = itA.value() * itB.value();
183template <
typename Lhs_,
typename Rhs_>
185 typedef remove_all_t<Lhs_> Lhs;
186 typedef remove_all_t<Rhs_> Rhs;
188 typedef typename promote_index_type<typename Lhs::StorageIndex, typename Rhs::StorageIndex>::type StorageIndex;
191 Rows = size_at_compile_time(traits<Lhs>::RowsAtCompileTime, traits<Rhs>::RowsAtCompileTime),
192 Cols = size_at_compile_time(traits<Lhs>::ColsAtCompileTime, traits<Rhs>::ColsAtCompileTime),
193 MaxRows = size_at_compile_time(traits<Lhs>::MaxRowsAtCompileTime, traits<Rhs>::MaxRowsAtCompileTime),
194 MaxCols = size_at_compile_time(traits<Lhs>::MaxColsAtCompileTime, traits<Rhs>::MaxColsAtCompileTime)
197 typedef Matrix<Scalar, Rows, Cols> ReturnType;
200template <
typename Lhs_,
typename Rhs_>
201struct traits<KroneckerProductSparse<Lhs_, Rhs_> > {
202 typedef MatrixXpr XprKind;
203 typedef remove_all_t<Lhs_> Lhs;
204 typedef remove_all_t<Rhs_> Rhs;
205 typedef typename ScalarBinaryOpTraits<typename Lhs::Scalar, typename Rhs::Scalar>::ReturnType Scalar;
206 typedef typename cwise_promote_storage_type<typename traits<Lhs>::StorageKind,
typename traits<Rhs>::StorageKind,
207 scalar_product_op<typename Lhs::Scalar, typename Rhs::Scalar> >::ret
209 typedef typename promote_index_type<typename Lhs::StorageIndex, typename Rhs::StorageIndex>::type StorageIndex;
212 LhsFlags = Lhs::Flags,
213 RhsFlags = Rhs::Flags,
215 RowsAtCompileTime = size_at_compile_time(traits<Lhs>::RowsAtCompileTime, traits<Rhs>::RowsAtCompileTime),
216 ColsAtCompileTime = size_at_compile_time(traits<Lhs>::ColsAtCompileTime, traits<Rhs>::ColsAtCompileTime),
217 MaxRowsAtCompileTime = size_at_compile_time(traits<Lhs>::MaxRowsAtCompileTime, traits<Rhs>::MaxRowsAtCompileTime),
218 MaxColsAtCompileTime = size_at_compile_time(traits<Lhs>::MaxColsAtCompileTime, traits<Rhs>::MaxColsAtCompileTime),
220 EvalToRowMajor = (int(LhsFlags) & int(RhsFlags) &
RowMajorBit),
221 RemovedBits = ~(EvalToRowMajor ? 0 : RowMajorBit),
223 Flags = ((int(LhsFlags) | int(RhsFlags)) & HereditaryBits & RemovedBits) | EvalBeforeNestingBit,
224 CoeffReadCost = HugeCost
227 typedef SparseMatrix<Scalar, 0, StorageIndex> ReturnType;
251template <
typename A,
typename B>
277template <
typename A,
typename B>
Scalar coeff(Index row, Index col) const
Definition KroneckerTensorProduct.h:48
KroneckerProductBase(const Lhs &A, const Rhs &B)
Constructor.
Definition KroneckerTensorProduct.h:39
Scalar coeff(Index i) const
Definition KroneckerTensorProduct.h:56
Kronecker tensor product helper class for sparse matrices.
Definition KroneckerTensorProduct.h:110
void evalTo(Dest &dst) const
Evaluate the Kronecker tensor product.
Definition KroneckerTensorProduct.h:137
KroneckerProductSparse(const Lhs &A, const Rhs &B)
Constructor.
Definition KroneckerTensorProduct.h:118
Kronecker tensor product helper class for dense matrices.
Definition KroneckerTensorProduct.h:79
KroneckerProduct(const Lhs &A, const Rhs &B)
Constructor.
Definition KroneckerTensorProduct.h:87
void evalTo(Dest &dst) const
Evaluate the Kronecker tensor product.
Definition KroneckerTensorProduct.h:127
constexpr Scalar * data()
KroneckerProduct< A, B > kroneckerProduct(const MatrixBase< A > &a, const MatrixBase< B > &b)
Definition KroneckerTensorProduct.h:252
const unsigned int RowMajorBit
Matrix< int, Dynamic, 1 > VectorXi
Namespace containing all symbols from the Eigen library.
EIGEN_DEFAULT_DENSE_INDEX_TYPE Index
constexpr Derived & derived()