Eigen-unsupported  5.0.1-dev+284dcc12
 
Loading...
Searching...
No Matches
imklfft_impl.h
1// This file is part of Eigen, a lightweight C++ template library
2// for linear algebra.
3//
4// This Source Code Form is subject to the terms of the Mozilla
5// Public License v. 2.0. If a copy of the MPL was not distributed
6// with this file, You can obtain one at http://mozilla.org/MPL/2.0/.
7
8#include <mkl_dfti.h>
9
10// IWYU pragma: private
11#include "./InternalHeaderCheck.h"
12
13#include <complex>
14#include <memory>
15
16namespace Eigen {
17namespace internal {
18namespace imklfft {
19
20#define RUN_OR_ASSERT(EXPR, ERROR_MSG) \
21 { \
22 MKL_LONG status = (EXPR); \
23 eigen_assert(status == DFTI_NO_ERROR && (ERROR_MSG)); \
24 };
25
26inline MKL_Complex16* complex_cast(const std::complex<double>* p) {
27 return const_cast<MKL_Complex16*>(reinterpret_cast<const MKL_Complex16*>(p));
28}
29
30inline MKL_Complex8* complex_cast(const std::complex<float>* p) {
31 return const_cast<MKL_Complex8*>(reinterpret_cast<const MKL_Complex8*>(p));
32}
33
34/*
35 * Parameters:
36 * precision: enum, Precision of the transform: DFTI_SINGLE or DFTI_DOUBLE.
37 * forward_domain: enum, Forward domain of the transform: DFTI_COMPLEX or
38 * DFTI_REAL. dimension: MKL_LONG Dimension of the transform. sizes: MKL_LONG if
39 * dimension = 1.Length of the transform for a one-dimensional transform. sizes:
40 * Array of type MKL_LONG otherwise. Lengths of each dimension for a
41 * multi-dimensional transform.
42 */
43inline void configure_descriptor(std::shared_ptr<DFTI_DESCRIPTOR>& handl, enum DFTI_CONFIG_VALUE precision,
44 enum DFTI_CONFIG_VALUE forward_domain, MKL_LONG dimension, MKL_LONG* sizes) {
45 eigen_assert(dimension == 1 || dimension == 2 && "Transformation dimension must be less than 3.");
46
47 DFTI_DESCRIPTOR_HANDLE res = nullptr;
48 if (dimension == 1) {
49 RUN_OR_ASSERT(DftiCreateDescriptor(&res, precision, forward_domain, dimension, *sizes),
50 "DftiCreateDescriptor failed.")
51 handl.reset(res, [](DFTI_DESCRIPTOR_HANDLE handle) { DftiFreeDescriptor(&handle); });
52 if (forward_domain == DFTI_REAL) {
53 // Set CCE storage
54 RUN_OR_ASSERT(DftiSetValue(handl.get(), DFTI_CONJUGATE_EVEN_STORAGE, DFTI_COMPLEX_COMPLEX),
55 "DftiSetValue failed.")
56 }
57 } else {
58 RUN_OR_ASSERT(DftiCreateDescriptor(&res, precision, DFTI_COMPLEX, dimension, sizes), "DftiCreateDescriptor failed.")
59 handl.reset(res, [](DFTI_DESCRIPTOR_HANDLE handle) { DftiFreeDescriptor(&handle); });
60 }
61
62 RUN_OR_ASSERT(DftiSetValue(handl.get(), DFTI_PLACEMENT, DFTI_NOT_INPLACE), "DftiSetValue failed.")
63 RUN_OR_ASSERT(DftiCommitDescriptor(handl.get()), "DftiCommitDescriptor failed.")
64}
65
66template <typename T>
67struct plan {};
68
69template <>
70struct plan<float> {
71 typedef float scalar_type;
72 typedef MKL_Complex8 complex_type;
73
74 std::shared_ptr<DFTI_DESCRIPTOR> m_plan;
75
76 plan() = default;
77
78 enum DFTI_CONFIG_VALUE precision = DFTI_SINGLE;
79
80 inline void forward(complex_type* dst, complex_type* src, MKL_LONG nfft) {
81 if (m_plan == 0) {
82 configure_descriptor(m_plan, precision, DFTI_COMPLEX, 1, &nfft);
83 }
84 RUN_OR_ASSERT(DftiComputeForward(m_plan.get(), src, dst), "DftiComputeForward failed.")
85 }
86
87 inline void inverse(complex_type* dst, complex_type* src, MKL_LONG nfft) {
88 if (m_plan == 0) {
89 configure_descriptor(m_plan, precision, DFTI_COMPLEX, 1, &nfft);
90 }
91 RUN_OR_ASSERT(DftiComputeBackward(m_plan.get(), src, dst), "DftiComputeBackward failed.")
92 }
93
94 inline void forward(complex_type* dst, scalar_type* src, MKL_LONG nfft) {
95 if (m_plan == 0) {
96 configure_descriptor(m_plan, precision, DFTI_REAL, 1, &nfft);
97 }
98 RUN_OR_ASSERT(DftiComputeForward(m_plan.get(), src, dst), "DftiComputeForward failed.")
99 }
100
101 inline void inverse(scalar_type* dst, complex_type* src, MKL_LONG nfft) {
102 if (m_plan == 0) {
103 configure_descriptor(m_plan, precision, DFTI_REAL, 1, &nfft);
104 }
105 RUN_OR_ASSERT(DftiComputeBackward(m_plan.get(), src, dst), "DftiComputeBackward failed.")
106 }
107
108 inline void forward2(complex_type* dst, complex_type* src, int n0, int n1) {
109 if (m_plan == 0) {
110 MKL_LONG sizes[2] = {n0, n1};
111 configure_descriptor(m_plan, precision, DFTI_COMPLEX, 2, sizes);
112 }
113 RUN_OR_ASSERT(DftiComputeForward(m_plan.get(), src, dst), "DftiComputeForward failed.")
114 }
115
116 inline void inverse2(complex_type* dst, complex_type* src, int n0, int n1) {
117 if (m_plan == 0) {
118 MKL_LONG sizes[2] = {n0, n1};
119 configure_descriptor(m_plan, precision, DFTI_COMPLEX, 2, sizes);
120 }
121 RUN_OR_ASSERT(DftiComputeBackward(m_plan.get(), src, dst), "DftiComputeBackward failed.")
122 }
123};
124
125template <>
126struct plan<double> {
127 typedef double scalar_type;
128 typedef MKL_Complex16 complex_type;
129
130 std::shared_ptr<DFTI_DESCRIPTOR> m_plan;
131
132 plan() = default;
133
134 enum DFTI_CONFIG_VALUE precision = DFTI_DOUBLE;
135
136 inline void forward(complex_type* dst, complex_type* src, MKL_LONG nfft) {
137 if (m_plan == 0) {
138 configure_descriptor(m_plan, precision, DFTI_COMPLEX, 1, &nfft);
139 }
140 RUN_OR_ASSERT(DftiComputeForward(m_plan.get(), src, dst), "DftiComputeForward failed.")
141 }
142
143 inline void inverse(complex_type* dst, complex_type* src, MKL_LONG nfft) {
144 if (m_plan == 0) {
145 configure_descriptor(m_plan, precision, DFTI_COMPLEX, 1, &nfft);
146 }
147 RUN_OR_ASSERT(DftiComputeBackward(m_plan.get(), src, dst), "DftiComputeBackward failed.")
148 }
149
150 inline void forward(complex_type* dst, scalar_type* src, MKL_LONG nfft) {
151 if (m_plan == 0) {
152 configure_descriptor(m_plan, precision, DFTI_REAL, 1, &nfft);
153 }
154 RUN_OR_ASSERT(DftiComputeForward(m_plan.get(), src, dst), "DftiComputeForward failed.")
155 }
156
157 inline void inverse(scalar_type* dst, complex_type* src, MKL_LONG nfft) {
158 if (m_plan == 0) {
159 configure_descriptor(m_plan, precision, DFTI_REAL, 1, &nfft);
160 }
161 RUN_OR_ASSERT(DftiComputeBackward(m_plan.get(), src, dst), "DftiComputeBackward failed.")
162 }
163
164 inline void forward2(complex_type* dst, complex_type* src, int n0, int n1) {
165 if (m_plan == 0) {
166 MKL_LONG sizes[2] = {n0, n1};
167 configure_descriptor(m_plan, precision, DFTI_COMPLEX, 2, sizes);
168 }
169 RUN_OR_ASSERT(DftiComputeForward(m_plan.get(), src, dst), "DftiComputeForward failed.")
170 }
171
172 inline void inverse2(complex_type* dst, complex_type* src, int n0, int n1) {
173 if (m_plan == 0) {
174 MKL_LONG sizes[2] = {n0, n1};
175 configure_descriptor(m_plan, precision, DFTI_COMPLEX, 2, sizes);
176 }
177 RUN_OR_ASSERT(DftiComputeBackward(m_plan.get(), src, dst), "DftiComputeBackward failed.")
178 }
179};
180
181template <typename Scalar_>
182struct imklfft_impl {
183 typedef Scalar_ Scalar;
184 typedef std::complex<Scalar> Complex;
185
186 inline void clear() { m_plans.clear(); }
187
188 // complex-to-complex forward FFT
189 inline void fwd(Complex* dst, const Complex* src, int nfft) {
190 MKL_LONG size = nfft;
191 get_plan(nfft, dst, src).forward(complex_cast(dst), complex_cast(src), size);
192 }
193
194 // real-to-complex forward FFT
195 inline void fwd(Complex* dst, const Scalar* src, int nfft) {
196 MKL_LONG size = nfft;
197 get_plan(nfft, dst, src).forward(complex_cast(dst), const_cast<Scalar*>(src), nfft);
198 }
199
200 // 2-d complex-to-complex
201 inline void fwd2(Complex* dst, const Complex* src, int n0, int n1) {
202 get_plan(n0, n1, dst, src).forward2(complex_cast(dst), complex_cast(src), n0, n1);
203 }
204
205 // inverse complex-to-complex
206 inline void inv(Complex* dst, const Complex* src, int nfft) {
207 MKL_LONG size = nfft;
208 get_plan(nfft, dst, src).inverse(complex_cast(dst), complex_cast(src), nfft);
209 }
210
211 // half-complex to scalar
212 inline void inv(Scalar* dst, const Complex* src, int nfft) {
213 MKL_LONG size = nfft;
214 get_plan(nfft, dst, src).inverse(const_cast<Scalar*>(dst), complex_cast(src), nfft);
215 }
216
217 // 2-d complex-to-complex
218 inline void inv2(Complex* dst, const Complex* src, int n0, int n1) {
219 get_plan(n0, n1, dst, src).inverse2(complex_cast(dst), complex_cast(src), n0, n1);
220 }
221
222 private:
223 std::map<int64_t, plan<Scalar>> m_plans;
224
225 inline plan<Scalar>& get_plan(int nfft, void* dst, const void* src) {
226 int inplace = dst == src ? 1 : 0;
227 int aligned = ((reinterpret_cast<size_t>(src) & 15) | (reinterpret_cast<size_t>(dst) & 15)) == 0 ? 1 : 0;
228 int64_t key = ((nfft << 2) | (inplace << 1) | aligned) << 1;
229
230 // Create element if key does not exist.
231 return m_plans[key];
232 }
233
234 inline plan<Scalar>& get_plan(int n0, int n1, void* dst, const void* src) {
235 int inplace = (dst == src) ? 1 : 0;
236 int aligned = ((reinterpret_cast<size_t>(src) & 15) | (reinterpret_cast<size_t>(dst) & 15)) == 0 ? 1 : 0;
237 int64_t key = (((((int64_t)n0) << 31) | (n1 << 2) | (inplace << 1) | aligned) << 1) + 1;
238
239 // Create element if key does not exist.
240 return m_plans[key];
241 }
242};
243
244#undef RUN_OR_ASSERT
245
246} // namespace imklfft
247} // namespace internal
248} // namespace Eigen
Namespace containing all symbols from the Eigen library.
const Eigen::CwiseUnaryOp< Eigen::internal::scalar_inverse_op< typename Derived::Scalar >, const Derived > inverse(const Eigen::ArrayBase< Derived > &x)