Eigen-unsupported  5.0.1-dev+284dcc12
 
Loading...
Searching...
No Matches
KroneckerTensorProduct.h
1// This file is part of Eigen, a lightweight C++ template library
2// for linear algebra.
3//
4// Copyright (C) 2011 Kolja Brix <brix@igpm.rwth-aachen.de>
5// Copyright (C) 2011 Andreas Platen <andiplaten@gmx.de>
6// Copyright (C) 2012 Chen-Pang He <jdh8@ms63.hinet.net>
7//
8// This Source Code Form is subject to the terms of the Mozilla
9// Public License v. 2.0. If a copy of the MPL was not distributed
10// with this file, You can obtain one at http://mozilla.org/MPL/2.0/.
11
12#ifndef KRONECKER_TENSOR_PRODUCT_H
13#define KRONECKER_TENSOR_PRODUCT_H
14
15// IWYU pragma: private
16#include "./InternalHeaderCheck.h"
17
18namespace Eigen {
19
27template <typename Derived>
28class KroneckerProductBase : public ReturnByValue<Derived> {
29 private:
30 typedef typename internal::traits<Derived> Traits;
31 typedef typename Traits::Scalar Scalar;
32
33 protected:
34 typedef typename Traits::Lhs Lhs;
35 typedef typename Traits::Rhs Rhs;
36
37 public:
39 KroneckerProductBase(const Lhs& A, const Rhs& B) : m_A(A), m_B(B) {}
40
41 inline Index rows() const { return m_A.rows() * m_B.rows(); }
42 inline Index cols() const { return m_A.cols() * m_B.cols(); }
43
48 Scalar coeff(Index row, Index col) const {
49 return m_A.coeff(row / m_B.rows(), col / m_B.cols()) * m_B.coeff(row % m_B.rows(), col % m_B.cols());
50 }
51
56 Scalar coeff(Index i) const {
57 EIGEN_STATIC_ASSERT_VECTOR_ONLY(Derived);
58 return m_A.coeff(i / m_A.size()) * m_B.coeff(i % m_A.size());
59 }
60
61 protected:
62 typename Lhs::Nested m_A;
63 typename Rhs::Nested m_B;
64};
65
78template <typename Lhs, typename Rhs>
79class KroneckerProduct : public KroneckerProductBase<KroneckerProduct<Lhs, Rhs> > {
80 private:
82 using Base::m_A;
83 using Base::m_B;
84
85 public:
87 KroneckerProduct(const Lhs& A, const Rhs& B) : Base(A, B) {}
88
90 template <typename Dest>
91 void evalTo(Dest& dst) const;
92};
93
109template <typename Lhs, typename Rhs>
110class KroneckerProductSparse : public KroneckerProductBase<KroneckerProductSparse<Lhs, Rhs> > {
111 private:
113 using Base::m_A;
114 using Base::m_B;
115
116 public:
118 KroneckerProductSparse(const Lhs& A, const Rhs& B) : Base(A, B) {}
119
121 template <typename Dest>
122 void evalTo(Dest& dst) const;
123};
124
125template <typename Lhs, typename Rhs>
126template <typename Dest>
128 const int BlockRows = Rhs::RowsAtCompileTime, BlockCols = Rhs::ColsAtCompileTime;
129 const Index Br = m_B.rows(), Bc = m_B.cols();
130 for (Index i = 0; i < m_A.rows(); ++i)
131 for (Index j = 0; j < m_A.cols(); ++j)
132 Block<Dest, BlockRows, BlockCols>(dst, i * Br, j * Bc, Br, Bc) = m_A.coeff(i, j) * m_B;
133}
134
135template <typename Lhs, typename Rhs>
136template <typename Dest>
138 Index Br = m_B.rows(), Bc = m_B.cols();
139 dst.resize(this->rows(), this->cols());
140 dst.resizeNonZeros(0);
141
142 // 1 - evaluate the operands if needed:
143 typedef typename internal::nested_eval<Lhs, Dynamic>::type Lhs1;
144 typedef internal::remove_all_t<Lhs1> Lhs1Cleaned;
145 const Lhs1 lhs1(m_A);
146 typedef typename internal::nested_eval<Rhs, Dynamic>::type Rhs1;
147 typedef internal::remove_all_t<Rhs1> Rhs1Cleaned;
148 const Rhs1 rhs1(m_B);
149
150 // 2 - construct respective iterators
151 typedef Eigen::InnerIterator<Lhs1Cleaned> LhsInnerIterator;
152 typedef Eigen::InnerIterator<Rhs1Cleaned> RhsInnerIterator;
153
154 // compute number of non-zeros per innervectors of dst
155 {
156 // TODO VectorXi is not necessarily big enough!
157 VectorXi nnzA = VectorXi::Zero(Dest::IsRowMajor ? m_A.rows() : m_A.cols());
158 for (Index kA = 0; kA < m_A.outerSize(); ++kA)
159 for (LhsInnerIterator itA(lhs1, kA); itA; ++itA) nnzA(Dest::IsRowMajor ? itA.row() : itA.col())++;
160
161 VectorXi nnzB = VectorXi::Zero(Dest::IsRowMajor ? m_B.rows() : m_B.cols());
162 for (Index kB = 0; kB < m_B.outerSize(); ++kB)
163 for (RhsInnerIterator itB(rhs1, kB); itB; ++itB) nnzB(Dest::IsRowMajor ? itB.row() : itB.col())++;
164
165 Matrix<int, Dynamic, Dynamic, ColMajor> nnzAB = nnzB * nnzA.transpose();
166 dst.reserve(VectorXi::Map(nnzAB.data(), nnzAB.size()));
167 }
168
169 for (Index kA = 0; kA < m_A.outerSize(); ++kA) {
170 for (Index kB = 0; kB < m_B.outerSize(); ++kB) {
171 for (LhsInnerIterator itA(lhs1, kA); itA; ++itA) {
172 for (RhsInnerIterator itB(rhs1, kB); itB; ++itB) {
173 Index i = itA.row() * Br + itB.row(), j = itA.col() * Bc + itB.col();
174 dst.insert(i, j) = itA.value() * itB.value();
175 }
176 }
177 }
178 }
179}
180
181namespace internal {
182
183template <typename Lhs_, typename Rhs_>
184struct traits<KroneckerProduct<Lhs_, Rhs_> > {
185 typedef remove_all_t<Lhs_> Lhs;
186 typedef remove_all_t<Rhs_> Rhs;
188 typedef typename promote_index_type<typename Lhs::StorageIndex, typename Rhs::StorageIndex>::type StorageIndex;
189
190 enum {
191 Rows = size_at_compile_time(traits<Lhs>::RowsAtCompileTime, traits<Rhs>::RowsAtCompileTime),
192 Cols = size_at_compile_time(traits<Lhs>::ColsAtCompileTime, traits<Rhs>::ColsAtCompileTime),
193 MaxRows = size_at_compile_time(traits<Lhs>::MaxRowsAtCompileTime, traits<Rhs>::MaxRowsAtCompileTime),
194 MaxCols = size_at_compile_time(traits<Lhs>::MaxColsAtCompileTime, traits<Rhs>::MaxColsAtCompileTime)
195 };
196
197 typedef Matrix<Scalar, Rows, Cols> ReturnType;
198};
199
200template <typename Lhs_, typename Rhs_>
201struct traits<KroneckerProductSparse<Lhs_, Rhs_> > {
202 typedef MatrixXpr XprKind;
203 typedef remove_all_t<Lhs_> Lhs;
204 typedef remove_all_t<Rhs_> Rhs;
205 typedef typename ScalarBinaryOpTraits<typename Lhs::Scalar, typename Rhs::Scalar>::ReturnType Scalar;
206 typedef typename cwise_promote_storage_type<typename traits<Lhs>::StorageKind, typename traits<Rhs>::StorageKind,
207 scalar_product_op<typename Lhs::Scalar, typename Rhs::Scalar> >::ret
208 StorageKind;
209 typedef typename promote_index_type<typename Lhs::StorageIndex, typename Rhs::StorageIndex>::type StorageIndex;
210
211 enum {
212 LhsFlags = Lhs::Flags,
213 RhsFlags = Rhs::Flags,
214
215 RowsAtCompileTime = size_at_compile_time(traits<Lhs>::RowsAtCompileTime, traits<Rhs>::RowsAtCompileTime),
216 ColsAtCompileTime = size_at_compile_time(traits<Lhs>::ColsAtCompileTime, traits<Rhs>::ColsAtCompileTime),
217 MaxRowsAtCompileTime = size_at_compile_time(traits<Lhs>::MaxRowsAtCompileTime, traits<Rhs>::MaxRowsAtCompileTime),
218 MaxColsAtCompileTime = size_at_compile_time(traits<Lhs>::MaxColsAtCompileTime, traits<Rhs>::MaxColsAtCompileTime),
219
220 EvalToRowMajor = (int(LhsFlags) & int(RhsFlags) & RowMajorBit),
221 RemovedBits = ~(EvalToRowMajor ? 0 : RowMajorBit),
222
223 Flags = ((int(LhsFlags) | int(RhsFlags)) & HereditaryBits & RemovedBits) | EvalBeforeNestingBit,
224 CoeffReadCost = HugeCost
225 };
226
227 typedef SparseMatrix<Scalar, 0, StorageIndex> ReturnType;
228};
229
230} // end namespace internal
231
251template <typename A, typename B>
253 return KroneckerProduct<A, B>(a.derived(), b.derived());
254}
255
277template <typename A, typename B>
281
282} // end namespace Eigen
283
284#endif // KRONECKER_TENSOR_PRODUCT_H
Scalar coeff(Index row, Index col) const
Definition KroneckerTensorProduct.h:48
KroneckerProductBase(const Lhs &A, const Rhs &B)
Constructor.
Definition KroneckerTensorProduct.h:39
Scalar coeff(Index i) const
Definition KroneckerTensorProduct.h:56
Kronecker tensor product helper class for sparse matrices.
Definition KroneckerTensorProduct.h:110
void evalTo(Dest &dst) const
Evaluate the Kronecker tensor product.
Definition KroneckerTensorProduct.h:137
KroneckerProductSparse(const Lhs &A, const Rhs &B)
Constructor.
Definition KroneckerTensorProduct.h:118
Kronecker tensor product helper class for dense matrices.
Definition KroneckerTensorProduct.h:79
KroneckerProduct(const Lhs &A, const Rhs &B)
Constructor.
Definition KroneckerTensorProduct.h:87
void evalTo(Dest &dst) const
Evaluate the Kronecker tensor product.
Definition KroneckerTensorProduct.h:127
constexpr Scalar * data()
KroneckerProduct< A, B > kroneckerProduct(const MatrixBase< A > &a, const MatrixBase< B > &b)
Definition KroneckerTensorProduct.h:252
const unsigned int RowMajorBit
Matrix< int, Dynamic, 1 > VectorXi
Namespace containing all symbols from the Eigen library.
EIGEN_DEFAULT_DENSE_INDEX_TYPE Index
constexpr Derived & derived()