Eigen-unsupported  5.0.1-dev+284dcc12
 
Loading...
Searching...
No Matches
SpecialFunctionsImpl.h
1// This file is part of Eigen, a lightweight C++ template library
2// for linear algebra.
3//
4// Copyright (C) 2015 Eugene Brevdo <ebrevdo@gmail.com>
5//
6// This Source Code Form is subject to the terms of the Mozilla
7// Public License v. 2.0. If a copy of the MPL was not distributed
8// with this file, You can obtain one at http://mozilla.org/MPL/2.0/.
9
10#ifndef EIGEN_SPECIAL_FUNCTIONS_H
11#define EIGEN_SPECIAL_FUNCTIONS_H
12
13// IWYU pragma: private
14#include "./InternalHeaderCheck.h"
15
16namespace Eigen {
17namespace internal {
18
19// Parts of this code are based on the Cephes Math Library.
20//
21// Cephes Math Library Release 2.8: June, 2000
22// Copyright 1984, 1987, 1992, 2000 by Stephen L. Moshier
23//
24// Permission has been kindly provided by the original author
25// to incorporate the Cephes software into the Eigen codebase:
26//
27// From: Stephen Moshier
28// To: Eugene Brevdo
29// Subject: Re: Permission to wrap several cephes functions in Eigen
30//
31// Hello Eugene,
32//
33// Thank you for writing.
34//
35// If your licensing is similar to BSD, the formal way that has been
36// handled is simply to add a statement to the effect that you are incorporating
37// the Cephes software by permission of the author.
38//
39// Good luck with your project,
40// Steve
41
42/****************************************************************************
43 * Implementation of lgamma, requires C++11/C99 *
44 ****************************************************************************/
45
46template <typename Scalar>
47struct lgamma_impl {
48 EIGEN_STATIC_ASSERT((internal::is_same<Scalar, Scalar>::value == false), THIS_TYPE_IS_NOT_SUPPORTED)
49
50 EIGEN_DEVICE_FUNC static EIGEN_STRONG_INLINE Scalar run(const Scalar) { return Scalar(0); }
51};
52
53template <typename Scalar>
54struct lgamma_retval {
55 typedef Scalar type;
56};
57
58#if EIGEN_HAS_C99_MATH
59// Since glibc 2.19
60#if defined(__GLIBC__) && ((__GLIBC__ >= 2 && __GLIBC_MINOR__ >= 19) || __GLIBC__ > 2) && \
61 (defined(_DEFAULT_SOURCE) || defined(_BSD_SOURCE) || defined(_SVID_SOURCE))
62#define EIGEN_HAS_LGAMMA_R
63#endif
64
65// Glibc versions before 2.19
66#if defined(__GLIBC__) && ((__GLIBC__ == 2 && __GLIBC_MINOR__ < 19) || __GLIBC__ < 2) && \
67 (defined(_BSD_SOURCE) || defined(_SVID_SOURCE))
68#define EIGEN_HAS_LGAMMA_R
69#endif
70
71template <>
72struct lgamma_impl<float> {
73 EIGEN_DEVICE_FUNC static EIGEN_STRONG_INLINE float run(float x) {
74#if !defined(EIGEN_GPU_COMPILE_PHASE) && defined(EIGEN_HAS_LGAMMA_R) && !defined(__APPLE__)
75 int dummy;
76 return ::lgammaf_r(x, &dummy);
77#elif defined(SYCL_DEVICE_ONLY)
78 return cl::sycl::lgamma(x);
79#else
80 return ::lgammaf(x);
81#endif
82 }
83};
84
85template <>
86struct lgamma_impl<double> {
87 EIGEN_DEVICE_FUNC static EIGEN_STRONG_INLINE double run(double x) {
88#if !defined(EIGEN_GPU_COMPILE_PHASE) && defined(EIGEN_HAS_LGAMMA_R) && !defined(__APPLE__)
89 int dummy;
90 return ::lgamma_r(x, &dummy);
91#elif defined(SYCL_DEVICE_ONLY)
92 return cl::sycl::lgamma(x);
93#else
94 return ::lgamma(x);
95#endif
96 }
97};
98
99#undef EIGEN_HAS_LGAMMA_R
100#endif
101
102/****************************************************************************
103 * Implementation of digamma (psi), based on Cephes *
104 ****************************************************************************/
105
106template <typename Scalar>
107struct digamma_retval {
108 typedef Scalar type;
109};
110
111/*
112 *
113 * Polynomial evaluation helper for the Psi (digamma) function.
114 *
115 * digamma_impl_maybe_poly::run(s) evaluates the asymptotic Psi expansion for
116 * input Scalar s, assuming s is above 10.0.
117 *
118 * If s is above a certain threshold for the given Scalar type, zero
119 * is returned. Otherwise the polynomial is evaluated with enough
120 * coefficients for results matching Scalar machine precision.
121 *
122 *
123 */
124template <typename Scalar>
125struct digamma_impl_maybe_poly {
126 EIGEN_STATIC_ASSERT((internal::is_same<Scalar, Scalar>::value == false), THIS_TYPE_IS_NOT_SUPPORTED)
127
128 EIGEN_DEVICE_FUNC static EIGEN_STRONG_INLINE Scalar run(const Scalar) { return Scalar(0); }
129};
130
131template <>
132struct digamma_impl_maybe_poly<float> {
133 EIGEN_DEVICE_FUNC static EIGEN_STRONG_INLINE float run(const float s) {
134 constexpr float A[] = {-4.16666666666666666667E-3f, 3.96825396825396825397E-3f, -8.33333333333333333333E-3f,
135 8.33333333333333333333E-2f};
136
137 float z;
138 if (s < 1.0e8f) {
139 z = 1.0f / (s * s);
140 return z * internal::ppolevl<float, 3>::run(z, A);
141 } else
142 return 0.0f;
143 }
144};
145
146template <>
147struct digamma_impl_maybe_poly<double> {
148 EIGEN_DEVICE_FUNC static EIGEN_STRONG_INLINE double run(const double s) {
149 constexpr double A[] = {8.33333333333333333333E-2, -2.10927960927960927961E-2, 7.57575757575757575758E-3,
150 -4.16666666666666666667E-3, 3.96825396825396825397E-3, -8.33333333333333333333E-3,
151 8.33333333333333333333E-2};
152
153 double z;
154 if (s < 1.0e17) {
155 z = 1.0 / (s * s);
156 return z * internal::ppolevl<double, 6>::run(z, A);
157 } else
158 return 0.0;
159 }
160};
161
162template <typename Scalar>
163struct digamma_impl {
164 EIGEN_DEVICE_FUNC static Scalar run(Scalar x) {
165 /*
166 *
167 * Psi (digamma) function (modified for Eigen)
168 *
169 *
170 * SYNOPSIS:
171 *
172 * double x, y, psi();
173 *
174 * y = psi( x );
175 *
176 *
177 * DESCRIPTION:
178 *
179 * d -
180 * psi(x) = -- ln | (x)
181 * dx
182 *
183 * is the logarithmic derivative of the gamma function.
184 * For integer x,
185 * n-1
186 * -
187 * psi(n) = -EUL + > 1/k.
188 * -
189 * k=1
190 *
191 * If x is negative, it is transformed to a positive argument by the
192 * reflection formula psi(1-x) = psi(x) + pi cot(pi x).
193 * For general positive x, the argument is made greater than 10
194 * using the recurrence psi(x+1) = psi(x) + 1/x.
195 * Then the following asymptotic expansion is applied:
196 *
197 * inf. B
198 * - 2k
199 * psi(x) = log(x) - 1/2x - > -------
200 * - 2k
201 * k=1 2k x
202 *
203 * where the B2k are Bernoulli numbers.
204 *
205 * ACCURACY (float):
206 * Relative error (except absolute when |psi| < 1):
207 * arithmetic domain # trials peak rms
208 * IEEE 0,30 30000 1.3e-15 1.4e-16
209 * IEEE -30,0 40000 1.5e-15 2.2e-16
210 *
211 * ACCURACY (double):
212 * Absolute error, relative when |psi| > 1 :
213 * arithmetic domain # trials peak rms
214 * IEEE -33,0 30000 8.2e-7 1.2e-7
215 * IEEE 0,33 100000 7.3e-7 7.7e-8
216 *
217 * ERROR MESSAGES:
218 * message condition value returned
219 * psi singularity x integer <=0 INFINITY
220 */
221
222 Scalar p, q, nz, s, w, y;
223 bool negative = false;
224
225 const Scalar nan = NumTraits<Scalar>::quiet_NaN();
226 const Scalar m_pi = Scalar(EIGEN_PI);
227
228 const Scalar zero = Scalar(0);
229 const Scalar one = Scalar(1);
230 const Scalar half = Scalar(0.5);
231 nz = zero;
232
233 if (x <= zero) {
234 negative = true;
235 q = x;
236 p = numext::floor(q);
237 if (p == q) {
238 return nan;
239 }
240 /* Remove the zeros of tan(m_pi x)
241 * by subtracting the nearest integer from x
242 */
243 nz = q - p;
244 if (nz != half) {
245 if (nz > half) {
246 p += one;
247 nz = q - p;
248 }
249 nz = m_pi / numext::tan(m_pi * nz);
250 } else {
251 nz = zero;
252 }
253 x = one - x;
254 }
255
256 /* use the recurrence psi(x+1) = psi(x) + 1/x. */
257 s = x;
258 w = zero;
259 while (s < Scalar(10)) {
260 w += one / s;
261 s += one;
262 }
263
264 y = digamma_impl_maybe_poly<Scalar>::run(s);
265
266 y = numext::log(s) - (half / s) - y - w;
267
268 return (negative) ? y - nz : y;
269 }
270};
271
272/***************************************************************************
273 * Implementation of erfc.
274 ****************************************************************************/
275template <typename Scalar>
276struct generic_fast_erfc {
277 template <typename T>
278 static EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE T run(const T& x_in);
279};
280
281template <>
282template <typename T>
283EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE T generic_fast_erfc<float>::run(const T& x_in) {
284 constexpr float kClamp = 11.0f;
285 const T x = pmin(pmax(x_in, pset1<T>(-kClamp)), pset1<T>(kClamp));
286
287 // erfc(x) = 1 + x * S(x^2), |x| <= 1.
288 //
289 // Coefficients for S and T generated with Rminimax command:
290 // ./ratapprox --function="erfc(x)-1" --dom='[-1,1]' --type=[11,0] --num="odd"
291 // --numF="[SG]" --denF="[SG]" --log --dispCoeff="dec"
292 constexpr float alpha[] = {5.61802298761904239654541015625e-04, -4.91381669417023658752441406250e-03,
293 2.67075151205062866210937500000e-02, -1.12800106406211853027343750000e-01,
294 3.76122951507568359375000000000e-01, -1.12837910652160644531250000000e+00};
295 const T x2 = pmul(x, x);
296 const T one = pset1<T>(1.0f);
297 const T erfc_small = pmadd(x, ppolevl<T, 5>::run(x2, alpha), one);
298
299 // Return early if we don't need the more expensive approximation for any
300 // entry in a.
301 const T x_abs_gt_one_mask = pcmp_lt(one, x2);
302 if (!predux_any(x_abs_gt_one_mask)) return erfc_small;
303
304 // erfc(x) = exp(-x^2) * 1/x * P(1/x^2) / Q(1/x^2), 1 < x < 9.
305 //
306 // Coefficients for P and Q generated with Rminimax command:
307 // ./ratapprox --function="erfc(1/sqrt(x))*exp(1/x)/sqrt(x)"
308 // --dom='[0.01,1]' --type=[3,4] --numF="[SG]" --denF="[SG]" --log
309 // --dispCoeff="dec"
310 constexpr float gamma[] = {1.0208116471767425537109375e-01f, 4.2920666933059692382812500e-01f,
311 3.2379078865051269531250000e-01f, 5.3971976041793823242187500e-02f};
312 constexpr float delta[] = {1.7251677811145782470703125e-02f, 3.9137163758277893066406250e-01f,
313 1.0000000000000000000000000e+00f, 6.2173241376876831054687500e-01f,
314 9.5662862062454223632812500e-02f};
315 const T x2_lo = twoprod_low(x, x, x2);
316 // Here we use that
317 // exp(-x^2) = exp(-(x2+x2_lo)^2) ~= exp(-x2)*exp(-x2_lo) ~= exp(-x2)*(1-x2_lo)
318 // since x2_lo < kClamp * eps << 1 in the region we care about. This trick reduces the max error
319 // from 34 ulps to below 5 ulps.
320 const T exp2_hi = pexp(pnegate(x2));
321 const T z = pnmadd(exp2_hi, x2_lo, exp2_hi);
322 const T q2 = preciprocal(x2);
323 const T num = ppolevl<T, 3>::run(q2, gamma);
324 const T denom = pmul(x, ppolevl<T, 4>::run(q2, delta));
325 const T r = pdiv(num, denom);
326 const T maybe_two = pselect(pcmp_lt(x, pset1<T>(0.0f)), pset1<T>(2.0f), pset1<T>(0.0f));
327 const T erfc_large = pmadd(z, r, maybe_two);
328 return pselect(x_abs_gt_one_mask, erfc_large, erfc_small);
329}
330
331// Computes erf(x)/x for |x| <= 1. Used by both erf and erfc implementations.
332// Takes x2 = x^2 as input.
333//
334// PRECONDITION: x2 <= 1.
335template <typename T>
336EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE T erf_over_x_double_small(const T& x2) {
337 // erf(x)/x = S(x^2) / T(x^2), x^2 <= 1.
338 //
339 // Coefficients for S and T generated with Rminimax command:
340 // ./ratapprox --function="erf(x)" --dom='[-1,1]' --type=[9,10]
341 // --num="odd" --numF="[D]" --den="even" --denF="[D]" --log --dispCoeff="dec"
342 constexpr double alpha[] = {1.9493725660006057018823477644531294572516344487667083740234375e-04,
343 1.8272566210022942682217328425053892715368419885635375976562500e-03,
344 4.5303363351690106863856044583371840417385101318359375000000000e-02,
345 1.4215015503619179981775744181504705920815467834472656250000000e-01,
346 1.1283791670955125585606992899556644260883331298828125000000000e+00};
347 constexpr double beta[] = {2.0294484101083099089526257108317963684385176748037338256835938e-05,
348 6.8117805899186819641732970609382391558028757572174072265625000e-04,
349 1.0582026056098614921752165685120417037978768348693847656250000e-02,
350 9.3252603143757495374188692949246615171432495117187500000000000e-02,
351 4.5931062818368939559832142549566924571990966796875000000000000e-01,
352 1.0};
353 const T num_small = ppolevl<T, 4>::run(x2, alpha);
354 const T denom_small = ppolevl<T, 5>::run(x2, beta);
355 return pdiv(num_small, denom_small);
356}
357
358// erfc(x) = exp(-x^2) * 1/x * P(1/x^2) / Q(1/x^2), 1 < x < 28.
359//
360// Coefficients for P and Q generated with Rminimax command:
361// ./ratapprox --function="erfc(1/sqrt(x))*exp(1/x)/sqrt(x)" --dom='[0.0013717,1]' --type=[9,9] --numF="[D]"
362// --denF="[D]" --log --dispCoeff="dec"
363//
364// PRECONDITION: 1 < x < 28.
365template <typename T>
366EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE T erfc_double_large(const T& x, const T& x2) {
367 constexpr double gamma[] = {1.5252844933226974316088642158462107545346952974796295166015625e-04,
368 1.0909912393738931124520519233556115068495273590087890625000000e-02,
369 1.0628604636755033252537572252549580298364162445068359375000000e-01,
370 3.3492472973137982217295416376146022230386734008789062500000000e-01,
371 4.5065776215933289750026347064704168587923049926757812500000000e-01,
372 2.9433039130294824659017649537418037652969360351562500000000000e-01,
373 9.8792676360600226170838311645638896152377128601074218750000000e-02,
374 1.7095935395503719655962981960328761488199234008789062500000000e-02,
375 1.4249109729504577659398023570247460156679153442382812500000000e-03,
376 4.4567378313647954771875570045835956989321857690811157226562500e-05};
377 constexpr double delta[] = {2.041985103115789845773520028160419315099716186523437500000000e-03,
378 5.316030659946043707142493417450168635696172714233398437500000e-02,
379 3.426242193784684864077405563875799998641014099121093750000000e-01,
380 8.565637124308049799026321124983951449394226074218750000000000e-01,
381 1.000000000000000000000000000000000000000000000000000000000000e+00,
382 5.968805280570776972126623149961233139038085937500000000000000e-01,
383 1.890922854723317836356244470152887515723705291748046875000000e-01,
384 3.152505418656005586885981983868987299501895904541015625000000e-02,
385 2.565085751861882583380047861965067568235099315643310546875000e-03,
386 7.899362131678837697403017248376499992446042597293853759765625e-05};
387 // Compute exp(-x^2).
388 const T x2_lo = twoprod_low(x, x, x2);
389 // Here we use that
390 // exp(-x^2) = exp(-(x2+x2_lo)^2) ~= exp(-x2)*exp(-x2_lo) ~= exp(-x2)*(1-x2_lo)
391 // since x2_lo < kClamp *eps << 1 in the region we care about. This trick reduces the max error
392 // from 258 ulps to below 7 ulps.
393 const T exp2_hi = pexp(pnegate(x2));
394 const T z = pnmadd(exp2_hi, x2_lo, exp2_hi);
395 // Compute r = P / Q.
396 const T q2 = preciprocal(x2);
397 const T num_large = ppolevl<T, 9>::run(q2, gamma);
398 const T denom_large = pmul(x, ppolevl<T, 9>::run(q2, delta));
399 const T r = pdiv(num_large, denom_large);
400 const T maybe_two = pselect(pcmp_lt(x, pset1<T>(0.0)), pset1<T>(2.0), pset1<T>(0.0));
401 return pmadd(z, r, maybe_two);
402}
403
404template <>
405template <typename T>
406EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE T generic_fast_erfc<double>::run(const T& x_in) {
407 // Clamp x to [-28:28] beyond which erfc(x) is either two or zero (below the underflow threshold).
408 // This avoids having to deal with twoprod(x,x) producing NaN for sufficiently large x.
409 constexpr double kClamp = 28.0;
410 const T x = pmin(pmax(x_in, pset1<T>(-kClamp)), pset1<T>(kClamp));
411
412 // For |x| < 1, we use erfc(x) = 1 - erf(x).
413 const T x2 = pmul(x, x);
414 const T one = pset1<T>(1.0);
415 const T erfc_small = pnmadd(x, erf_over_x_double_small(x2), one);
416
417 // Return early if we don't need the more expensive approximation for any
418 // entry in a.
419 const T x_abs_gt_one_mask = pcmp_lt(one, x2);
420 if (!predux_any(x_abs_gt_one_mask)) return erfc_small;
421
422 const T erfc_large = erfc_double_large(x, x2);
423 return pselect(x_abs_gt_one_mask, erfc_large, erfc_small);
424}
425
426template <typename T>
427struct erfc_impl {
428 typedef typename unpacket_traits<T>::type Scalar;
429 EIGEN_DEVICE_FUNC static EIGEN_STRONG_INLINE T run(const T& x) { return generic_fast_erfc<Scalar>::run(x); }
430};
431
432template <typename Scalar>
433struct erfc_retval {
434 typedef Scalar type;
435};
436
437#if EIGEN_HAS_C99_MATH
438template <>
439struct erfc_impl<float> {
440 EIGEN_DEVICE_FUNC static EIGEN_STRONG_INLINE float run(const float x) {
441#if defined(SYCL_DEVICE_ONLY)
442 return cl::sycl::erfc(x);
443#else
444 return generic_fast_erfc<float>::run(x);
445#endif
446 }
447};
448
449template <>
450struct erfc_impl<double> {
451 EIGEN_DEVICE_FUNC static EIGEN_STRONG_INLINE double run(const double x) {
452#if defined(SYCL_DEVICE_ONLY)
453 return cl::sycl::erfc(x);
454#else
455 return generic_fast_erfc<double>::run(x);
456#endif
457 }
458};
459#endif // EIGEN_HAS_C99_MATH
460
461/****************************************************************************
462 * Implementation of erf.
463 ****************************************************************************/
464
465template <typename Scalar>
466struct generic_fast_erf {
467 template <typename T>
468 static EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE T run(const T& x_in);
469};
470
477template <>
478template <typename T>
479EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE T generic_fast_erf<float>::run(const T& x) {
480 // The monomial coefficients of the numerator polynomial (odd).
481 constexpr float alpha[] = {2.123732201653183437883853912353515625e-06f, 2.861979592125862836837768554687500000e-04f,
482 3.658048342913389205932617187500000000e-03f, 5.243302136659622192382812500000000000e-02f,
483 1.874160766601562500000000000000000000e-01f, 1.128379106521606445312500000000000000e+00f};
484
485 // The monomial coefficients of the denominator polynomial (even).
486 constexpr float beta[] = {3.89185734093189239501953125000e-05f, 1.14329601638019084930419921875e-03f,
487 1.47520881146192550659179687500e-02f, 1.12945675849914550781250000000e-01f,
488 4.99425798654556274414062500000e-01f, 1.0f};
489
490 // Since the polynomials are odd/even, we need x^2.
491 // Since erf(4) == 1 in float, we clamp x^2 to 16 to avoid
492 // computing Inf/Inf below.
493 const T x2 = pmin(pset1<T>(16.0f), pmul(x, x));
494
495 // Evaluate the numerator polynomial p.
496 T p = ppolevl<T, 5>::run(x2, alpha);
497 p = pmul(x, p);
498
499 // Evaluate the denominator polynomial p.
500 T q = ppolevl<T, 5>::run(x2, beta);
501 const T r = pdiv(p, q);
502
503 // Clamp to [-1:1].
504 return pmax(pmin(r, pset1<T>(1.0f)), pset1<T>(-1.0f));
505}
506
507template <>
508template <typename T>
509EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE T generic_fast_erf<double>::run(const T& x) {
510 T x2 = pmul(x, x);
511 T erf_small = pmul(x, erf_over_x_double_small(x2));
512
513 // Return early if we don't need the more expensive approximation for any
514 // entry in a.
515 const T one = pset1<T>(1.0);
516 const T x_abs_gt_one_mask = pcmp_lt(one, x2);
517 if (!predux_any(x_abs_gt_one_mask)) return erf_small;
518
519 // For |x| > 1, use erf(x) = 1 - erfc(x).
520 const T erf_large = psub(one, erfc_double_large(x, x2));
521 return pselect(x_abs_gt_one_mask, erf_large, erf_small);
522}
523
524template <typename T>
525struct erf_impl {
526 typedef typename unpacket_traits<T>::type Scalar;
527 EIGEN_DEVICE_FUNC static EIGEN_STRONG_INLINE T run(const T& x) { return generic_fast_erf<Scalar>::run(x); }
528};
529
530template <typename Scalar>
531struct erf_retval {
532 typedef Scalar type;
533};
534
535#if EIGEN_HAS_C99_MATH
536template <>
537struct erf_impl<float> {
538 EIGEN_DEVICE_FUNC static EIGEN_STRONG_INLINE float run(const float x) {
539#if defined(SYCL_DEVICE_ONLY)
540 return cl::sycl::erf(x);
541#else
542 return generic_fast_erf<float>::run(x);
543#endif
544 }
545};
546
547template <>
548struct erf_impl<double> {
549 EIGEN_DEVICE_FUNC static EIGEN_STRONG_INLINE double run(const double x) {
550#if defined(SYCL_DEVICE_ONLY)
551 return cl::sycl::erf(x);
552#else
553 return generic_fast_erf<double>::run(x);
554#endif
555 }
556};
557#endif // EIGEN_HAS_C99_MATH
558
559/***************************************************************************
560 * Implementation of ndtri. *
561 ****************************************************************************/
562
563/* Inverse of Normal distribution function (modified for Eigen).
564 *
565 *
566 * SYNOPSIS:
567 *
568 * double x, y, ndtri();
569 *
570 * x = ndtri( y );
571 *
572 *
573 *
574 * DESCRIPTION:
575 *
576 * Returns the argument, x, for which the area under the
577 * Gaussian probability density function (integrated from
578 * minus infinity to x) is equal to y.
579 *
580 *
581 * For small arguments 0 < y < exp(-2), the program computes
582 * z = sqrt( -2.0 * log(y) ); then the approximation is
583 * x = z - log(z)/z - (1/z) P(1/z) / Q(1/z).
584 * There are two rational functions P/Q, one for 0 < y < exp(-32)
585 * and the other for y up to exp(-2). For larger arguments,
586 * w = y - 0.5, and x/sqrt(2pi) = w + w**3 R(w**2)/S(w**2)).
587 *
588 *
589 * ACCURACY:
590 *
591 * Relative error:
592 * arithmetic domain # trials peak rms
593 * DEC 0.125, 1 5500 9.5e-17 2.1e-17
594 * DEC 6e-39, 0.135 3500 5.7e-17 1.3e-17
595 * IEEE 0.125, 1 20000 7.2e-16 1.3e-16
596 * IEEE 3e-308, 0.135 50000 4.6e-16 9.8e-17
597 *
598 *
599 * ERROR MESSAGES:
600 *
601 * message condition value returned
602 * ndtri domain x == 0 -INF
603 * ndtri domain x == 1 INF
604 * ndtri domain x < 0, x > 1 NAN
605 */
606/*
607 Cephes Math Library Release 2.2: June, 1992
608 Copyright 1985, 1987, 1992 by Stephen L. Moshier
609 Direct inquiries to 30 Frost Street, Cambridge, MA 02140
610*/
611
612// TODO: Add a cheaper approximation for float.
613
614template <typename T>
615EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE T flipsign(const T& should_flipsign, const T& x) {
616 typedef typename unpacket_traits<T>::type Scalar;
617 const T sign_mask = pset1<T>(Scalar(-0.0));
618 T sign_bit = pand<T>(should_flipsign, sign_mask);
619 return pxor<T>(sign_bit, x);
620}
621
622template <>
623EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE double flipsign<double>(const double& should_flipsign, const double& x) {
624 return should_flipsign == 0 ? x : -x;
625}
626
627template <>
628EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE float flipsign<float>(const float& should_flipsign, const float& x) {
629 return should_flipsign == 0 ? x : -x;
630}
631
632// We split this computation in to two so that in the scalar path
633// only one branch is evaluated (due to our template specialization of pselect
634// being an if statement.)
635
636template <typename T, typename ScalarType>
637EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE T generic_ndtri_gt_exp_neg_two(const T& b) {
638 const ScalarType p0[] = {ScalarType(-5.99633501014107895267e1), ScalarType(9.80010754185999661536e1),
639 ScalarType(-5.66762857469070293439e1), ScalarType(1.39312609387279679503e1),
640 ScalarType(-1.23916583867381258016e0)};
641 const ScalarType q0[] = {ScalarType(1.0),
642 ScalarType(1.95448858338141759834e0),
643 ScalarType(4.67627912898881538453e0),
644 ScalarType(8.63602421390890590575e1),
645 ScalarType(-2.25462687854119370527e2),
646 ScalarType(2.00260212380060660359e2),
647 ScalarType(-8.20372256168333339912e1),
648 ScalarType(1.59056225126211695515e1),
649 ScalarType(-1.18331621121330003142e0)};
650 const T sqrt2pi = pset1<T>(ScalarType(2.50662827463100050242e0));
651 const T half = pset1<T>(ScalarType(0.5));
652 T c, c2, ndtri_gt_exp_neg_two;
653
654 c = psub(b, half);
655 c2 = pmul(c, c);
656 ndtri_gt_exp_neg_two =
657 pmadd(c, pmul(c2, pdiv(internal::ppolevl<T, 4>::run(c2, p0), internal::ppolevl<T, 8>::run(c2, q0))), c);
658 return pmul(ndtri_gt_exp_neg_two, sqrt2pi);
659}
660
661template <typename T, typename ScalarType>
662EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE T generic_ndtri_lt_exp_neg_two(const T& b, const T& should_flipsign) {
663 /* Approximation for interval z = sqrt(-2 log a ) between 2 and 8
664 * i.e., a between exp(-2) = .135 and exp(-32) = 1.27e-14.
665 */
666 const ScalarType p1[] = {ScalarType(4.05544892305962419923e0), ScalarType(3.15251094599893866154e1),
667 ScalarType(5.71628192246421288162e1), ScalarType(4.40805073893200834700e1),
668 ScalarType(1.46849561928858024014e1), ScalarType(2.18663306850790267539e0),
669 ScalarType(-1.40256079171354495875e-1), ScalarType(-3.50424626827848203418e-2),
670 ScalarType(-8.57456785154685413611e-4)};
671 const ScalarType q1[] = {ScalarType(1.0),
672 ScalarType(1.57799883256466749731e1),
673 ScalarType(4.53907635128879210584e1),
674 ScalarType(4.13172038254672030440e1),
675 ScalarType(1.50425385692907503408e1),
676 ScalarType(2.50464946208309415979e0),
677 ScalarType(-1.42182922854787788574e-1),
678 ScalarType(-3.80806407691578277194e-2),
679 ScalarType(-9.33259480895457427372e-4)};
680 /* Approximation for interval z = sqrt(-2 log a ) between 8 and 64
681 * i.e., a between exp(-32) = 1.27e-14 and exp(-2048) = 3.67e-890.
682 */
683 const ScalarType p2[] = {ScalarType(3.23774891776946035970e0), ScalarType(6.91522889068984211695e0),
684 ScalarType(3.93881025292474443415e0), ScalarType(1.33303460815807542389e0),
685 ScalarType(2.01485389549179081538e-1), ScalarType(1.23716634817820021358e-2),
686 ScalarType(3.01581553508235416007e-4), ScalarType(2.65806974686737550832e-6),
687 ScalarType(6.23974539184983293730e-9)};
688 const ScalarType q2[] = {ScalarType(1.0),
689 ScalarType(6.02427039364742014255e0),
690 ScalarType(3.67983563856160859403e0),
691 ScalarType(1.37702099489081330271e0),
692 ScalarType(2.16236993594496635890e-1),
693 ScalarType(1.34204006088543189037e-2),
694 ScalarType(3.28014464682127739104e-4),
695 ScalarType(2.89247864745380683936e-6),
696 ScalarType(6.79019408009981274425e-9)};
697 const T eight = pset1<T>(ScalarType(8.0));
698 const T neg_two = pset1<T>(ScalarType(-2));
699 T x, x0, x1, z;
700
701 x = psqrt(pmul(neg_two, plog(b)));
702 x0 = psub(x, pdiv(plog(x), x));
703 z = preciprocal(x);
704 x1 =
705 pmul(z, pselect(pcmp_lt(x, eight), pdiv(internal::ppolevl<T, 8>::run(z, p1), internal::ppolevl<T, 8>::run(z, q1)),
706 pdiv(internal::ppolevl<T, 8>::run(z, p2), internal::ppolevl<T, 8>::run(z, q2))));
707 return flipsign(should_flipsign, psub(x0, x1));
708}
709
710template <typename T, typename ScalarType>
711EIGEN_DEVICE_FUNC EIGEN_ALWAYS_INLINE T generic_ndtri(const T& a) {
712 const T maxnum = pset1<T>(NumTraits<ScalarType>::infinity());
713 const T neg_maxnum = pset1<T>(-NumTraits<ScalarType>::infinity());
714
715 const T zero = pset1<T>(ScalarType(0));
716 const T one = pset1<T>(ScalarType(1));
717 // exp(-2)
718 const T exp_neg_two = pset1<T>(ScalarType(0.13533528323661269189));
719 T b, ndtri, should_flipsign;
720
721 should_flipsign = pcmp_le(a, psub(one, exp_neg_two));
722 b = pselect(should_flipsign, a, psub(one, a));
723
724 ndtri = pselect(pcmp_lt(exp_neg_two, b), generic_ndtri_gt_exp_neg_two<T, ScalarType>(b),
725 generic_ndtri_lt_exp_neg_two<T, ScalarType>(b, should_flipsign));
726
727 return pselect(pcmp_eq(a, zero), neg_maxnum, pselect(pcmp_eq(one, a), maxnum, ndtri));
728}
729
730template <typename Scalar>
731struct ndtri_retval {
732 typedef Scalar type;
733};
734
735#if !EIGEN_HAS_C99_MATH
736
737template <typename Scalar>
738struct ndtri_impl {
739 EIGEN_STATIC_ASSERT((internal::is_same<Scalar, Scalar>::value == false), THIS_TYPE_IS_NOT_SUPPORTED)
740
741 EIGEN_DEVICE_FUNC static EIGEN_STRONG_INLINE Scalar run(const Scalar) { return Scalar(0); }
742};
743
744#else
745
746template <typename Scalar>
747struct ndtri_impl {
748 EIGEN_DEVICE_FUNC static EIGEN_STRONG_INLINE Scalar run(const Scalar x) { return generic_ndtri<Scalar, Scalar>(x); }
749};
750
751#endif // EIGEN_HAS_C99_MATH
752
753/**************************************************************************************************************
754 * Implementation of igammac (complemented incomplete gamma integral), based on Cephes but requires C++11/C99 *
755 **************************************************************************************************************/
756
757template <typename Scalar>
758struct igammac_retval {
759 typedef Scalar type;
760};
761
762// NOTE: cephes_helper is also used to implement zeta
763template <typename Scalar>
764struct cephes_helper {
765 EIGEN_DEVICE_FUNC static EIGEN_STRONG_INLINE Scalar machep() {
766 eigen_assert(false && "machep not supported for this type");
767 return 0.0;
768 }
769 EIGEN_DEVICE_FUNC static EIGEN_STRONG_INLINE Scalar big() {
770 eigen_assert(false && "big not supported for this type");
771 return 0.0;
772 }
773 EIGEN_DEVICE_FUNC static EIGEN_STRONG_INLINE Scalar biginv() {
774 eigen_assert(false && "biginv not supported for this type");
775 return 0.0;
776 }
777};
778
779template <>
780struct cephes_helper<float> {
781 EIGEN_DEVICE_FUNC static EIGEN_STRONG_INLINE float machep() {
782 return NumTraits<float>::epsilon() / 2; // 1.0 - machep == 1.0
783 }
784 EIGEN_DEVICE_FUNC static EIGEN_STRONG_INLINE float big() {
785 // use epsneg (1.0 - epsneg == 1.0)
786 return 1.0f / (NumTraits<float>::epsilon() / 2);
787 }
788 EIGEN_DEVICE_FUNC static EIGEN_STRONG_INLINE float biginv() {
789 // epsneg
790 return machep();
791 }
792};
793
794template <>
795struct cephes_helper<double> {
796 EIGEN_DEVICE_FUNC static EIGEN_STRONG_INLINE double machep() {
797 return NumTraits<double>::epsilon() / 2; // 1.0 - machep == 1.0
798 }
799 EIGEN_DEVICE_FUNC static EIGEN_STRONG_INLINE double big() { return 1.0 / NumTraits<double>::epsilon(); }
800 EIGEN_DEVICE_FUNC static EIGEN_STRONG_INLINE double biginv() {
801 // inverse of eps
802 return NumTraits<double>::epsilon();
803 }
804};
805
806enum IgammaComputationMode { VALUE, DERIVATIVE, SAMPLE_DERIVATIVE };
807
808template <typename Scalar>
809EIGEN_DEVICE_FUNC static EIGEN_STRONG_INLINE Scalar main_igamma_term(Scalar a, Scalar x) {
810 /* Compute x**a * exp(-x) / gamma(a) */
811 Scalar logax = a * numext::log(x) - x - lgamma_impl<Scalar>::run(a);
812 if (logax < -numext::log(NumTraits<Scalar>::highest()) ||
813 // Assuming x and a aren't Nan.
814 (numext::isnan)(logax)) {
815 return Scalar(0);
816 }
817 return numext::exp(logax);
818}
819
820template <typename Scalar, IgammaComputationMode mode>
821EIGEN_DEVICE_FUNC int igamma_num_iterations() {
822 /* Returns the maximum number of internal iterations for igamma computation.
823 */
824 if (mode == VALUE) {
825 return 2000;
826 }
827
828 if (internal::is_same<Scalar, float>::value) {
829 return 200;
830 } else if (internal::is_same<Scalar, double>::value) {
831 return 500;
832 } else {
833 return 2000;
834 }
835}
836
837template <typename Scalar, IgammaComputationMode mode>
838struct igammac_cf_impl {
839 /* Computes igamc(a, x) or derivative (depending on the mode)
840 * using the continued fraction expansion of the complementary
841 * incomplete Gamma function.
842 *
843 * Preconditions:
844 * a > 0
845 * x >= 1
846 * x >= a
847 */
848 EIGEN_DEVICE_FUNC static Scalar run(Scalar a, Scalar x) {
849 const Scalar zero = 0;
850 const Scalar one = 1;
851 const Scalar two = 2;
852 const Scalar machep = cephes_helper<Scalar>::machep();
853 const Scalar big = cephes_helper<Scalar>::big();
854 const Scalar biginv = cephes_helper<Scalar>::biginv();
855
856 if ((numext::isinf)(x)) {
857 return zero;
858 }
859
860 Scalar ax = main_igamma_term<Scalar>(a, x);
861 // This is independent of mode. If this value is zero,
862 // then the function value is zero. If the function value is zero,
863 // then we are in a neighborhood where the function value evaluates to zero,
864 // so the derivative is zero.
865 if (ax == zero) {
866 return zero;
867 }
868
869 // continued fraction
870 Scalar y = one - a;
871 Scalar z = x + y + one;
872 Scalar c = zero;
873 Scalar pkm2 = one;
874 Scalar qkm2 = x;
875 Scalar pkm1 = x + one;
876 Scalar qkm1 = z * x;
877 Scalar ans = pkm1 / qkm1;
878
879 Scalar dpkm2_da = zero;
880 Scalar dqkm2_da = zero;
881 Scalar dpkm1_da = zero;
882 Scalar dqkm1_da = -x;
883 Scalar dans_da = (dpkm1_da - ans * dqkm1_da) / qkm1;
884
885 for (int i = 0; i < igamma_num_iterations<Scalar, mode>(); i++) {
886 c += one;
887 y += one;
888 z += two;
889
890 Scalar yc = y * c;
891 Scalar pk = pkm1 * z - pkm2 * yc;
892 Scalar qk = qkm1 * z - qkm2 * yc;
893
894 Scalar dpk_da = dpkm1_da * z - pkm1 - dpkm2_da * yc + pkm2 * c;
895 Scalar dqk_da = dqkm1_da * z - qkm1 - dqkm2_da * yc + qkm2 * c;
896
897 if (qk != zero) {
898 Scalar ans_prev = ans;
899 ans = pk / qk;
900
901 Scalar dans_da_prev = dans_da;
902 dans_da = (dpk_da - ans * dqk_da) / qk;
903
904 if (mode == VALUE) {
905 if (numext::abs(ans_prev - ans) <= machep * numext::abs(ans)) {
906 break;
907 }
908 } else {
909 if (numext::abs(dans_da - dans_da_prev) <= machep) {
910 break;
911 }
912 }
913 }
914
915 pkm2 = pkm1;
916 pkm1 = pk;
917 qkm2 = qkm1;
918 qkm1 = qk;
919
920 dpkm2_da = dpkm1_da;
921 dpkm1_da = dpk_da;
922 dqkm2_da = dqkm1_da;
923 dqkm1_da = dqk_da;
924
925 if (numext::abs(pk) > big) {
926 pkm2 *= biginv;
927 pkm1 *= biginv;
928 qkm2 *= biginv;
929 qkm1 *= biginv;
930
931 dpkm2_da *= biginv;
932 dpkm1_da *= biginv;
933 dqkm2_da *= biginv;
934 dqkm1_da *= biginv;
935 }
936 }
937
938 /* Compute x**a * exp(-x) / gamma(a) */
939 Scalar dlogax_da = numext::log(x) - digamma_impl<Scalar>::run(a);
940 Scalar dax_da = ax * dlogax_da;
941
942 switch (mode) {
943 case VALUE:
944 return ans * ax;
945 case DERIVATIVE:
946 return ans * dax_da + dans_da * ax;
947 case SAMPLE_DERIVATIVE:
948 default: // this is needed to suppress clang warning
949 return -(dans_da + ans * dlogax_da) * x;
950 }
951 }
952};
953
954template <typename Scalar, IgammaComputationMode mode>
955struct igamma_series_impl {
956 /* Computes igam(a, x) or its derivative (depending on the mode)
957 * using the series expansion of the incomplete Gamma function.
958 *
959 * Preconditions:
960 * x > 0
961 * a > 0
962 * !(x > 1 && x > a)
963 */
964 EIGEN_DEVICE_FUNC static Scalar run(Scalar a, Scalar x) {
965 const Scalar zero = 0;
966 const Scalar one = 1;
967 const Scalar machep = cephes_helper<Scalar>::machep();
968
969 Scalar ax = main_igamma_term<Scalar>(a, x);
970
971 // This is independent of mode. If this value is zero,
972 // then the function value is zero. If the function value is zero,
973 // then we are in a neighborhood where the function value evaluates to zero,
974 // so the derivative is zero.
975 if (ax == zero) {
976 return zero;
977 }
978
979 ax /= a;
980
981 /* power series */
982 Scalar r = a;
983 Scalar c = one;
984 Scalar ans = one;
985
986 Scalar dc_da = zero;
987 Scalar dans_da = zero;
988
989 for (int i = 0; i < igamma_num_iterations<Scalar, mode>(); i++) {
990 r += one;
991 Scalar term = x / r;
992 Scalar dterm_da = -x / (r * r);
993 dc_da = term * dc_da + dterm_da * c;
994 dans_da += dc_da;
995 c *= term;
996 ans += c;
997
998 if (mode == VALUE) {
999 if (c <= machep * ans) {
1000 break;
1001 }
1002 } else {
1003 if (numext::abs(dc_da) <= machep * numext::abs(dans_da)) {
1004 break;
1005 }
1006 }
1007 }
1008
1009 Scalar dlogax_da = numext::log(x) - digamma_impl<Scalar>::run(a + one);
1010 Scalar dax_da = ax * dlogax_da;
1011
1012 switch (mode) {
1013 case VALUE:
1014 return ans * ax;
1015 case DERIVATIVE:
1016 return ans * dax_da + dans_da * ax;
1017 case SAMPLE_DERIVATIVE:
1018 default: // this is needed to suppress clang warning
1019 return -(dans_da + ans * dlogax_da) * x / a;
1020 }
1021 }
1022};
1023
1024#if !EIGEN_HAS_C99_MATH
1025
1026template <typename Scalar>
1027struct igammac_impl {
1028 EIGEN_STATIC_ASSERT((internal::is_same<Scalar, Scalar>::value == false), THIS_TYPE_IS_NOT_SUPPORTED)
1029
1030 EIGEN_DEVICE_FUNC static Scalar run(Scalar a, Scalar x) { return Scalar(0); }
1031};
1032
1033#else
1034
1035template <typename Scalar>
1036struct igammac_impl {
1037 EIGEN_DEVICE_FUNC static Scalar run(Scalar a, Scalar x) {
1038 /* igamc()
1039 *
1040 * Incomplete gamma integral (modified for Eigen)
1041 *
1042 *
1043 *
1044 * SYNOPSIS:
1045 *
1046 * double a, x, y, igamc();
1047 *
1048 * y = igamc( a, x );
1049 *
1050 * DESCRIPTION:
1051 *
1052 * The function is defined by
1053 *
1054 *
1055 * igamc(a,x) = 1 - igam(a,x)
1056 *
1057 * inf.
1058 * -
1059 * 1 | | -t a-1
1060 * = ----- | e t dt.
1061 * - | |
1062 * | (a) -
1063 * x
1064 *
1065 *
1066 * In this implementation both arguments must be positive.
1067 * The integral is evaluated by either a power series or
1068 * continued fraction expansion, depending on the relative
1069 * values of a and x.
1070 *
1071 * ACCURACY (float):
1072 *
1073 * Relative error:
1074 * arithmetic domain # trials peak rms
1075 * IEEE 0,30 30000 7.8e-6 5.9e-7
1076 *
1077 *
1078 * ACCURACY (double):
1079 *
1080 * Tested at random a, x.
1081 * a x Relative error:
1082 * arithmetic domain domain # trials peak rms
1083 * IEEE 0.5,100 0,100 200000 1.9e-14 1.7e-15
1084 * IEEE 0.01,0.5 0,100 200000 1.4e-13 1.6e-15
1085 *
1086 */
1087 /*
1088 Cephes Math Library Release 2.2: June, 1992
1089 Copyright 1985, 1987, 1992 by Stephen L. Moshier
1090 Direct inquiries to 30 Frost Street, Cambridge, MA 02140
1091 */
1092 const Scalar zero = 0;
1093 const Scalar one = 1;
1094 const Scalar nan = NumTraits<Scalar>::quiet_NaN();
1095
1096 if ((x < zero) || (a <= zero)) {
1097 // domain error
1098 return nan;
1099 }
1100
1101 if ((numext::isnan)(a) || (numext::isnan)(x)) { // propagate nans
1102 return nan;
1103 }
1104
1105 if ((x < one) || (x < a)) {
1106 return (one - igamma_series_impl<Scalar, VALUE>::run(a, x));
1107 }
1108
1109 return igammac_cf_impl<Scalar, VALUE>::run(a, x);
1110 }
1111};
1112
1113#endif // EIGEN_HAS_C99_MATH
1114
1115/************************************************************************************************
1116 * Implementation of igamma (incomplete gamma integral), based on Cephes but requires C++11/C99 *
1117 ************************************************************************************************/
1118
1119#if !EIGEN_HAS_C99_MATH
1120
1121template <typename Scalar, IgammaComputationMode mode>
1122struct igamma_generic_impl {
1123 EIGEN_STATIC_ASSERT((internal::is_same<Scalar, Scalar>::value == false), THIS_TYPE_IS_NOT_SUPPORTED)
1124
1125 EIGEN_DEVICE_FUNC static EIGEN_STRONG_INLINE Scalar run(Scalar a, Scalar x) { return Scalar(0); }
1126};
1127
1128#else
1129
1130template <typename Scalar, IgammaComputationMode mode>
1131struct igamma_generic_impl {
1132 EIGEN_DEVICE_FUNC static Scalar run(Scalar a, Scalar x) {
1133 /* Depending on the mode, returns
1134 * - VALUE: incomplete Gamma function igamma(a, x)
1135 * - DERIVATIVE: derivative of incomplete Gamma function d/da igamma(a, x)
1136 * - SAMPLE_DERIVATIVE: implicit derivative of a Gamma random variable
1137 * x ~ Gamma(x | a, 1), dx/da = -1 / Gamma(x | a, 1) * d igamma(a, x) / dx
1138 *
1139 * Derivatives are implemented by forward-mode differentiation.
1140 */
1141 const Scalar zero = 0;
1142 const Scalar one = 1;
1143 const Scalar nan = NumTraits<Scalar>::quiet_NaN();
1144
1145 if (x == zero) return zero;
1146
1147 if ((x < zero) || (a <= zero)) { // domain error
1148 return nan;
1149 }
1150
1151 if ((numext::isnan)(a) || (numext::isnan)(x)) { // propagate nans
1152 return nan;
1153 }
1154
1155 if ((x > one) && (x > a)) {
1156 Scalar ret = igammac_cf_impl<Scalar, mode>::run(a, x);
1157 if (mode == VALUE) {
1158 return one - ret;
1159 } else {
1160 return -ret;
1161 }
1162 }
1163
1164 return igamma_series_impl<Scalar, mode>::run(a, x);
1165 }
1166};
1167
1168#endif // EIGEN_HAS_C99_MATH
1169
1170template <typename Scalar>
1171struct igamma_retval {
1172 typedef Scalar type;
1173};
1174
1175template <typename Scalar>
1176struct igamma_impl : igamma_generic_impl<Scalar, VALUE> {
1177 /* igam()
1178 * Incomplete gamma integral.
1179 *
1180 * The CDF of Gamma(a, 1) random variable at the point x.
1181 *
1182 * Accuracy estimation. For each a in [10^-2, 10^-1...10^3] we sample
1183 * 50 Gamma random variables x ~ Gamma(x | a, 1), a total of 300 points.
1184 * The ground truth is computed by mpmath. Mean absolute error:
1185 * float: 1.26713e-05
1186 * double: 2.33606e-12
1187 *
1188 * Cephes documentation below.
1189 *
1190 * SYNOPSIS:
1191 *
1192 * double a, x, y, igam();
1193 *
1194 * y = igam( a, x );
1195 *
1196 * DESCRIPTION:
1197 *
1198 * The function is defined by
1199 *
1200 * x
1201 * -
1202 * 1 | | -t a-1
1203 * igam(a,x) = ----- | e t dt.
1204 * - | |
1205 * | (a) -
1206 * 0
1207 *
1208 *
1209 * In this implementation both arguments must be positive.
1210 * The integral is evaluated by either a power series or
1211 * continued fraction expansion, depending on the relative
1212 * values of a and x.
1213 *
1214 * ACCURACY (double):
1215 *
1216 * Relative error:
1217 * arithmetic domain # trials peak rms
1218 * IEEE 0,30 200000 3.6e-14 2.9e-15
1219 * IEEE 0,100 300000 9.9e-14 1.5e-14
1220 *
1221 *
1222 * ACCURACY (float):
1223 *
1224 * Relative error:
1225 * arithmetic domain # trials peak rms
1226 * IEEE 0,30 20000 7.8e-6 5.9e-7
1227 *
1228 */
1229 /*
1230 Cephes Math Library Release 2.2: June, 1992
1231 Copyright 1985, 1987, 1992 by Stephen L. Moshier
1232 Direct inquiries to 30 Frost Street, Cambridge, MA 02140
1233 */
1234
1235 /* left tail of incomplete gamma function:
1236 *
1237 * inf. k
1238 * a -x - x
1239 * x e > ----------
1240 * - -
1241 * k=0 | (a+k+1)
1242 *
1243 */
1244};
1245
1246template <typename Scalar>
1247struct igamma_der_a_retval : igamma_retval<Scalar> {};
1248
1249template <typename Scalar>
1250struct igamma_der_a_impl : igamma_generic_impl<Scalar, DERIVATIVE> {
1251 /* Derivative of the incomplete Gamma function with respect to a.
1252 *
1253 * Computes d/da igamma(a, x) by forward differentiation of the igamma code.
1254 *
1255 * Accuracy estimation. For each a in [10^-2, 10^-1...10^3] we sample
1256 * 50 Gamma random variables x ~ Gamma(x | a, 1), a total of 300 points.
1257 * The ground truth is computed by mpmath. Mean absolute error:
1258 * float: 6.17992e-07
1259 * double: 4.60453e-12
1260 *
1261 * Reference:
1262 * R. Moore. "Algorithm AS 187: Derivatives of the incomplete gamma
1263 * integral". Journal of the Royal Statistical Society. 1982
1264 */
1265};
1266
1267template <typename Scalar>
1268struct gamma_sample_der_alpha_retval : igamma_retval<Scalar> {};
1269
1270template <typename Scalar>
1271struct gamma_sample_der_alpha_impl : igamma_generic_impl<Scalar, SAMPLE_DERIVATIVE> {
1272 /* Derivative of a Gamma random variable sample with respect to alpha.
1273 *
1274 * Consider a sample of a Gamma random variable with the concentration
1275 * parameter alpha: sample ~ Gamma(alpha, 1). The reparameterization
1276 * derivative that we want to compute is dsample / dalpha =
1277 * d igammainv(alpha, u) / dalpha, where u = igamma(alpha, sample).
1278 * However, this formula is numerically unstable and expensive, so instead
1279 * we use implicit differentiation:
1280 *
1281 * igamma(alpha, sample) = u, where u ~ Uniform(0, 1).
1282 * Apply d / dalpha to both sides:
1283 * d igamma(alpha, sample) / dalpha
1284 * + d igamma(alpha, sample) / dsample * dsample/dalpha = 0
1285 * d igamma(alpha, sample) / dalpha
1286 * + Gamma(sample | alpha, 1) dsample / dalpha = 0
1287 * dsample/dalpha = - (d igamma(alpha, sample) / dalpha)
1288 * / Gamma(sample | alpha, 1)
1289 *
1290 * Here Gamma(sample | alpha, 1) is the PDF of the Gamma distribution
1291 * (note that the derivative of the CDF w.r.t. sample is the PDF).
1292 * See the reference below for more details.
1293 *
1294 * The derivative of igamma(alpha, sample) is computed by forward
1295 * differentiation of the igamma code. Division by the Gamma PDF is performed
1296 * in the same code, increasing the accuracy and speed due to cancellation
1297 * of some terms.
1298 *
1299 * Accuracy estimation. For each alpha in [10^-2, 10^-1...10^3] we sample
1300 * 50 Gamma random variables sample ~ Gamma(sample | alpha, 1), a total of 300
1301 * points. The ground truth is computed by mpmath. Mean absolute error:
1302 * float: 2.1686e-06
1303 * double: 1.4774e-12
1304 *
1305 * Reference:
1306 * M. Figurnov, S. Mohamed, A. Mnih "Implicit Reparameterization Gradients".
1307 * 2018
1308 */
1309};
1310
1311/*****************************************************************************
1312 * Implementation of Riemann zeta function of two arguments, based on Cephes *
1313 *****************************************************************************/
1314
1315template <typename Scalar>
1316struct zeta_retval {
1317 typedef Scalar type;
1318};
1319
1320template <typename Scalar>
1321struct zeta_impl_series {
1322 EIGEN_STATIC_ASSERT((internal::is_same<Scalar, Scalar>::value == false), THIS_TYPE_IS_NOT_SUPPORTED)
1323
1324 EIGEN_DEVICE_FUNC static EIGEN_STRONG_INLINE Scalar run(const Scalar) { return Scalar(0); }
1325};
1326
1327template <>
1328struct zeta_impl_series<float> {
1329 EIGEN_DEVICE_FUNC static EIGEN_STRONG_INLINE bool run(float& a, float& b, float& s, const float x,
1330 const float machep) {
1331 int i = 0;
1332 while (i < 9) {
1333 i += 1;
1334 a += 1.0f;
1335 b = numext::pow(a, -x);
1336 s += b;
1337 if (numext::abs(b / s) < machep) return true;
1338 }
1339
1340 // Return whether we are done
1341 return false;
1342 }
1343};
1344
1345template <>
1346struct zeta_impl_series<double> {
1347 EIGEN_DEVICE_FUNC static EIGEN_STRONG_INLINE bool run(double& a, double& b, double& s, const double x,
1348 const double machep) {
1349 int i = 0;
1350 while ((i < 9) || (a <= 9.0)) {
1351 i += 1;
1352 a += 1.0;
1353 b = numext::pow(a, -x);
1354 s += b;
1355 if (numext::abs(b / s) < machep) return true;
1356 }
1357
1358 // Return whether we are done
1359 return false;
1360 }
1361};
1362
1363template <typename Scalar>
1364struct zeta_impl {
1365 EIGEN_DEVICE_FUNC static Scalar run(Scalar x, Scalar q) {
1366 /* zeta.c
1367 *
1368 * Riemann zeta function of two arguments
1369 *
1370 *
1371 *
1372 * SYNOPSIS:
1373 *
1374 * double x, q, y, zeta();
1375 *
1376 * y = zeta( x, q );
1377 *
1378 *
1379 *
1380 * DESCRIPTION:
1381 *
1382 *
1383 *
1384 * inf.
1385 * - -x
1386 * zeta(x,q) = > (k+q)
1387 * -
1388 * k=0
1389 *
1390 * where x > 1 and q is not a negative integer or zero.
1391 * The Euler-Maclaurin summation formula is used to obtain
1392 * the expansion
1393 *
1394 * n
1395 * - -x
1396 * zeta(x,q) = > (k+q)
1397 * -
1398 * k=1
1399 *
1400 * 1-x inf. B x(x+1)...(x+2j)
1401 * (n+q) 1 - 2j
1402 * + --------- - ------- + > --------------------
1403 * x-1 x - x+2j+1
1404 * 2(n+q) j=1 (2j)! (n+q)
1405 *
1406 * where the B2j are Bernoulli numbers. Note that (see zetac.c)
1407 * zeta(x,1) = zetac(x) + 1.
1408 *
1409 *
1410 *
1411 * ACCURACY:
1412 *
1413 * Relative error for single precision:
1414 * arithmetic domain # trials peak rms
1415 * IEEE 0,25 10000 6.9e-7 1.0e-7
1416 *
1417 * Large arguments may produce underflow in powf(), in which
1418 * case the results are inaccurate.
1419 *
1420 * REFERENCE:
1421 *
1422 * Gradshteyn, I. S., and I. M. Ryzhik, Tables of Integrals,
1423 * Series, and Products, p. 1073; Academic Press, 1980.
1424 *
1425 */
1426
1427 int i;
1428 Scalar p, r, a, b, k, s, t, w;
1429
1430 const Scalar A[] = {
1431 Scalar(12.0),
1432 Scalar(-720.0),
1433 Scalar(30240.0),
1434 Scalar(-1209600.0),
1435 Scalar(47900160.0),
1436 Scalar(-1.8924375803183791606e9), /*1.307674368e12/691*/
1437 Scalar(7.47242496e10),
1438 Scalar(-2.950130727918164224e12), /*1.067062284288e16/3617*/
1439 Scalar(1.1646782814350067249e14), /*5.109094217170944e18/43867*/
1440 Scalar(-4.5979787224074726105e15), /*8.028576626982912e20/174611*/
1441 Scalar(1.8152105401943546773e17), /*1.5511210043330985984e23/854513*/
1442 Scalar(-7.1661652561756670113e18) /*1.6938241367317436694528e27/236364091*/
1443 };
1444
1445 const Scalar maxnum = NumTraits<Scalar>::infinity();
1446 const Scalar zero = Scalar(0.0), half = Scalar(0.5), one = Scalar(1.0);
1447 const Scalar machep = cephes_helper<Scalar>::machep();
1448 const Scalar nan = NumTraits<Scalar>::quiet_NaN();
1449
1450 if (x == one) return maxnum;
1451
1452 if (x < one) {
1453 return nan;
1454 }
1455
1456 if (q <= zero) {
1457 if (q == numext::floor(q)) {
1458 if (numext::rint(Scalar(0.5) * x) == Scalar(0.5) * x) {
1459 return maxnum;
1460 } else {
1461 return nan;
1462 }
1463 }
1464 p = x;
1465 r = numext::floor(p);
1466 if (p != r) return nan;
1467 }
1468
1469 /* Permit negative q but continue sum until n+q > +9 .
1470 * This case should be handled by a reflection formula.
1471 * If q<0 and x is an integer, there is a relation to
1472 * the polygamma function.
1473 */
1474 s = numext::pow(q, -x);
1475 a = q;
1476 b = zero;
1477 // Run the summation in a helper function that is specific to the floating precision
1478 if (zeta_impl_series<Scalar>::run(a, b, s, x, machep)) {
1479 return s;
1480 }
1481
1482 // If b is zero, then the tail sum will also end up being zero.
1483 // Exiting early here can prevent NaNs for some large inputs, where
1484 // the tail sum computed below has term `a` which can overflow to `inf`.
1485 if (numext::equal_strict(b, zero)) {
1486 return s;
1487 }
1488
1489 w = a;
1490 s += b * w / (x - one);
1491 s -= half * b;
1492 a = one;
1493 k = zero;
1494
1495 for (i = 0; i < 12; i++) {
1496 a *= x + k;
1497 b /= w;
1498 t = a * b / A[i];
1499 s = s + t;
1500 t = numext::abs(t / s);
1501 if (t < machep) {
1502 break;
1503 }
1504 k += one;
1505 a *= x + k;
1506 b /= w;
1507 k += one;
1508 }
1509 return s;
1510 }
1511};
1512
1513/****************************************************************************
1514 * Implementation of polygamma function, requires C++11/C99 *
1515 ****************************************************************************/
1516
1517template <typename Scalar>
1518struct polygamma_retval {
1519 typedef Scalar type;
1520};
1521
1522#if !EIGEN_HAS_C99_MATH
1523
1524template <typename Scalar>
1525struct polygamma_impl {
1526 EIGEN_STATIC_ASSERT((internal::is_same<Scalar, Scalar>::value == false), THIS_TYPE_IS_NOT_SUPPORTED)
1527
1528 EIGEN_DEVICE_FUNC static EIGEN_STRONG_INLINE Scalar run(Scalar n, Scalar x) { return Scalar(0); }
1529};
1530
1531#else
1532
1533template <typename Scalar>
1534struct polygamma_impl {
1535 EIGEN_DEVICE_FUNC static Scalar run(Scalar n, Scalar x) {
1536 Scalar zero = 0.0, one = 1.0;
1537 Scalar nplus = n + one;
1538 const Scalar nan = NumTraits<Scalar>::quiet_NaN();
1539
1540 // Check that n is a non-negative integer
1541 if (numext::floor(n) != n || n < zero) {
1542 return nan;
1543 }
1544 // Just return the digamma function for n = 0
1545 else if (n == zero) {
1546 return digamma_impl<Scalar>::run(x);
1547 }
1548 // Use the same implementation as scipy
1549 else {
1550 Scalar factorial = numext::exp(lgamma_impl<Scalar>::run(nplus));
1551 return numext::pow(-one, nplus) * factorial * zeta_impl<Scalar>::run(nplus, x);
1552 }
1553 }
1554};
1555
1556#endif // EIGEN_HAS_C99_MATH
1557
1558/************************************************************************************************
1559 * Implementation of betainc (incomplete beta integral), based on Cephes but requires C++11/C99 *
1560 ************************************************************************************************/
1561
1562template <typename Scalar>
1563struct betainc_retval {
1564 typedef Scalar type;
1565};
1566
1567#if !EIGEN_HAS_C99_MATH
1568
1569template <typename Scalar>
1570struct betainc_impl {
1571 EIGEN_STATIC_ASSERT((internal::is_same<Scalar, Scalar>::value == false), THIS_TYPE_IS_NOT_SUPPORTED)
1572
1573 EIGEN_DEVICE_FUNC static EIGEN_STRONG_INLINE Scalar run(Scalar a, Scalar b, Scalar x) { return Scalar(0); }
1574};
1575
1576#else
1577
1578template <typename Scalar>
1579struct betainc_impl {
1580 EIGEN_STATIC_ASSERT((internal::is_same<Scalar, Scalar>::value == false), THIS_TYPE_IS_NOT_SUPPORTED)
1581
1582 EIGEN_DEVICE_FUNC static EIGEN_STRONG_INLINE Scalar run(Scalar, Scalar, Scalar) {
1583 /* betaincf.c
1584 *
1585 * Incomplete beta integral
1586 *
1587 *
1588 * SYNOPSIS:
1589 *
1590 * float a, b, x, y, betaincf();
1591 *
1592 * y = betaincf( a, b, x );
1593 *
1594 *
1595 * DESCRIPTION:
1596 *
1597 * Returns incomplete beta integral of the arguments, evaluated
1598 * from zero to x. The function is defined as
1599 *
1600 * x
1601 * - -
1602 * | (a+b) | | a-1 b-1
1603 * ----------- | t (1-t) dt.
1604 * - - | |
1605 * | (a) | (b) -
1606 * 0
1607 *
1608 * The domain of definition is 0 <= x <= 1. In this
1609 * implementation a and b are restricted to positive values.
1610 * The integral from x to 1 may be obtained by the symmetry
1611 * relation
1612 *
1613 * 1 - betainc( a, b, x ) = betainc( b, a, 1-x ).
1614 *
1615 * The integral is evaluated by a continued fraction expansion.
1616 * If a < 1, the function calls itself recursively after a
1617 * transformation to increase a to a+1.
1618 *
1619 * ACCURACY (float):
1620 *
1621 * Tested at random points (a,b,x) with a and b in the indicated
1622 * interval and x between 0 and 1.
1623 *
1624 * arithmetic domain # trials peak rms
1625 * Relative error:
1626 * IEEE 0,30 10000 3.7e-5 5.1e-6
1627 * IEEE 0,100 10000 1.7e-4 2.5e-5
1628 * The useful domain for relative error is limited by underflow
1629 * of the single precision exponential function.
1630 * Absolute error:
1631 * IEEE 0,30 100000 2.2e-5 9.6e-7
1632 * IEEE 0,100 10000 6.5e-5 3.7e-6
1633 *
1634 * Larger errors may occur for extreme ratios of a and b.
1635 *
1636 * ACCURACY (double):
1637 * arithmetic domain # trials peak rms
1638 * IEEE 0,5 10000 6.9e-15 4.5e-16
1639 * IEEE 0,85 250000 2.2e-13 1.7e-14
1640 * IEEE 0,1000 30000 5.3e-12 6.3e-13
1641 * IEEE 0,10000 250000 9.3e-11 7.1e-12
1642 * IEEE 0,100000 10000 8.7e-10 4.8e-11
1643 * Outputs smaller than the IEEE gradual underflow threshold
1644 * were excluded from these statistics.
1645 *
1646 * ERROR MESSAGES:
1647 * message condition value returned
1648 * incbet domain x<0, x>1 nan
1649 * incbet underflow nan
1650 */
1651 return Scalar(0);
1652 }
1653};
1654
1655/* Continued fraction expansion #1 for incomplete beta integral (small_branch = True)
1656 * Continued fraction expansion #2 for incomplete beta integral (small_branch = False)
1657 */
1658template <typename Scalar>
1659struct incbeta_cfe {
1660 EIGEN_STATIC_ASSERT((internal::is_same<Scalar, float>::value || internal::is_same<Scalar, double>::value),
1661 THIS_TYPE_IS_NOT_SUPPORTED)
1662
1663 EIGEN_DEVICE_FUNC static EIGEN_STRONG_INLINE Scalar run(Scalar a, Scalar b, Scalar x, bool small_branch) {
1664 const Scalar big = cephes_helper<Scalar>::big();
1665 const Scalar machep = cephes_helper<Scalar>::machep();
1666 const Scalar biginv = cephes_helper<Scalar>::biginv();
1667
1668 const Scalar zero = 0;
1669 const Scalar one = 1;
1670 const Scalar two = 2;
1671
1672 Scalar xk, pk, pkm1, pkm2, qk, qkm1, qkm2;
1673 Scalar k1, k2, k3, k4, k5, k6, k7, k8, k26update;
1674 Scalar ans;
1675 int n;
1676
1677 const int num_iters = (internal::is_same<Scalar, float>::value) ? 100 : 300;
1678 const Scalar thresh = (internal::is_same<Scalar, float>::value) ? machep : Scalar(3) * machep;
1679 Scalar r = (internal::is_same<Scalar, float>::value) ? zero : one;
1680
1681 if (small_branch) {
1682 k1 = a;
1683 k2 = a + b;
1684 k3 = a;
1685 k4 = a + one;
1686 k5 = one;
1687 k6 = b - one;
1688 k7 = k4;
1689 k8 = a + two;
1690 k26update = one;
1691 } else {
1692 k1 = a;
1693 k2 = b - one;
1694 k3 = a;
1695 k4 = a + one;
1696 k5 = one;
1697 k6 = a + b;
1698 k7 = a + one;
1699 k8 = a + two;
1700 k26update = -one;
1701 x = x / (one - x);
1702 }
1703
1704 pkm2 = zero;
1705 qkm2 = one;
1706 pkm1 = one;
1707 qkm1 = one;
1708 ans = one;
1709 n = 0;
1710
1711 do {
1712 xk = -(x * k1 * k2) / (k3 * k4);
1713 pk = pkm1 + pkm2 * xk;
1714 qk = qkm1 + qkm2 * xk;
1715 pkm2 = pkm1;
1716 pkm1 = pk;
1717 qkm2 = qkm1;
1718 qkm1 = qk;
1719
1720 xk = (x * k5 * k6) / (k7 * k8);
1721 pk = pkm1 + pkm2 * xk;
1722 qk = qkm1 + qkm2 * xk;
1723 pkm2 = pkm1;
1724 pkm1 = pk;
1725 qkm2 = qkm1;
1726 qkm1 = qk;
1727
1728 if (qk != zero) {
1729 r = pk / qk;
1730 if (numext::abs(ans - r) < numext::abs(r) * thresh) {
1731 return r;
1732 }
1733 ans = r;
1734 }
1735
1736 k1 += one;
1737 k2 += k26update;
1738 k3 += two;
1739 k4 += two;
1740 k5 += one;
1741 k6 -= k26update;
1742 k7 += two;
1743 k8 += two;
1744
1745 if ((numext::abs(qk) + numext::abs(pk)) > big) {
1746 pkm2 *= biginv;
1747 pkm1 *= biginv;
1748 qkm2 *= biginv;
1749 qkm1 *= biginv;
1750 }
1751 if ((numext::abs(qk) < biginv) || (numext::abs(pk) < biginv)) {
1752 pkm2 *= big;
1753 pkm1 *= big;
1754 qkm2 *= big;
1755 qkm1 *= big;
1756 }
1757 } while (++n < num_iters);
1758
1759 return ans;
1760 }
1761};
1762
1763/* Helper functions depending on the Scalar type */
1764template <typename Scalar>
1765struct betainc_helper {};
1766
1767template <>
1768struct betainc_helper<float> {
1769 /* Core implementation, assumes a large (> 1.0) */
1770 EIGEN_DEVICE_FUNC static EIGEN_STRONG_INLINE float incbsa(float aa, float bb, float xx) {
1771 float ans, a, b, t, x, onemx;
1772 bool reversed_a_b = false;
1773
1774 onemx = 1.0f - xx;
1775
1776 /* see if x is greater than the mean */
1777 if (xx > (aa / (aa + bb))) {
1778 reversed_a_b = true;
1779 a = bb;
1780 b = aa;
1781 t = xx;
1782 x = onemx;
1783 } else {
1784 a = aa;
1785 b = bb;
1786 t = onemx;
1787 x = xx;
1788 }
1789
1790 /* Choose expansion for optimal convergence */
1791 if (b > 10.0f) {
1792 if (numext::abs(b * x / a) < 0.3f) {
1793 t = betainc_helper<float>::incbps(a, b, x);
1794 if (reversed_a_b) t = 1.0f - t;
1795 return t;
1796 }
1797 }
1798
1799 ans = x * (a + b - 2.0f) / (a - 1.0f);
1800 if (ans < 1.0f) {
1801 ans = incbeta_cfe<float>::run(a, b, x, true /* small_branch */);
1802 t = b * numext::log(t);
1803 } else {
1804 ans = incbeta_cfe<float>::run(a, b, x, false /* small_branch */);
1805 t = (b - 1.0f) * numext::log(t);
1806 }
1807
1808 t += a * numext::log(x) + lgamma_impl<float>::run(a + b) - lgamma_impl<float>::run(a) - lgamma_impl<float>::run(b);
1809 t += numext::log(ans / a);
1810 t = numext::exp(t);
1811
1812 if (reversed_a_b) t = 1.0f - t;
1813 return t;
1814 }
1815
1816 EIGEN_DEVICE_FUNC static EIGEN_STRONG_INLINE float incbps(float a, float b, float x) {
1817 float t, u, y, s;
1818 const float machep = cephes_helper<float>::machep();
1819
1820 y = a * numext::log(x) + (b - 1.0f) * numext::log1p(-x) - numext::log(a);
1821 y -= lgamma_impl<float>::run(a) + lgamma_impl<float>::run(b);
1822 y += lgamma_impl<float>::run(a + b);
1823
1824 t = x / (1.0f - x);
1825 s = 0.0f;
1826 u = 1.0f;
1827 do {
1828 b -= 1.0f;
1829 if (b == 0.0f) {
1830 break;
1831 }
1832 a += 1.0f;
1833 u *= t * b / a;
1834 s += u;
1835 } while (numext::abs(u) > machep);
1836
1837 return numext::exp(y) * (1.0f + s);
1838 }
1839};
1840
1841template <>
1842struct betainc_impl<float> {
1843 EIGEN_DEVICE_FUNC static float run(float a, float b, float x) {
1844 const float nan = NumTraits<float>::quiet_NaN();
1845 float ans, t;
1846
1847 if (a <= 0.0f) return nan;
1848 if (b <= 0.0f) return nan;
1849 if ((x <= 0.0f) || (x >= 1.0f)) {
1850 if (x == 0.0f) return 0.0f;
1851 if (x == 1.0f) return 1.0f;
1852 // mtherr("betaincf", DOMAIN);
1853 return nan;
1854 }
1855
1856 /* transformation for small aa */
1857 if (a <= 1.0f) {
1858 ans = betainc_helper<float>::incbsa(a + 1.0f, b, x);
1859 t = a * numext::log(x) + b * numext::log1p(-x) + lgamma_impl<float>::run(a + b) -
1860 lgamma_impl<float>::run(a + 1.0f) - lgamma_impl<float>::run(b);
1861 return (ans + numext::exp(t));
1862 } else {
1863 return betainc_helper<float>::incbsa(a, b, x);
1864 }
1865 }
1866};
1867
1868template <>
1869struct betainc_helper<double> {
1870 EIGEN_DEVICE_FUNC static EIGEN_STRONG_INLINE double incbps(double a, double b, double x) {
1871 const double machep = cephes_helper<double>::machep();
1872
1873 double s, t, u, v, n, t1, z, ai;
1874
1875 ai = 1.0 / a;
1876 u = (1.0 - b) * x;
1877 v = u / (a + 1.0);
1878 t1 = v;
1879 t = u;
1880 n = 2.0;
1881 s = 0.0;
1882 z = machep * ai;
1883 while (numext::abs(v) > z) {
1884 u = (n - b) * x / n;
1885 t *= u;
1886 v = t / (a + n);
1887 s += v;
1888 n += 1.0;
1889 }
1890 s += t1;
1891 s += ai;
1892
1893 u = a * numext::log(x);
1894 // TODO: gamma() is not directly implemented in Eigen.
1895 /*
1896 if ((a + b) < maxgam && numext::abs(u) < maxlog) {
1897 t = gamma(a + b) / (gamma(a) * gamma(b));
1898 s = s * t * pow(x, a);
1899 }
1900 */
1901 t = lgamma_impl<double>::run(a + b) - lgamma_impl<double>::run(a) - lgamma_impl<double>::run(b) + u +
1902 numext::log(s);
1903 return s = numext::exp(t);
1904 }
1905};
1906
1907template <>
1908struct betainc_impl<double> {
1909 EIGEN_DEVICE_FUNC static double run(double aa, double bb, double xx) {
1910 const double nan = NumTraits<double>::quiet_NaN();
1911 const double machep = cephes_helper<double>::machep();
1912 // const double maxgam = 171.624376956302725;
1913
1914 double a, b, t, x, xc, w, y;
1915 bool reversed_a_b = false;
1916
1917 if (aa <= 0.0 || bb <= 0.0) {
1918 return nan; // goto domerr;
1919 }
1920
1921 if ((xx <= 0.0) || (xx >= 1.0)) {
1922 if (xx == 0.0) return (0.0);
1923 if (xx == 1.0) return (1.0);
1924 // mtherr("incbet", DOMAIN);
1925 return nan;
1926 }
1927
1928 if ((bb * xx) <= 1.0 && xx <= 0.95) {
1929 return betainc_helper<double>::incbps(aa, bb, xx);
1930 }
1931
1932 w = 1.0 - xx;
1933
1934 /* Reverse a and b if x is greater than the mean. */
1935 if (xx > (aa / (aa + bb))) {
1936 reversed_a_b = true;
1937 a = bb;
1938 b = aa;
1939 xc = xx;
1940 x = w;
1941 } else {
1942 a = aa;
1943 b = bb;
1944 xc = w;
1945 x = xx;
1946 }
1947
1948 if (reversed_a_b && (b * x) <= 1.0 && x <= 0.95) {
1949 t = betainc_helper<double>::incbps(a, b, x);
1950 if (t <= machep) {
1951 t = 1.0 - machep;
1952 } else {
1953 t = 1.0 - t;
1954 }
1955 return t;
1956 }
1957
1958 /* Choose expansion for better convergence. */
1959 y = x * (a + b - 2.0) - (a - 1.0);
1960 if (y < 0.0) {
1961 w = incbeta_cfe<double>::run(a, b, x, true /* small_branch */);
1962 } else {
1963 w = incbeta_cfe<double>::run(a, b, x, false /* small_branch */) / xc;
1964 }
1965
1966 /* Multiply w by the factor
1967 a b _ _ _
1968 x (1-x) | (a+b) / ( a | (a) | (b) ) . */
1969
1970 y = a * numext::log(x);
1971 t = b * numext::log(xc);
1972 // TODO: gamma is not directly implemented in Eigen.
1973 /*
1974 if ((a + b) < maxgam && numext::abs(y) < maxlog && numext::abs(t) < maxlog)
1975 {
1976 t = pow(xc, b);
1977 t *= pow(x, a);
1978 t /= a;
1979 t *= w;
1980 t *= gamma(a + b) / (gamma(a) * gamma(b));
1981 } else {
1982 */
1983 /* Resort to logarithms. */
1984 y += t + lgamma_impl<double>::run(a + b) - lgamma_impl<double>::run(a) - lgamma_impl<double>::run(b);
1985 y += numext::log(w / a);
1986 t = numext::exp(y);
1987
1988 /* } */
1989 // done:
1990
1991 if (reversed_a_b) {
1992 if (t <= machep) {
1993 t = 1.0 - machep;
1994 } else {
1995 t = 1.0 - t;
1996 }
1997 }
1998 return t;
1999 }
2000};
2001
2002#endif // EIGEN_HAS_C99_MATH
2003
2004} // end namespace internal
2005
2006namespace numext {
2007
2008template <typename Scalar>
2009EIGEN_DEVICE_FUNC inline EIGEN_MATHFUNC_RETVAL(lgamma, Scalar) lgamma(const Scalar& x) {
2010 return EIGEN_MATHFUNC_IMPL(lgamma, Scalar)::run(x);
2011}
2012
2013template <typename Scalar>
2014EIGEN_DEVICE_FUNC inline EIGEN_MATHFUNC_RETVAL(digamma, Scalar) digamma(const Scalar& x) {
2015 return EIGEN_MATHFUNC_IMPL(digamma, Scalar)::run(x);
2016}
2017
2018template <typename Scalar>
2019EIGEN_DEVICE_FUNC inline EIGEN_MATHFUNC_RETVAL(zeta, Scalar) zeta(const Scalar& x, const Scalar& q) {
2020 return EIGEN_MATHFUNC_IMPL(zeta, Scalar)::run(x, q);
2021}
2022
2023template <typename Scalar>
2024EIGEN_DEVICE_FUNC inline EIGEN_MATHFUNC_RETVAL(polygamma, Scalar) polygamma(const Scalar& n, const Scalar& x) {
2025 return EIGEN_MATHFUNC_IMPL(polygamma, Scalar)::run(n, x);
2026}
2027
2028template <typename Scalar>
2029EIGEN_DEVICE_FUNC inline EIGEN_MATHFUNC_RETVAL(erf, Scalar) erf(const Scalar& x) {
2030 return EIGEN_MATHFUNC_IMPL(erf, Scalar)::run(x);
2031}
2032
2033template <typename Scalar>
2034EIGEN_DEVICE_FUNC inline EIGEN_MATHFUNC_RETVAL(erfc, Scalar) erfc(const Scalar& x) {
2035 return EIGEN_MATHFUNC_IMPL(erfc, Scalar)::run(x);
2036}
2037
2038template <typename Scalar>
2039EIGEN_DEVICE_FUNC inline EIGEN_MATHFUNC_RETVAL(ndtri, Scalar) ndtri(const Scalar& x) {
2040 return EIGEN_MATHFUNC_IMPL(ndtri, Scalar)::run(x);
2041}
2042
2043template <typename Scalar>
2044EIGEN_DEVICE_FUNC inline EIGEN_MATHFUNC_RETVAL(igamma, Scalar) igamma(const Scalar& a, const Scalar& x) {
2045 return EIGEN_MATHFUNC_IMPL(igamma, Scalar)::run(a, x);
2046}
2047
2048template <typename Scalar>
2049EIGEN_DEVICE_FUNC inline EIGEN_MATHFUNC_RETVAL(igamma_der_a, Scalar) igamma_der_a(const Scalar& a, const Scalar& x) {
2050 return EIGEN_MATHFUNC_IMPL(igamma_der_a, Scalar)::run(a, x);
2051}
2052
2053template <typename Scalar>
2054EIGEN_DEVICE_FUNC inline EIGEN_MATHFUNC_RETVAL(gamma_sample_der_alpha, Scalar)
2055 gamma_sample_der_alpha(const Scalar& a, const Scalar& x) {
2056 return EIGEN_MATHFUNC_IMPL(gamma_sample_der_alpha, Scalar)::run(a, x);
2057}
2058
2059template <typename Scalar>
2060EIGEN_DEVICE_FUNC inline EIGEN_MATHFUNC_RETVAL(igammac, Scalar) igammac(const Scalar& a, const Scalar& x) {
2061 return EIGEN_MATHFUNC_IMPL(igammac, Scalar)::run(a, x);
2062}
2063
2064template <typename Scalar>
2065EIGEN_DEVICE_FUNC inline EIGEN_MATHFUNC_RETVAL(betainc, Scalar)
2066 betainc(const Scalar& a, const Scalar& b, const Scalar& x) {
2067 return EIGEN_MATHFUNC_IMPL(betainc, Scalar)::run(a, b, x);
2068}
2069
2070} // end namespace numext
2071} // end namespace Eigen
2072
2073#endif // EIGEN_SPECIAL_FUNCTIONS_H
Namespace containing all symbols from the Eigen library.
const Eigen::CwiseBinaryOp< Eigen::internal::scalar_igammac_op< typename Derived::Scalar >, const Derived, const ExponentDerived > igammac(const Eigen::ArrayBase< Derived > &a, const Eigen::ArrayBase< ExponentDerived > &x)
Definition SpecialFunctionsArrayAPI.h:93
const Eigen::CwiseBinaryOp< Eigen::internal::scalar_igamma_der_a_op< typename Derived::Scalar >, const Derived, const ExponentDerived > igamma_der_a(const Eigen::ArrayBase< Derived > &a, const Eigen::ArrayBase< ExponentDerived > &x)
Definition SpecialFunctionsArrayAPI.h:52
const Eigen::CwiseUnaryOp< Eigen::internal::scalar_lgamma_op< typename Derived::Scalar >, const Derived > lgamma(const Eigen::ArrayBase< Derived > &x)
const Eigen::CwiseUnaryOp< Eigen::internal::scalar_erf_op< typename Derived::Scalar >, const Derived > erf(const Eigen::ArrayBase< Derived > &x)
const Eigen::CwiseBinaryOp< Eigen::internal::scalar_gamma_sample_der_alpha_op< typename AlphaDerived::Scalar >, const AlphaDerived, const SampleDerived > gamma_sample_der_alpha(const Eigen::ArrayBase< AlphaDerived > &alpha, const Eigen::ArrayBase< SampleDerived > &sample)
Definition SpecialFunctionsArrayAPI.h:75
const Eigen::CwiseUnaryOp< Eigen::internal::scalar_erfc_op< typename Derived::Scalar >, const Derived > erfc(const Eigen::ArrayBase< Derived > &x)
const TensorCwiseTernaryOp< internal::scalar_betainc_op< typename XDerived::Scalar >, const ADerived, const BDerived, const XDerived > betainc(const Eigen::TensorBase< ADerived, ReadOnlyAccessors > &a, const Eigen::TensorBase< BDerived, ReadOnlyAccessors > &b, const Eigen::TensorBase< XDerived, ReadOnlyAccessors > &x)
Definition TensorGlobalFunctions.h:26
const Eigen::CwiseUnaryOp< Eigen::internal::scalar_ndtri_op< typename Derived::Scalar >, const Derived > ndtri(const Eigen::ArrayBase< Derived > &x)
const Eigen::CwiseUnaryOp< Eigen::internal::scalar_digamma_op< typename Derived::Scalar >, const Derived > digamma(const Eigen::ArrayBase< Derived > &x)
const Eigen::CwiseBinaryOp< Eigen::internal::scalar_polygamma_op< typename DerivedX::Scalar >, const DerivedN, const DerivedX > polygamma(const Eigen::ArrayBase< DerivedN > &n, const Eigen::ArrayBase< DerivedX > &x)
Definition SpecialFunctionsArrayAPI.h:113
const Eigen::CwiseBinaryOp< Eigen::internal::scalar_igamma_op< typename Derived::Scalar >, const Derived, const ExponentDerived > igamma(const Eigen::ArrayBase< Derived > &a, const Eigen::ArrayBase< ExponentDerived > &x)
Definition SpecialFunctionsArrayAPI.h:31
const Eigen::CwiseBinaryOp< Eigen::internal::scalar_zeta_op< typename DerivedX::Scalar >, const DerivedX, const DerivedQ > zeta(const Eigen::ArrayBase< DerivedX > &x, const Eigen::ArrayBase< DerivedQ > &q)
Definition SpecialFunctionsArrayAPI.h:152