Eigen-unsupported  5.0.1-dev+284dcc12
 
Loading...
Searching...
No Matches
SpecialFunctionsFunctors.h
1// This file is part of Eigen, a lightweight C++ template library
2// for linear algebra.
3//
4// Copyright (C) 2016 Eugene Brevdo <ebrevdo@gmail.com>
5// Copyright (C) 2016 Gael Guennebaud <gael.guennebaud@inria.fr>
6//
7// This Source Code Form is subject to the terms of the Mozilla
8// Public License v. 2.0. If a copy of the MPL was not distributed
9// with this file, You can obtain one at http://mozilla.org/MPL/2.0/.
10
11#ifndef EIGEN_SPECIALFUNCTIONS_FUNCTORS_H
12#define EIGEN_SPECIALFUNCTIONS_FUNCTORS_H
13
14// IWYU pragma: private
15#include "./InternalHeaderCheck.h"
16
17namespace Eigen {
18
19namespace internal {
20
26template <typename Scalar>
27struct scalar_igamma_op : binary_op_base<Scalar, Scalar> {
28 EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE const Scalar operator()(const Scalar& a, const Scalar& x) const {
29 using numext::igamma;
30 return igamma(a, x);
31 }
32 template <typename Packet>
33 EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE const Packet packetOp(const Packet& a, const Packet& x) const {
34 return internal::pigamma(a, x);
35 }
36};
37template <typename Scalar>
38struct functor_traits<scalar_igamma_op<Scalar> > {
39 enum {
40 // Guesstimate
41 Cost = 20 * NumTraits<Scalar>::MulCost + 10 * NumTraits<Scalar>::AddCost,
42 PacketAccess = packet_traits<Scalar>::HasIGamma
43 };
44};
45
52template <typename Scalar>
53struct scalar_igamma_der_a_op {
54 EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE const Scalar operator()(const Scalar& a, const Scalar& x) const {
55 using numext::igamma_der_a;
56 return igamma_der_a(a, x);
57 }
58 template <typename Packet>
59 EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE const Packet packetOp(const Packet& a, const Packet& x) const {
60 return internal::pigamma_der_a(a, x);
61 }
62};
63template <typename Scalar>
64struct functor_traits<scalar_igamma_der_a_op<Scalar> > {
65 enum {
66 // 2x the cost of igamma
67 Cost = 40 * NumTraits<Scalar>::MulCost + 20 * NumTraits<Scalar>::AddCost,
68 PacketAccess = packet_traits<Scalar>::HasIGammaDerA
69 };
70};
71
79template <typename Scalar>
80struct scalar_gamma_sample_der_alpha_op {
81 EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE const Scalar operator()(const Scalar& alpha, const Scalar& sample) const {
82 using numext::gamma_sample_der_alpha;
83 return gamma_sample_der_alpha(alpha, sample);
84 }
85 template <typename Packet>
86 EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE const Packet packetOp(const Packet& alpha, const Packet& sample) const {
87 return internal::pgamma_sample_der_alpha(alpha, sample);
88 }
89};
90template <typename Scalar>
91struct functor_traits<scalar_gamma_sample_der_alpha_op<Scalar> > {
92 enum {
93 // 2x the cost of igamma, minus the lgamma cost (the lgamma cancels out)
94 Cost = 30 * NumTraits<Scalar>::MulCost + 15 * NumTraits<Scalar>::AddCost,
95 PacketAccess = packet_traits<Scalar>::HasGammaSampleDerAlpha
96 };
97};
98
104template <typename Scalar>
105struct scalar_igammac_op : binary_op_base<Scalar, Scalar> {
106 EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE const Scalar operator()(const Scalar& a, const Scalar& x) const {
107 using numext::igammac;
108 return igammac(a, x);
109 }
110 template <typename Packet>
111 EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE const Packet packetOp(const Packet& a, const Packet& x) const {
112 return internal::pigammac(a, x);
113 }
114};
115template <typename Scalar>
116struct functor_traits<scalar_igammac_op<Scalar> > {
117 enum {
118 // Guesstimate
119 Cost = 20 * NumTraits<Scalar>::MulCost + 10 * NumTraits<Scalar>::AddCost,
120 PacketAccess = packet_traits<Scalar>::HasIGammac
121 };
122};
123
128template <typename Scalar>
129struct scalar_betainc_op {
130 EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE const Scalar operator()(const Scalar& x, const Scalar& a,
131 const Scalar& b) const {
132 using numext::betainc;
133 return betainc(x, a, b);
134 }
135 template <typename Packet>
136 EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE const Packet packetOp(const Packet& x, const Packet& a, const Packet& b) const {
137 return internal::pbetainc(x, a, b);
138 }
139};
140template <typename Scalar>
141struct functor_traits<scalar_betainc_op<Scalar> > {
142 enum {
143 // Guesstimate
144 Cost = 400 * NumTraits<Scalar>::MulCost + 400 * NumTraits<Scalar>::AddCost,
145 PacketAccess = packet_traits<Scalar>::HasBetaInc
146 };
147};
148
154template <typename Scalar>
155struct scalar_lgamma_op {
156 EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE const Scalar operator()(const Scalar& a) const {
157 using numext::lgamma;
158 return lgamma(a);
159 }
160 typedef typename packet_traits<Scalar>::type Packet;
161 EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE Packet packetOp(const Packet& a) const { return internal::plgamma(a); }
162};
163template <typename Scalar>
164struct functor_traits<scalar_lgamma_op<Scalar> > {
165 enum {
166 // Guesstimate
167 Cost = 10 * NumTraits<Scalar>::MulCost + 5 * NumTraits<Scalar>::AddCost,
168 PacketAccess = packet_traits<Scalar>::HasLGamma
169 };
170};
171
176template <typename Scalar>
177struct scalar_digamma_op {
178 EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE const Scalar operator()(const Scalar& a) const {
179 using numext::digamma;
180 return digamma(a);
181 }
182 typedef typename packet_traits<Scalar>::type Packet;
183 EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE Packet packetOp(const Packet& a) const { return internal::pdigamma(a); }
184};
185template <typename Scalar>
186struct functor_traits<scalar_digamma_op<Scalar> > {
187 enum {
188 // Guesstimate
189 Cost = 10 * NumTraits<Scalar>::MulCost + 5 * NumTraits<Scalar>::AddCost,
190 PacketAccess = packet_traits<Scalar>::HasDiGamma
191 };
192};
193
198template <typename Scalar>
199struct scalar_zeta_op {
200 EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE const Scalar operator()(const Scalar& x, const Scalar& q) const {
201 using numext::zeta;
202 return zeta(x, q);
203 }
204 typedef typename packet_traits<Scalar>::type Packet;
205 EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE Packet packetOp(const Packet& x, const Packet& q) const {
206 return internal::pzeta(x, q);
207 }
208};
209template <typename Scalar>
210struct functor_traits<scalar_zeta_op<Scalar> > {
211 enum {
212 // Guesstimate
213 Cost = 10 * NumTraits<Scalar>::MulCost + 5 * NumTraits<Scalar>::AddCost,
214 PacketAccess = packet_traits<Scalar>::HasZeta
215 };
216};
217
222template <typename Scalar>
223struct scalar_polygamma_op {
224 EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE const Scalar operator()(const Scalar& n, const Scalar& x) const {
225 using numext::polygamma;
226 return polygamma(n, x);
227 }
228 typedef typename packet_traits<Scalar>::type Packet;
229 EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE Packet packetOp(const Packet& n, const Packet& x) const {
230 return internal::ppolygamma(n, x);
231 }
232};
233template <typename Scalar>
234struct functor_traits<scalar_polygamma_op<Scalar> > {
235 enum {
236 // Guesstimate
237 Cost = 10 * NumTraits<Scalar>::MulCost + 5 * NumTraits<Scalar>::AddCost,
238 PacketAccess = packet_traits<Scalar>::HasPolygamma
239 };
240};
241
246template <typename Scalar>
247struct scalar_erf_op {
248 EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE const Scalar operator()(const Scalar& a) const { return numext::erf(a); }
249 template <typename Packet>
250 EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE Packet packetOp(const Packet& x) const {
251 return perf(x);
252 }
253};
254template <typename Scalar>
255struct functor_traits<scalar_erf_op<Scalar> > {
256 enum {
257 PacketAccess = packet_traits<Scalar>::HasErf,
258 Cost = (PacketAccess
259#ifdef EIGEN_VECTORIZE_FMA
260 // TODO(rmlarsen): Move the FMA cost model to a central location.
261 // Haswell can issue 2 add/mul/madd per cycle.
262 // 10 pmadd, 2 pmul, 1 div, 2 other
263 ? (2 * NumTraits<Scalar>::AddCost + 7 * NumTraits<Scalar>::MulCost +
264 scalar_div_cost<Scalar, packet_traits<Scalar>::HasDiv>::value)
265#else
266 ? (12 * NumTraits<Scalar>::AddCost + 12 * NumTraits<Scalar>::MulCost +
267 scalar_div_cost<Scalar, packet_traits<Scalar>::HasDiv>::value)
268#endif
269 // Assume for simplicity that this is as expensive as an exp().
270 : (functor_traits<scalar_exp_op<Scalar> >::Cost))
271 };
272};
273
279template <typename Scalar>
280struct scalar_erfc_op {
281 EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE const Scalar operator()(const Scalar& a) const {
282 using numext::erfc;
283 return erfc(a);
284 }
285 typedef typename packet_traits<Scalar>::type Packet;
286 EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE Packet packetOp(const Packet& a) const { return internal::perfc(a); }
287};
288template <typename Scalar>
289struct functor_traits<scalar_erfc_op<Scalar> > {
290 enum {
291 // Guesstimate
292 Cost = 10 * NumTraits<Scalar>::MulCost + 5 * NumTraits<Scalar>::AddCost,
293 PacketAccess = packet_traits<Scalar>::HasErfc
294 };
295};
296
302template <typename Scalar>
303struct scalar_ndtri_op {
304 EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE const Scalar operator()(const Scalar& a) const {
305 using numext::ndtri;
306 return ndtri(a);
307 }
308 typedef typename packet_traits<Scalar>::type Packet;
309 EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE Packet packetOp(const Packet& a) const { return internal::pndtri(a); }
310};
311template <typename Scalar>
312struct functor_traits<scalar_ndtri_op<Scalar> > {
313 enum {
314 // On average, We are evaluating rational functions with degree N=9 in the
315 // numerator and denominator. This results in 2*N additions and 2*N
316 // multiplications.
317 Cost = 18 * NumTraits<Scalar>::MulCost + 18 * NumTraits<Scalar>::AddCost,
318 PacketAccess = packet_traits<Scalar>::HasNdtri
319 };
320};
321
322} // end namespace internal
323
324} // end namespace Eigen
325
326#endif // EIGEN_SPECIALFUNCTIONS_FUNCTORS_H
Namespace containing all symbols from the Eigen library.
const Eigen::CwiseBinaryOp< Eigen::internal::scalar_igammac_op< typename Derived::Scalar >, const Derived, const ExponentDerived > igammac(const Eigen::ArrayBase< Derived > &a, const Eigen::ArrayBase< ExponentDerived > &x)
Definition SpecialFunctionsArrayAPI.h:93
const Eigen::CwiseBinaryOp< Eigen::internal::scalar_igamma_der_a_op< typename Derived::Scalar >, const Derived, const ExponentDerived > igamma_der_a(const Eigen::ArrayBase< Derived > &a, const Eigen::ArrayBase< ExponentDerived > &x)
Definition SpecialFunctionsArrayAPI.h:52
const Eigen::CwiseUnaryOp< Eigen::internal::scalar_lgamma_op< typename Derived::Scalar >, const Derived > lgamma(const Eigen::ArrayBase< Derived > &x)
const Eigen::CwiseBinaryOp< Eigen::internal::scalar_gamma_sample_der_alpha_op< typename AlphaDerived::Scalar >, const AlphaDerived, const SampleDerived > gamma_sample_der_alpha(const Eigen::ArrayBase< AlphaDerived > &alpha, const Eigen::ArrayBase< SampleDerived > &sample)
Definition SpecialFunctionsArrayAPI.h:75
const Eigen::CwiseUnaryOp< Eigen::internal::scalar_erfc_op< typename Derived::Scalar >, const Derived > erfc(const Eigen::ArrayBase< Derived > &x)
const TensorCwiseTernaryOp< internal::scalar_betainc_op< typename XDerived::Scalar >, const ADerived, const BDerived, const XDerived > betainc(const Eigen::TensorBase< ADerived, ReadOnlyAccessors > &a, const Eigen::TensorBase< BDerived, ReadOnlyAccessors > &b, const Eigen::TensorBase< XDerived, ReadOnlyAccessors > &x)
Definition TensorGlobalFunctions.h:26
const Eigen::CwiseUnaryOp< Eigen::internal::scalar_ndtri_op< typename Derived::Scalar >, const Derived > ndtri(const Eigen::ArrayBase< Derived > &x)
const Eigen::CwiseUnaryOp< Eigen::internal::scalar_digamma_op< typename Derived::Scalar >, const Derived > digamma(const Eigen::ArrayBase< Derived > &x)
const Eigen::CwiseBinaryOp< Eigen::internal::scalar_polygamma_op< typename DerivedX::Scalar >, const DerivedN, const DerivedX > polygamma(const Eigen::ArrayBase< DerivedN > &n, const Eigen::ArrayBase< DerivedX > &x)
Definition SpecialFunctionsArrayAPI.h:113
const Eigen::CwiseBinaryOp< Eigen::internal::scalar_igamma_op< typename Derived::Scalar >, const Derived, const ExponentDerived > igamma(const Eigen::ArrayBase< Derived > &a, const Eigen::ArrayBase< ExponentDerived > &x)
Definition SpecialFunctionsArrayAPI.h:31
const Eigen::CwiseBinaryOp< Eigen::internal::scalar_zeta_op< typename DerivedX::Scalar >, const DerivedX, const DerivedQ > zeta(const Eigen::ArrayBase< DerivedX > &x, const Eigen::ArrayBase< DerivedQ > &q)
Definition SpecialFunctionsArrayAPI.h:152