11#include "./InternalHeaderCheck.h"
20#define RUN_OR_ASSERT(EXPR, ERROR_MSG) \
22 MKL_LONG status = (EXPR); \
23 eigen_assert(status == DFTI_NO_ERROR && (ERROR_MSG)); \
26inline MKL_Complex16* complex_cast(
const std::complex<double>* p) {
27 return const_cast<MKL_Complex16*
>(
reinterpret_cast<const MKL_Complex16*
>(p));
30inline MKL_Complex8* complex_cast(
const std::complex<float>* p) {
31 return const_cast<MKL_Complex8*
>(
reinterpret_cast<const MKL_Complex8*
>(p));
43inline void configure_descriptor(std::shared_ptr<DFTI_DESCRIPTOR>& handl,
enum DFTI_CONFIG_VALUE precision,
44 enum DFTI_CONFIG_VALUE forward_domain, MKL_LONG dimension, MKL_LONG* sizes) {
45 eigen_assert(dimension == 1 || dimension == 2 &&
"Transformation dimension must be less than 3.");
47 DFTI_DESCRIPTOR_HANDLE res =
nullptr;
49 RUN_OR_ASSERT(DftiCreateDescriptor(&res, precision, forward_domain, dimension, *sizes),
50 "DftiCreateDescriptor failed.")
51 handl.reset(res, [](DFTI_DESCRIPTOR_HANDLE handle) { DftiFreeDescriptor(&handle); });
52 if (forward_domain == DFTI_REAL) {
54 RUN_OR_ASSERT(DftiSetValue(handl.get(), DFTI_CONJUGATE_EVEN_STORAGE, DFTI_COMPLEX_COMPLEX),
55 "DftiSetValue failed.")
58 RUN_OR_ASSERT(DftiCreateDescriptor(&res, precision, DFTI_COMPLEX, dimension, sizes),
"DftiCreateDescriptor failed.")
59 handl.reset(res, [](DFTI_DESCRIPTOR_HANDLE handle) { DftiFreeDescriptor(&handle); });
62 RUN_OR_ASSERT(DftiSetValue(handl.get(), DFTI_PLACEMENT, DFTI_NOT_INPLACE),
"DftiSetValue failed.")
63 RUN_OR_ASSERT(DftiCommitDescriptor(handl.get()),
"DftiCommitDescriptor failed.")
71 typedef float scalar_type;
72 typedef MKL_Complex8 complex_type;
74 std::shared_ptr<DFTI_DESCRIPTOR> m_plan;
78 enum DFTI_CONFIG_VALUE precision = DFTI_SINGLE;
80 inline void forward(complex_type* dst, complex_type* src, MKL_LONG nfft) {
82 configure_descriptor(m_plan, precision, DFTI_COMPLEX, 1, &nfft);
84 RUN_OR_ASSERT(DftiComputeForward(m_plan.get(), src, dst),
"DftiComputeForward failed.")
87 inline void inverse(complex_type* dst, complex_type* src, MKL_LONG nfft) {
89 configure_descriptor(m_plan, precision, DFTI_COMPLEX, 1, &nfft);
91 RUN_OR_ASSERT(DftiComputeBackward(m_plan.get(), src, dst),
"DftiComputeBackward failed.")
94 inline void forward(complex_type* dst, scalar_type* src, MKL_LONG nfft) {
96 configure_descriptor(m_plan, precision, DFTI_REAL, 1, &nfft);
98 RUN_OR_ASSERT(DftiComputeForward(m_plan.get(), src, dst),
"DftiComputeForward failed.")
101 inline void inverse(scalar_type* dst, complex_type* src, MKL_LONG nfft) {
103 configure_descriptor(m_plan, precision, DFTI_REAL, 1, &nfft);
105 RUN_OR_ASSERT(DftiComputeBackward(m_plan.get(), src, dst),
"DftiComputeBackward failed.")
108 inline void forward2(complex_type* dst, complex_type* src,
int n0,
int n1) {
110 MKL_LONG sizes[2] = {n0, n1};
111 configure_descriptor(m_plan, precision, DFTI_COMPLEX, 2, sizes);
113 RUN_OR_ASSERT(DftiComputeForward(m_plan.get(), src, dst),
"DftiComputeForward failed.")
116 inline void inverse2(complex_type* dst, complex_type* src,
int n0,
int n1) {
118 MKL_LONG sizes[2] = {n0, n1};
119 configure_descriptor(m_plan, precision, DFTI_COMPLEX, 2, sizes);
121 RUN_OR_ASSERT(DftiComputeBackward(m_plan.get(), src, dst),
"DftiComputeBackward failed.")
127 typedef double scalar_type;
128 typedef MKL_Complex16 complex_type;
130 std::shared_ptr<DFTI_DESCRIPTOR> m_plan;
134 enum DFTI_CONFIG_VALUE precision = DFTI_DOUBLE;
136 inline void forward(complex_type* dst, complex_type* src, MKL_LONG nfft) {
138 configure_descriptor(m_plan, precision, DFTI_COMPLEX, 1, &nfft);
140 RUN_OR_ASSERT(DftiComputeForward(m_plan.get(), src, dst),
"DftiComputeForward failed.")
143 inline void inverse(complex_type* dst, complex_type* src, MKL_LONG nfft) {
145 configure_descriptor(m_plan, precision, DFTI_COMPLEX, 1, &nfft);
147 RUN_OR_ASSERT(DftiComputeBackward(m_plan.get(), src, dst),
"DftiComputeBackward failed.")
150 inline void forward(complex_type* dst, scalar_type* src, MKL_LONG nfft) {
152 configure_descriptor(m_plan, precision, DFTI_REAL, 1, &nfft);
154 RUN_OR_ASSERT(DftiComputeForward(m_plan.get(), src, dst),
"DftiComputeForward failed.")
157 inline void inverse(scalar_type* dst, complex_type* src, MKL_LONG nfft) {
159 configure_descriptor(m_plan, precision, DFTI_REAL, 1, &nfft);
161 RUN_OR_ASSERT(DftiComputeBackward(m_plan.get(), src, dst),
"DftiComputeBackward failed.")
164 inline void forward2(complex_type* dst, complex_type* src,
int n0,
int n1) {
166 MKL_LONG sizes[2] = {n0, n1};
167 configure_descriptor(m_plan, precision, DFTI_COMPLEX, 2, sizes);
169 RUN_OR_ASSERT(DftiComputeForward(m_plan.get(), src, dst),
"DftiComputeForward failed.")
172 inline void inverse2(complex_type* dst, complex_type* src,
int n0,
int n1) {
174 MKL_LONG sizes[2] = {n0, n1};
175 configure_descriptor(m_plan, precision, DFTI_COMPLEX, 2, sizes);
177 RUN_OR_ASSERT(DftiComputeBackward(m_plan.get(), src, dst),
"DftiComputeBackward failed.")
181template <
typename Scalar_>
183 typedef Scalar_ Scalar;
184 typedef std::complex<Scalar> Complex;
186 inline void clear() { m_plans.clear(); }
189 inline void fwd(Complex* dst,
const Complex* src,
int nfft) {
190 MKL_LONG size = nfft;
191 get_plan(nfft, dst, src).forward(complex_cast(dst), complex_cast(src), size);
195 inline void fwd(Complex* dst,
const Scalar* src,
int nfft) {
196 MKL_LONG size = nfft;
197 get_plan(nfft, dst, src).forward(complex_cast(dst),
const_cast<Scalar*
>(src), nfft);
201 inline void fwd2(Complex* dst,
const Complex* src,
int n0,
int n1) {
202 get_plan(n0, n1, dst, src).forward2(complex_cast(dst), complex_cast(src), n0, n1);
206 inline void inv(Complex* dst,
const Complex* src,
int nfft) {
207 MKL_LONG size = nfft;
208 get_plan(nfft, dst, src).inverse(complex_cast(dst), complex_cast(src), nfft);
212 inline void inv(Scalar* dst,
const Complex* src,
int nfft) {
213 MKL_LONG size = nfft;
214 get_plan(nfft, dst, src).inverse(
const_cast<Scalar*
>(dst), complex_cast(src), nfft);
218 inline void inv2(Complex* dst,
const Complex* src,
int n0,
int n1) {
219 get_plan(n0, n1, dst, src).inverse2(complex_cast(dst), complex_cast(src), n0, n1);
223 std::map<int64_t, plan<Scalar>> m_plans;
225 inline plan<Scalar>& get_plan(
int nfft,
void* dst,
const void* src) {
226 int inplace = dst == src ? 1 : 0;
227 int aligned = ((
reinterpret_cast<size_t>(src) & 15) | (
reinterpret_cast<size_t>(dst) & 15)) == 0 ? 1 : 0;
228 int64_t key = ((nfft << 2) | (inplace << 1) | aligned) << 1;
234 inline plan<Scalar>& get_plan(
int n0,
int n1,
void* dst,
const void* src) {
235 int inplace = (dst == src) ? 1 : 0;
236 int aligned = ((
reinterpret_cast<size_t>(src) & 15) | (
reinterpret_cast<size_t>(dst) & 15)) == 0 ? 1 : 0;
237 int64_t key = (((((int64_t)n0) << 31) | (n1 << 2) | (inplace << 1) | aligned) << 1) + 1;
Namespace containing all symbols from the Eigen library.
const Eigen::CwiseUnaryOp< Eigen::internal::scalar_inverse_op< typename Derived::Scalar >, const Derived > inverse(const Eigen::ArrayBase< Derived > &x)