10#ifndef EIGEN_SPECIAL_FUNCTIONS_H
11#define EIGEN_SPECIAL_FUNCTIONS_H
14#include "./InternalHeaderCheck.h"
46template <
typename Scalar>
48 EIGEN_STATIC_ASSERT((internal::is_same<Scalar, Scalar>::value ==
false), THIS_TYPE_IS_NOT_SUPPORTED)
50 EIGEN_DEVICE_FUNC
static EIGEN_STRONG_INLINE Scalar run(
const Scalar) {
return Scalar(0); }
53template <
typename Scalar>
60#if defined(__GLIBC__) && ((__GLIBC__ >= 2 && __GLIBC_MINOR__ >= 19) || __GLIBC__ > 2) && \
61 (defined(_DEFAULT_SOURCE) || defined(_BSD_SOURCE) || defined(_SVID_SOURCE))
62#define EIGEN_HAS_LGAMMA_R
66#if defined(__GLIBC__) && ((__GLIBC__ == 2 && __GLIBC_MINOR__ < 19) || __GLIBC__ < 2) && \
67 (defined(_BSD_SOURCE) || defined(_SVID_SOURCE))
68#define EIGEN_HAS_LGAMMA_R
72struct lgamma_impl<float> {
73 EIGEN_DEVICE_FUNC
static EIGEN_STRONG_INLINE
float run(
float x) {
74#if !defined(EIGEN_GPU_COMPILE_PHASE) && defined(EIGEN_HAS_LGAMMA_R) && !defined(__APPLE__)
76 return ::lgammaf_r(x, &dummy);
77#elif defined(SYCL_DEVICE_ONLY)
78 return cl::sycl::lgamma(x);
86struct lgamma_impl<double> {
87 EIGEN_DEVICE_FUNC
static EIGEN_STRONG_INLINE
double run(
double x) {
88#if !defined(EIGEN_GPU_COMPILE_PHASE) && defined(EIGEN_HAS_LGAMMA_R) && !defined(__APPLE__)
90 return ::lgamma_r(x, &dummy);
91#elif defined(SYCL_DEVICE_ONLY)
92 return cl::sycl::lgamma(x);
99#undef EIGEN_HAS_LGAMMA_R
106template <
typename Scalar>
107struct digamma_retval {
124template <
typename Scalar>
125struct digamma_impl_maybe_poly {
126 EIGEN_STATIC_ASSERT((internal::is_same<Scalar, Scalar>::value ==
false), THIS_TYPE_IS_NOT_SUPPORTED)
128 EIGEN_DEVICE_FUNC
static EIGEN_STRONG_INLINE Scalar run(
const Scalar) {
return Scalar(0); }
132struct digamma_impl_maybe_poly<float> {
133 EIGEN_DEVICE_FUNC
static EIGEN_STRONG_INLINE
float run(
const float s) {
134 constexpr float A[] = {-4.16666666666666666667E-3f, 3.96825396825396825397E-3f, -8.33333333333333333333E-3f,
135 8.33333333333333333333E-2f};
140 return z * internal::ppolevl<float, 3>::run(z, A);
147struct digamma_impl_maybe_poly<double> {
148 EIGEN_DEVICE_FUNC
static EIGEN_STRONG_INLINE
double run(
const double s) {
149 constexpr double A[] = {8.33333333333333333333E-2, -2.10927960927960927961E-2, 7.57575757575757575758E-3,
150 -4.16666666666666666667E-3, 3.96825396825396825397E-3, -8.33333333333333333333E-3,
151 8.33333333333333333333E-2};
156 return z * internal::ppolevl<double, 6>::run(z, A);
162template <
typename Scalar>
164 EIGEN_DEVICE_FUNC
static Scalar run(Scalar x) {
222 Scalar p, q, nz, s, w, y;
223 bool negative =
false;
225 const Scalar nan = NumTraits<Scalar>::quiet_NaN();
226 const Scalar m_pi = Scalar(EIGEN_PI);
228 const Scalar zero = Scalar(0);
229 const Scalar one = Scalar(1);
230 const Scalar half = Scalar(0.5);
236 p = numext::floor(q);
249 nz = m_pi / numext::tan(m_pi * nz);
259 while (s < Scalar(10)) {
264 y = digamma_impl_maybe_poly<Scalar>::run(s);
266 y = numext::log(s) - (half / s) - y - w;
268 return (negative) ? y - nz : y;
275template <
typename Scalar>
276struct generic_fast_erfc {
277 template <
typename T>
278 static EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE T run(
const T& x_in);
283EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE T generic_fast_erfc<float>::run(
const T& x_in) {
284 constexpr float kClamp = 11.0f;
285 const T x = pmin(pmax(x_in, pset1<T>(-kClamp)), pset1<T>(kClamp));
292 constexpr float alpha[] = {5.61802298761904239654541015625e-04, -4.91381669417023658752441406250e-03,
293 2.67075151205062866210937500000e-02, -1.12800106406211853027343750000e-01,
294 3.76122951507568359375000000000e-01, -1.12837910652160644531250000000e+00};
295 const T x2 = pmul(x, x);
296 const T one = pset1<T>(1.0f);
297 const T erfc_small = pmadd(x, ppolevl<T, 5>::run(x2, alpha), one);
301 const T x_abs_gt_one_mask = pcmp_lt(one, x2);
302 if (!predux_any(x_abs_gt_one_mask))
return erfc_small;
310 constexpr float gamma[] = {1.0208116471767425537109375e-01f, 4.2920666933059692382812500e-01f,
311 3.2379078865051269531250000e-01f, 5.3971976041793823242187500e-02f};
312 constexpr float delta[] = {1.7251677811145782470703125e-02f, 3.9137163758277893066406250e-01f,
313 1.0000000000000000000000000e+00f, 6.2173241376876831054687500e-01f,
314 9.5662862062454223632812500e-02f};
315 const T x2_lo = twoprod_low(x, x, x2);
320 const T exp2_hi = pexp(pnegate(x2));
321 const T z = pnmadd(exp2_hi, x2_lo, exp2_hi);
322 const T q2 = preciprocal(x2);
323 const T num = ppolevl<T, 3>::run(q2, gamma);
324 const T denom = pmul(x, ppolevl<T, 4>::run(q2, delta));
325 const T r = pdiv(num, denom);
326 const T maybe_two = pselect(pcmp_lt(x, pset1<T>(0.0f)), pset1<T>(2.0f), pset1<T>(0.0f));
327 const T erfc_large = pmadd(z, r, maybe_two);
328 return pselect(x_abs_gt_one_mask, erfc_large, erfc_small);
336EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE T erf_over_x_double_small(
const T& x2) {
342 constexpr double alpha[] = {1.9493725660006057018823477644531294572516344487667083740234375e-04,
343 1.8272566210022942682217328425053892715368419885635375976562500e-03,
344 4.5303363351690106863856044583371840417385101318359375000000000e-02,
345 1.4215015503619179981775744181504705920815467834472656250000000e-01,
346 1.1283791670955125585606992899556644260883331298828125000000000e+00};
347 constexpr double beta[] = {2.0294484101083099089526257108317963684385176748037338256835938e-05,
348 6.8117805899186819641732970609382391558028757572174072265625000e-04,
349 1.0582026056098614921752165685120417037978768348693847656250000e-02,
350 9.3252603143757495374188692949246615171432495117187500000000000e-02,
351 4.5931062818368939559832142549566924571990966796875000000000000e-01,
353 const T num_small = ppolevl<T, 4>::run(x2, alpha);
354 const T denom_small = ppolevl<T, 5>::run(x2, beta);
355 return pdiv(num_small, denom_small);
366EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE T erfc_double_large(
const T& x,
const T& x2) {
367 constexpr double gamma[] = {1.5252844933226974316088642158462107545346952974796295166015625e-04,
368 1.0909912393738931124520519233556115068495273590087890625000000e-02,
369 1.0628604636755033252537572252549580298364162445068359375000000e-01,
370 3.3492472973137982217295416376146022230386734008789062500000000e-01,
371 4.5065776215933289750026347064704168587923049926757812500000000e-01,
372 2.9433039130294824659017649537418037652969360351562500000000000e-01,
373 9.8792676360600226170838311645638896152377128601074218750000000e-02,
374 1.7095935395503719655962981960328761488199234008789062500000000e-02,
375 1.4249109729504577659398023570247460156679153442382812500000000e-03,
376 4.4567378313647954771875570045835956989321857690811157226562500e-05};
377 constexpr double delta[] = {2.041985103115789845773520028160419315099716186523437500000000e-03,
378 5.316030659946043707142493417450168635696172714233398437500000e-02,
379 3.426242193784684864077405563875799998641014099121093750000000e-01,
380 8.565637124308049799026321124983951449394226074218750000000000e-01,
381 1.000000000000000000000000000000000000000000000000000000000000e+00,
382 5.968805280570776972126623149961233139038085937500000000000000e-01,
383 1.890922854723317836356244470152887515723705291748046875000000e-01,
384 3.152505418656005586885981983868987299501895904541015625000000e-02,
385 2.565085751861882583380047861965067568235099315643310546875000e-03,
386 7.899362131678837697403017248376499992446042597293853759765625e-05};
388 const T x2_lo = twoprod_low(x, x, x2);
393 const T exp2_hi = pexp(pnegate(x2));
394 const T z = pnmadd(exp2_hi, x2_lo, exp2_hi);
396 const T q2 = preciprocal(x2);
397 const T num_large = ppolevl<T, 9>::run(q2, gamma);
398 const T denom_large = pmul(x, ppolevl<T, 9>::run(q2, delta));
399 const T r = pdiv(num_large, denom_large);
400 const T maybe_two = pselect(pcmp_lt(x, pset1<T>(0.0)), pset1<T>(2.0), pset1<T>(0.0));
401 return pmadd(z, r, maybe_two);
406EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE T generic_fast_erfc<double>::run(
const T& x_in) {
409 constexpr double kClamp = 28.0;
410 const T x = pmin(pmax(x_in, pset1<T>(-kClamp)), pset1<T>(kClamp));
413 const T x2 = pmul(x, x);
414 const T one = pset1<T>(1.0);
415 const T erfc_small = pnmadd(x, erf_over_x_double_small(x2), one);
419 const T x_abs_gt_one_mask = pcmp_lt(one, x2);
420 if (!predux_any(x_abs_gt_one_mask))
return erfc_small;
422 const T erfc_large = erfc_double_large(x, x2);
423 return pselect(x_abs_gt_one_mask, erfc_large, erfc_small);
428 typedef typename unpacket_traits<T>::type Scalar;
429 EIGEN_DEVICE_FUNC
static EIGEN_STRONG_INLINE T run(
const T& x) {
return generic_fast_erfc<Scalar>::run(x); }
432template <
typename Scalar>
437#if EIGEN_HAS_C99_MATH
439struct erfc_impl<float> {
440 EIGEN_DEVICE_FUNC
static EIGEN_STRONG_INLINE
float run(
const float x) {
441#if defined(SYCL_DEVICE_ONLY)
442 return cl::sycl::erfc(x);
444 return generic_fast_erfc<float>::run(x);
450struct erfc_impl<double> {
451 EIGEN_DEVICE_FUNC
static EIGEN_STRONG_INLINE
double run(
const double x) {
452#if defined(SYCL_DEVICE_ONLY)
453 return cl::sycl::erfc(x);
455 return generic_fast_erfc<double>::run(x);
465template <
typename Scalar>
466struct generic_fast_erf {
467 template <
typename T>
468 static EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE T run(
const T& x_in);
479EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE T generic_fast_erf<float>::run(
const T& x) {
481 constexpr float alpha[] = {2.123732201653183437883853912353515625e-06f, 2.861979592125862836837768554687500000e-04f,
482 3.658048342913389205932617187500000000e-03f, 5.243302136659622192382812500000000000e-02f,
483 1.874160766601562500000000000000000000e-01f, 1.128379106521606445312500000000000000e+00f};
486 constexpr float beta[] = {3.89185734093189239501953125000e-05f, 1.14329601638019084930419921875e-03f,
487 1.47520881146192550659179687500e-02f, 1.12945675849914550781250000000e-01f,
488 4.99425798654556274414062500000e-01f, 1.0f};
493 const T x2 = pmin(pset1<T>(16.0f), pmul(x, x));
496 T p = ppolevl<T, 5>::run(x2, alpha);
500 T q = ppolevl<T, 5>::run(x2, beta);
501 const T r = pdiv(p, q);
504 return pmax(pmin(r, pset1<T>(1.0f)), pset1<T>(-1.0f));
509EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE T generic_fast_erf<double>::run(
const T& x) {
511 T erf_small = pmul(x, erf_over_x_double_small(x2));
515 const T one = pset1<T>(1.0);
516 const T x_abs_gt_one_mask = pcmp_lt(one, x2);
517 if (!predux_any(x_abs_gt_one_mask))
return erf_small;
520 const T erf_large = psub(one, erfc_double_large(x, x2));
521 return pselect(x_abs_gt_one_mask, erf_large, erf_small);
526 typedef typename unpacket_traits<T>::type Scalar;
527 EIGEN_DEVICE_FUNC
static EIGEN_STRONG_INLINE T run(
const T& x) {
return generic_fast_erf<Scalar>::run(x); }
530template <
typename Scalar>
535#if EIGEN_HAS_C99_MATH
537struct erf_impl<float> {
538 EIGEN_DEVICE_FUNC
static EIGEN_STRONG_INLINE
float run(
const float x) {
539#if defined(SYCL_DEVICE_ONLY)
540 return cl::sycl::erf(x);
542 return generic_fast_erf<float>::run(x);
548struct erf_impl<double> {
549 EIGEN_DEVICE_FUNC
static EIGEN_STRONG_INLINE
double run(
const double x) {
550#if defined(SYCL_DEVICE_ONLY)
551 return cl::sycl::erf(x);
553 return generic_fast_erf<double>::run(x);
615EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE T flipsign(
const T& should_flipsign,
const T& x) {
616 typedef typename unpacket_traits<T>::type Scalar;
617 const T sign_mask = pset1<T>(Scalar(-0.0));
618 T sign_bit = pand<T>(should_flipsign, sign_mask);
619 return pxor<T>(sign_bit, x);
623EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE
double flipsign<double>(
const double& should_flipsign,
const double& x) {
624 return should_flipsign == 0 ? x : -x;
628EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE
float flipsign<float>(
const float& should_flipsign,
const float& x) {
629 return should_flipsign == 0 ? x : -x;
636template <
typename T,
typename ScalarType>
637EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE T generic_ndtri_gt_exp_neg_two(
const T& b) {
638 const ScalarType p0[] = {ScalarType(-5.99633501014107895267e1), ScalarType(9.80010754185999661536e1),
639 ScalarType(-5.66762857469070293439e1), ScalarType(1.39312609387279679503e1),
640 ScalarType(-1.23916583867381258016e0)};
641 const ScalarType q0[] = {ScalarType(1.0),
642 ScalarType(1.95448858338141759834e0),
643 ScalarType(4.67627912898881538453e0),
644 ScalarType(8.63602421390890590575e1),
645 ScalarType(-2.25462687854119370527e2),
646 ScalarType(2.00260212380060660359e2),
647 ScalarType(-8.20372256168333339912e1),
648 ScalarType(1.59056225126211695515e1),
649 ScalarType(-1.18331621121330003142e0)};
650 const T sqrt2pi = pset1<T>(ScalarType(2.50662827463100050242e0));
651 const T half = pset1<T>(ScalarType(0.5));
652 T c, c2, ndtri_gt_exp_neg_two;
656 ndtri_gt_exp_neg_two =
657 pmadd(c, pmul(c2, pdiv(internal::ppolevl<T, 4>::run(c2, p0), internal::ppolevl<T, 8>::run(c2, q0))), c);
658 return pmul(ndtri_gt_exp_neg_two, sqrt2pi);
661template <
typename T,
typename ScalarType>
662EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE T generic_ndtri_lt_exp_neg_two(
const T& b,
const T& should_flipsign) {
666 const ScalarType p1[] = {ScalarType(4.05544892305962419923e0), ScalarType(3.15251094599893866154e1),
667 ScalarType(5.71628192246421288162e1), ScalarType(4.40805073893200834700e1),
668 ScalarType(1.46849561928858024014e1), ScalarType(2.18663306850790267539e0),
669 ScalarType(-1.40256079171354495875e-1), ScalarType(-3.50424626827848203418e-2),
670 ScalarType(-8.57456785154685413611e-4)};
671 const ScalarType q1[] = {ScalarType(1.0),
672 ScalarType(1.57799883256466749731e1),
673 ScalarType(4.53907635128879210584e1),
674 ScalarType(4.13172038254672030440e1),
675 ScalarType(1.50425385692907503408e1),
676 ScalarType(2.50464946208309415979e0),
677 ScalarType(-1.42182922854787788574e-1),
678 ScalarType(-3.80806407691578277194e-2),
679 ScalarType(-9.33259480895457427372e-4)};
683 const ScalarType p2[] = {ScalarType(3.23774891776946035970e0), ScalarType(6.91522889068984211695e0),
684 ScalarType(3.93881025292474443415e0), ScalarType(1.33303460815807542389e0),
685 ScalarType(2.01485389549179081538e-1), ScalarType(1.23716634817820021358e-2),
686 ScalarType(3.01581553508235416007e-4), ScalarType(2.65806974686737550832e-6),
687 ScalarType(6.23974539184983293730e-9)};
688 const ScalarType q2[] = {ScalarType(1.0),
689 ScalarType(6.02427039364742014255e0),
690 ScalarType(3.67983563856160859403e0),
691 ScalarType(1.37702099489081330271e0),
692 ScalarType(2.16236993594496635890e-1),
693 ScalarType(1.34204006088543189037e-2),
694 ScalarType(3.28014464682127739104e-4),
695 ScalarType(2.89247864745380683936e-6),
696 ScalarType(6.79019408009981274425e-9)};
697 const T eight = pset1<T>(ScalarType(8.0));
698 const T neg_two = pset1<T>(ScalarType(-2));
701 x = psqrt(pmul(neg_two, plog(b)));
702 x0 = psub(x, pdiv(plog(x), x));
705 pmul(z, pselect(pcmp_lt(x, eight), pdiv(internal::ppolevl<T, 8>::run(z, p1), internal::ppolevl<T, 8>::run(z, q1)),
706 pdiv(internal::ppolevl<T, 8>::run(z, p2), internal::ppolevl<T, 8>::run(z, q2))));
707 return flipsign(should_flipsign, psub(x0, x1));
710template <
typename T,
typename ScalarType>
711EIGEN_DEVICE_FUNC EIGEN_ALWAYS_INLINE T generic_ndtri(
const T& a) {
712 const T maxnum = pset1<T>(NumTraits<ScalarType>::infinity());
713 const T neg_maxnum = pset1<T>(-NumTraits<ScalarType>::infinity());
715 const T zero = pset1<T>(ScalarType(0));
716 const T one = pset1<T>(ScalarType(1));
718 const T exp_neg_two = pset1<T>(ScalarType(0.13533528323661269189));
719 T b,
ndtri, should_flipsign;
721 should_flipsign = pcmp_le(a, psub(one, exp_neg_two));
722 b = pselect(should_flipsign, a, psub(one, a));
724 ndtri = pselect(pcmp_lt(exp_neg_two, b), generic_ndtri_gt_exp_neg_two<T, ScalarType>(b),
725 generic_ndtri_lt_exp_neg_two<T, ScalarType>(b, should_flipsign));
727 return pselect(pcmp_eq(a, zero), neg_maxnum, pselect(pcmp_eq(one, a), maxnum,
ndtri));
730template <
typename Scalar>
735#if !EIGEN_HAS_C99_MATH
737template <
typename Scalar>
739 EIGEN_STATIC_ASSERT((internal::is_same<Scalar, Scalar>::value ==
false), THIS_TYPE_IS_NOT_SUPPORTED)
741 EIGEN_DEVICE_FUNC
static EIGEN_STRONG_INLINE Scalar run(
const Scalar) {
return Scalar(0); }
746template <
typename Scalar>
748 EIGEN_DEVICE_FUNC
static EIGEN_STRONG_INLINE Scalar run(
const Scalar x) {
return generic_ndtri<Scalar, Scalar>(x); }
757template <
typename Scalar>
758struct igammac_retval {
763template <
typename Scalar>
764struct cephes_helper {
765 EIGEN_DEVICE_FUNC
static EIGEN_STRONG_INLINE Scalar machep() {
766 eigen_assert(
false &&
"machep not supported for this type");
769 EIGEN_DEVICE_FUNC
static EIGEN_STRONG_INLINE Scalar big() {
770 eigen_assert(
false &&
"big not supported for this type");
773 EIGEN_DEVICE_FUNC
static EIGEN_STRONG_INLINE Scalar biginv() {
774 eigen_assert(
false &&
"biginv not supported for this type");
780struct cephes_helper<float> {
781 EIGEN_DEVICE_FUNC
static EIGEN_STRONG_INLINE
float machep() {
782 return NumTraits<float>::epsilon() / 2;
784 EIGEN_DEVICE_FUNC
static EIGEN_STRONG_INLINE
float big() {
786 return 1.0f / (NumTraits<float>::epsilon() / 2);
788 EIGEN_DEVICE_FUNC
static EIGEN_STRONG_INLINE
float biginv() {
795struct cephes_helper<double> {
796 EIGEN_DEVICE_FUNC
static EIGEN_STRONG_INLINE
double machep() {
797 return NumTraits<double>::epsilon() / 2;
799 EIGEN_DEVICE_FUNC
static EIGEN_STRONG_INLINE
double big() {
return 1.0 / NumTraits<double>::epsilon(); }
800 EIGEN_DEVICE_FUNC
static EIGEN_STRONG_INLINE
double biginv() {
802 return NumTraits<double>::epsilon();
806enum IgammaComputationMode { VALUE, DERIVATIVE, SAMPLE_DERIVATIVE };
808template <
typename Scalar>
809EIGEN_DEVICE_FUNC
static EIGEN_STRONG_INLINE Scalar main_igamma_term(Scalar a, Scalar x) {
811 Scalar logax = a * numext::log(x) - x - lgamma_impl<Scalar>::run(a);
812 if (logax < -numext::log(NumTraits<Scalar>::highest()) ||
814 (numext::isnan)(logax)) {
817 return numext::exp(logax);
820template <
typename Scalar, IgammaComputationMode mode>
821EIGEN_DEVICE_FUNC
int igamma_num_iterations() {
828 if (internal::is_same<Scalar, float>::value) {
830 }
else if (internal::is_same<Scalar, double>::value) {
837template <
typename Scalar, IgammaComputationMode mode>
838struct igammac_cf_impl {
848 EIGEN_DEVICE_FUNC
static Scalar run(Scalar a, Scalar x) {
849 const Scalar zero = 0;
850 const Scalar one = 1;
851 const Scalar two = 2;
852 const Scalar machep = cephes_helper<Scalar>::machep();
853 const Scalar big = cephes_helper<Scalar>::big();
854 const Scalar biginv = cephes_helper<Scalar>::biginv();
856 if ((numext::isinf)(x)) {
860 Scalar ax = main_igamma_term<Scalar>(a, x);
871 Scalar z = x + y + one;
875 Scalar pkm1 = x + one;
877 Scalar ans = pkm1 / qkm1;
879 Scalar dpkm2_da = zero;
880 Scalar dqkm2_da = zero;
881 Scalar dpkm1_da = zero;
882 Scalar dqkm1_da = -x;
883 Scalar dans_da = (dpkm1_da - ans * dqkm1_da) / qkm1;
885 for (
int i = 0; i < igamma_num_iterations<Scalar, mode>(); i++) {
891 Scalar pk = pkm1 * z - pkm2 * yc;
892 Scalar qk = qkm1 * z - qkm2 * yc;
894 Scalar dpk_da = dpkm1_da * z - pkm1 - dpkm2_da * yc + pkm2 * c;
895 Scalar dqk_da = dqkm1_da * z - qkm1 - dqkm2_da * yc + qkm2 * c;
898 Scalar ans_prev = ans;
901 Scalar dans_da_prev = dans_da;
902 dans_da = (dpk_da - ans * dqk_da) / qk;
905 if (numext::abs(ans_prev - ans) <= machep * numext::abs(ans)) {
909 if (numext::abs(dans_da - dans_da_prev) <= machep) {
925 if (numext::abs(pk) > big) {
939 Scalar dlogax_da = numext::log(x) - digamma_impl<Scalar>::run(a);
940 Scalar dax_da = ax * dlogax_da;
946 return ans * dax_da + dans_da * ax;
947 case SAMPLE_DERIVATIVE:
949 return -(dans_da + ans * dlogax_da) * x;
954template <
typename Scalar, IgammaComputationMode mode>
955struct igamma_series_impl {
964 EIGEN_DEVICE_FUNC
static Scalar run(Scalar a, Scalar x) {
965 const Scalar zero = 0;
966 const Scalar one = 1;
967 const Scalar machep = cephes_helper<Scalar>::machep();
969 Scalar ax = main_igamma_term<Scalar>(a, x);
987 Scalar dans_da = zero;
989 for (
int i = 0; i < igamma_num_iterations<Scalar, mode>(); i++) {
992 Scalar dterm_da = -x / (r * r);
993 dc_da = term * dc_da + dterm_da * c;
999 if (c <= machep * ans) {
1003 if (numext::abs(dc_da) <= machep * numext::abs(dans_da)) {
1009 Scalar dlogax_da = numext::log(x) - digamma_impl<Scalar>::run(a + one);
1010 Scalar dax_da = ax * dlogax_da;
1016 return ans * dax_da + dans_da * ax;
1017 case SAMPLE_DERIVATIVE:
1019 return -(dans_da + ans * dlogax_da) * x / a;
1024#if !EIGEN_HAS_C99_MATH
1026template <
typename Scalar>
1027struct igammac_impl {
1028 EIGEN_STATIC_ASSERT((internal::is_same<Scalar, Scalar>::value ==
false), THIS_TYPE_IS_NOT_SUPPORTED)
1030 EIGEN_DEVICE_FUNC
static Scalar run(Scalar a, Scalar x) {
return Scalar(0); }
1035template <
typename Scalar>
1036struct igammac_impl {
1037 EIGEN_DEVICE_FUNC
static Scalar run(Scalar a, Scalar x) {
1092 const Scalar zero = 0;
1093 const Scalar one = 1;
1094 const Scalar nan = NumTraits<Scalar>::quiet_NaN();
1096 if ((x < zero) || (a <= zero)) {
1101 if ((numext::isnan)(a) || (numext::isnan)(x)) {
1105 if ((x < one) || (x < a)) {
1106 return (one - igamma_series_impl<Scalar, VALUE>::run(a, x));
1109 return igammac_cf_impl<Scalar, VALUE>::run(a, x);
1119#if !EIGEN_HAS_C99_MATH
1121template <
typename Scalar, IgammaComputationMode mode>
1122struct igamma_generic_impl {
1123 EIGEN_STATIC_ASSERT((internal::is_same<Scalar, Scalar>::value ==
false), THIS_TYPE_IS_NOT_SUPPORTED)
1125 EIGEN_DEVICE_FUNC
static EIGEN_STRONG_INLINE Scalar run(Scalar a, Scalar x) {
return Scalar(0); }
1130template <
typename Scalar, IgammaComputationMode mode>
1131struct igamma_generic_impl {
1132 EIGEN_DEVICE_FUNC
static Scalar run(Scalar a, Scalar x) {
1141 const Scalar zero = 0;
1142 const Scalar one = 1;
1143 const Scalar nan = NumTraits<Scalar>::quiet_NaN();
1145 if (x == zero)
return zero;
1147 if ((x < zero) || (a <= zero)) {
1151 if ((numext::isnan)(a) || (numext::isnan)(x)) {
1155 if ((x > one) && (x > a)) {
1156 Scalar ret = igammac_cf_impl<Scalar, mode>::run(a, x);
1157 if (mode == VALUE) {
1164 return igamma_series_impl<Scalar, mode>::run(a, x);
1170template <
typename Scalar>
1171struct igamma_retval {
1172 typedef Scalar type;
1175template <
typename Scalar>
1176struct igamma_impl : igamma_generic_impl<Scalar, VALUE> {
1246template <
typename Scalar>
1247struct igamma_der_a_retval : igamma_retval<Scalar> {};
1249template <
typename Scalar>
1250struct igamma_der_a_impl : igamma_generic_impl<Scalar, DERIVATIVE> {
1267template <
typename Scalar>
1268struct gamma_sample_der_alpha_retval : igamma_retval<Scalar> {};
1270template <
typename Scalar>
1271struct gamma_sample_der_alpha_impl : igamma_generic_impl<Scalar, SAMPLE_DERIVATIVE> {
1315template <
typename Scalar>
1317 typedef Scalar type;
1320template <
typename Scalar>
1321struct zeta_impl_series {
1322 EIGEN_STATIC_ASSERT((internal::is_same<Scalar, Scalar>::value ==
false), THIS_TYPE_IS_NOT_SUPPORTED)
1324 EIGEN_DEVICE_FUNC
static EIGEN_STRONG_INLINE Scalar run(
const Scalar) {
return Scalar(0); }
1328struct zeta_impl_series<float> {
1329 EIGEN_DEVICE_FUNC
static EIGEN_STRONG_INLINE
bool run(
float& a,
float& b,
float& s,
const float x,
1330 const float machep) {
1335 b = numext::pow(a, -x);
1337 if (numext::abs(b / s) < machep)
return true;
1346struct zeta_impl_series<double> {
1347 EIGEN_DEVICE_FUNC
static EIGEN_STRONG_INLINE
bool run(
double& a,
double& b,
double& s,
const double x,
1348 const double machep) {
1350 while ((i < 9) || (a <= 9.0)) {
1353 b = numext::pow(a, -x);
1355 if (numext::abs(b / s) < machep)
return true;
1363template <
typename Scalar>
1365 EIGEN_DEVICE_FUNC
static Scalar run(Scalar x, Scalar q) {
1428 Scalar p, r, a, b, k, s, t, w;
1430 const Scalar A[] = {
1436 Scalar(-1.8924375803183791606e9),
1437 Scalar(7.47242496e10),
1438 Scalar(-2.950130727918164224e12),
1439 Scalar(1.1646782814350067249e14),
1440 Scalar(-4.5979787224074726105e15),
1441 Scalar(1.8152105401943546773e17),
1442 Scalar(-7.1661652561756670113e18)
1445 const Scalar maxnum = NumTraits<Scalar>::infinity();
1446 const Scalar zero = Scalar(0.0), half = Scalar(0.5), one = Scalar(1.0);
1447 const Scalar machep = cephes_helper<Scalar>::machep();
1448 const Scalar nan = NumTraits<Scalar>::quiet_NaN();
1450 if (x == one)
return maxnum;
1457 if (q == numext::floor(q)) {
1458 if (numext::rint(Scalar(0.5) * x) == Scalar(0.5) * x) {
1465 r = numext::floor(p);
1466 if (p != r)
return nan;
1474 s = numext::pow(q, -x);
1478 if (zeta_impl_series<Scalar>::run(a, b, s, x, machep)) {
1485 if (numext::equal_strict(b, zero)) {
1490 s += b * w / (x - one);
1495 for (i = 0; i < 12; i++) {
1500 t = numext::abs(t / s);
1517template <
typename Scalar>
1518struct polygamma_retval {
1519 typedef Scalar type;
1522#if !EIGEN_HAS_C99_MATH
1524template <
typename Scalar>
1525struct polygamma_impl {
1526 EIGEN_STATIC_ASSERT((internal::is_same<Scalar, Scalar>::value ==
false), THIS_TYPE_IS_NOT_SUPPORTED)
1528 EIGEN_DEVICE_FUNC
static EIGEN_STRONG_INLINE Scalar run(Scalar n, Scalar x) {
return Scalar(0); }
1533template <
typename Scalar>
1534struct polygamma_impl {
1535 EIGEN_DEVICE_FUNC
static Scalar run(Scalar n, Scalar x) {
1536 Scalar zero = 0.0, one = 1.0;
1537 Scalar nplus = n + one;
1538 const Scalar nan = NumTraits<Scalar>::quiet_NaN();
1541 if (numext::floor(n) != n || n < zero) {
1545 else if (n == zero) {
1546 return digamma_impl<Scalar>::run(x);
1550 Scalar factorial = numext::exp(lgamma_impl<Scalar>::run(nplus));
1551 return numext::pow(-one, nplus) * factorial * zeta_impl<Scalar>::run(nplus, x);
1562template <
typename Scalar>
1563struct betainc_retval {
1564 typedef Scalar type;
1567#if !EIGEN_HAS_C99_MATH
1569template <
typename Scalar>
1570struct betainc_impl {
1571 EIGEN_STATIC_ASSERT((internal::is_same<Scalar, Scalar>::value ==
false), THIS_TYPE_IS_NOT_SUPPORTED)
1573 EIGEN_DEVICE_FUNC
static EIGEN_STRONG_INLINE Scalar run(Scalar a, Scalar b, Scalar x) {
return Scalar(0); }
1578template <
typename Scalar>
1579struct betainc_impl {
1580 EIGEN_STATIC_ASSERT((internal::is_same<Scalar, Scalar>::value ==
false), THIS_TYPE_IS_NOT_SUPPORTED)
1582 EIGEN_DEVICE_FUNC
static EIGEN_STRONG_INLINE Scalar run(Scalar, Scalar, Scalar) {
1658template <
typename Scalar>
1660 EIGEN_STATIC_ASSERT((internal::is_same<Scalar, float>::value || internal::is_same<Scalar, double>::value),
1661 THIS_TYPE_IS_NOT_SUPPORTED)
1663 EIGEN_DEVICE_FUNC
static EIGEN_STRONG_INLINE Scalar run(Scalar a, Scalar b, Scalar x,
bool small_branch) {
1664 const Scalar big = cephes_helper<Scalar>::big();
1665 const Scalar machep = cephes_helper<Scalar>::machep();
1666 const Scalar biginv = cephes_helper<Scalar>::biginv();
1668 const Scalar zero = 0;
1669 const Scalar one = 1;
1670 const Scalar two = 2;
1672 Scalar xk, pk, pkm1, pkm2, qk, qkm1, qkm2;
1673 Scalar k1, k2, k3, k4, k5, k6, k7, k8, k26update;
1677 const int num_iters = (internal::is_same<Scalar, float>::value) ? 100 : 300;
1678 const Scalar thresh = (internal::is_same<Scalar, float>::value) ? machep : Scalar(3) * machep;
1679 Scalar r = (internal::is_same<Scalar, float>::value) ? zero : one;
1712 xk = -(x * k1 * k2) / (k3 * k4);
1713 pk = pkm1 + pkm2 * xk;
1714 qk = qkm1 + qkm2 * xk;
1720 xk = (x * k5 * k6) / (k7 * k8);
1721 pk = pkm1 + pkm2 * xk;
1722 qk = qkm1 + qkm2 * xk;
1730 if (numext::abs(ans - r) < numext::abs(r) * thresh) {
1745 if ((numext::abs(qk) + numext::abs(pk)) > big) {
1751 if ((numext::abs(qk) < biginv) || (numext::abs(pk) < biginv)) {
1757 }
while (++n < num_iters);
1764template <
typename Scalar>
1765struct betainc_helper {};
1768struct betainc_helper<float> {
1770 EIGEN_DEVICE_FUNC
static EIGEN_STRONG_INLINE
float incbsa(
float aa,
float bb,
float xx) {
1771 float ans, a, b, t, x, onemx;
1772 bool reversed_a_b =
false;
1777 if (xx > (aa / (aa + bb))) {
1778 reversed_a_b =
true;
1792 if (numext::abs(b * x / a) < 0.3f) {
1793 t = betainc_helper<float>::incbps(a, b, x);
1794 if (reversed_a_b) t = 1.0f - t;
1799 ans = x * (a + b - 2.0f) / (a - 1.0f);
1801 ans = incbeta_cfe<float>::run(a, b, x,
true );
1802 t = b * numext::log(t);
1804 ans = incbeta_cfe<float>::run(a, b, x,
false );
1805 t = (b - 1.0f) * numext::log(t);
1808 t += a * numext::log(x) + lgamma_impl<float>::run(a + b) - lgamma_impl<float>::run(a) - lgamma_impl<float>::run(b);
1809 t += numext::log(ans / a);
1812 if (reversed_a_b) t = 1.0f - t;
1816 EIGEN_DEVICE_FUNC
static EIGEN_STRONG_INLINE
float incbps(
float a,
float b,
float x) {
1818 const float machep = cephes_helper<float>::machep();
1820 y = a * numext::log(x) + (b - 1.0f) * numext::log1p(-x) - numext::log(a);
1821 y -= lgamma_impl<float>::run(a) + lgamma_impl<float>::run(b);
1822 y += lgamma_impl<float>::run(a + b);
1835 }
while (numext::abs(u) > machep);
1837 return numext::exp(y) * (1.0f + s);
1842struct betainc_impl<float> {
1843 EIGEN_DEVICE_FUNC
static float run(
float a,
float b,
float x) {
1844 const float nan = NumTraits<float>::quiet_NaN();
1847 if (a <= 0.0f)
return nan;
1848 if (b <= 0.0f)
return nan;
1849 if ((x <= 0.0f) || (x >= 1.0f)) {
1850 if (x == 0.0f)
return 0.0f;
1851 if (x == 1.0f)
return 1.0f;
1858 ans = betainc_helper<float>::incbsa(a + 1.0f, b, x);
1859 t = a * numext::log(x) + b * numext::log1p(-x) + lgamma_impl<float>::run(a + b) -
1860 lgamma_impl<float>::run(a + 1.0f) - lgamma_impl<float>::run(b);
1861 return (ans + numext::exp(t));
1863 return betainc_helper<float>::incbsa(a, b, x);
1869struct betainc_helper<double> {
1870 EIGEN_DEVICE_FUNC
static EIGEN_STRONG_INLINE
double incbps(
double a,
double b,
double x) {
1871 const double machep = cephes_helper<double>::machep();
1873 double s, t, u, v, n, t1, z, ai;
1883 while (numext::abs(v) > z) {
1884 u = (n - b) * x / n;
1893 u = a * numext::log(x);
1901 t = lgamma_impl<double>::run(a + b) - lgamma_impl<double>::run(a) - lgamma_impl<double>::run(b) + u +
1903 return s = numext::exp(t);
1908struct betainc_impl<double> {
1909 EIGEN_DEVICE_FUNC
static double run(
double aa,
double bb,
double xx) {
1910 const double nan = NumTraits<double>::quiet_NaN();
1911 const double machep = cephes_helper<double>::machep();
1914 double a, b, t, x, xc, w, y;
1915 bool reversed_a_b =
false;
1917 if (aa <= 0.0 || bb <= 0.0) {
1921 if ((xx <= 0.0) || (xx >= 1.0)) {
1922 if (xx == 0.0)
return (0.0);
1923 if (xx == 1.0)
return (1.0);
1928 if ((bb * xx) <= 1.0 && xx <= 0.95) {
1929 return betainc_helper<double>::incbps(aa, bb, xx);
1935 if (xx > (aa / (aa + bb))) {
1936 reversed_a_b =
true;
1948 if (reversed_a_b && (b * x) <= 1.0 && x <= 0.95) {
1949 t = betainc_helper<double>::incbps(a, b, x);
1959 y = x * (a + b - 2.0) - (a - 1.0);
1961 w = incbeta_cfe<double>::run(a, b, x,
true );
1963 w = incbeta_cfe<double>::run(a, b, x,
false ) / xc;
1970 y = a * numext::log(x);
1971 t = b * numext::log(xc);
1984 y += t + lgamma_impl<double>::run(a + b) - lgamma_impl<double>::run(a) - lgamma_impl<double>::run(b);
1985 y += numext::log(w / a);
2008template <
typename Scalar>
2009EIGEN_DEVICE_FUNC
inline EIGEN_MATHFUNC_RETVAL(
lgamma, Scalar)
lgamma(
const Scalar& x) {
2010 return EIGEN_MATHFUNC_IMPL(
lgamma, Scalar)::run(x);
2013template <
typename Scalar>
2014EIGEN_DEVICE_FUNC
inline EIGEN_MATHFUNC_RETVAL(
digamma, Scalar)
digamma(
const Scalar& x) {
2015 return EIGEN_MATHFUNC_IMPL(
digamma, Scalar)::run(x);
2018template <
typename Scalar>
2019EIGEN_DEVICE_FUNC
inline EIGEN_MATHFUNC_RETVAL(
zeta, Scalar)
zeta(
const Scalar& x,
const Scalar& q) {
2020 return EIGEN_MATHFUNC_IMPL(
zeta, Scalar)::run(x, q);
2023template <
typename Scalar>
2024EIGEN_DEVICE_FUNC
inline EIGEN_MATHFUNC_RETVAL(
polygamma, Scalar)
polygamma(
const Scalar& n,
const Scalar& x) {
2025 return EIGEN_MATHFUNC_IMPL(
polygamma, Scalar)::run(n, x);
2028template <
typename Scalar>
2029EIGEN_DEVICE_FUNC
inline EIGEN_MATHFUNC_RETVAL(
erf, Scalar)
erf(
const Scalar& x) {
2030 return EIGEN_MATHFUNC_IMPL(
erf, Scalar)::run(x);
2033template <
typename Scalar>
2034EIGEN_DEVICE_FUNC
inline EIGEN_MATHFUNC_RETVAL(
erfc, Scalar)
erfc(
const Scalar& x) {
2035 return EIGEN_MATHFUNC_IMPL(
erfc, Scalar)::run(x);
2038template <
typename Scalar>
2039EIGEN_DEVICE_FUNC
inline EIGEN_MATHFUNC_RETVAL(
ndtri, Scalar)
ndtri(
const Scalar& x) {
2040 return EIGEN_MATHFUNC_IMPL(
ndtri, Scalar)::run(x);
2043template <
typename Scalar>
2044EIGEN_DEVICE_FUNC
inline EIGEN_MATHFUNC_RETVAL(
igamma, Scalar)
igamma(
const Scalar& a,
const Scalar& x) {
2045 return EIGEN_MATHFUNC_IMPL(
igamma, Scalar)::run(a, x);
2048template <
typename Scalar>
2049EIGEN_DEVICE_FUNC
inline EIGEN_MATHFUNC_RETVAL(
igamma_der_a, Scalar)
igamma_der_a(
const Scalar& a,
const Scalar& x) {
2050 return EIGEN_MATHFUNC_IMPL(
igamma_der_a, Scalar)::run(a, x);
2053template <
typename Scalar>
2059template <
typename Scalar>
2060EIGEN_DEVICE_FUNC
inline EIGEN_MATHFUNC_RETVAL(
igammac, Scalar)
igammac(
const Scalar& a,
const Scalar& x) {
2061 return EIGEN_MATHFUNC_IMPL(
igammac, Scalar)::run(a, x);
2064template <
typename Scalar>
2065EIGEN_DEVICE_FUNC
inline EIGEN_MATHFUNC_RETVAL(
betainc, Scalar)
2066 betainc(
const Scalar& a,
const Scalar& b,
const Scalar& x) {
2067 return EIGEN_MATHFUNC_IMPL(
betainc, Scalar)::run(a, b, x);
Namespace containing all symbols from the Eigen library.
const Eigen::CwiseBinaryOp< Eigen::internal::scalar_igammac_op< typename Derived::Scalar >, const Derived, const ExponentDerived > igammac(const Eigen::ArrayBase< Derived > &a, const Eigen::ArrayBase< ExponentDerived > &x)
Definition SpecialFunctionsArrayAPI.h:93
const Eigen::CwiseBinaryOp< Eigen::internal::scalar_igamma_der_a_op< typename Derived::Scalar >, const Derived, const ExponentDerived > igamma_der_a(const Eigen::ArrayBase< Derived > &a, const Eigen::ArrayBase< ExponentDerived > &x)
Definition SpecialFunctionsArrayAPI.h:52
const Eigen::CwiseUnaryOp< Eigen::internal::scalar_lgamma_op< typename Derived::Scalar >, const Derived > lgamma(const Eigen::ArrayBase< Derived > &x)
const Eigen::CwiseUnaryOp< Eigen::internal::scalar_erf_op< typename Derived::Scalar >, const Derived > erf(const Eigen::ArrayBase< Derived > &x)
const Eigen::CwiseBinaryOp< Eigen::internal::scalar_gamma_sample_der_alpha_op< typename AlphaDerived::Scalar >, const AlphaDerived, const SampleDerived > gamma_sample_der_alpha(const Eigen::ArrayBase< AlphaDerived > &alpha, const Eigen::ArrayBase< SampleDerived > &sample)
Definition SpecialFunctionsArrayAPI.h:75
const Eigen::CwiseUnaryOp< Eigen::internal::scalar_erfc_op< typename Derived::Scalar >, const Derived > erfc(const Eigen::ArrayBase< Derived > &x)
const TensorCwiseTernaryOp< internal::scalar_betainc_op< typename XDerived::Scalar >, const ADerived, const BDerived, const XDerived > betainc(const Eigen::TensorBase< ADerived, ReadOnlyAccessors > &a, const Eigen::TensorBase< BDerived, ReadOnlyAccessors > &b, const Eigen::TensorBase< XDerived, ReadOnlyAccessors > &x)
Definition TensorGlobalFunctions.h:26
const Eigen::CwiseUnaryOp< Eigen::internal::scalar_ndtri_op< typename Derived::Scalar >, const Derived > ndtri(const Eigen::ArrayBase< Derived > &x)
const Eigen::CwiseUnaryOp< Eigen::internal::scalar_digamma_op< typename Derived::Scalar >, const Derived > digamma(const Eigen::ArrayBase< Derived > &x)
const Eigen::CwiseBinaryOp< Eigen::internal::scalar_polygamma_op< typename DerivedX::Scalar >, const DerivedN, const DerivedX > polygamma(const Eigen::ArrayBase< DerivedN > &n, const Eigen::ArrayBase< DerivedX > &x)
Definition SpecialFunctionsArrayAPI.h:113
const Eigen::CwiseBinaryOp< Eigen::internal::scalar_igamma_op< typename Derived::Scalar >, const Derived, const ExponentDerived > igamma(const Eigen::ArrayBase< Derived > &a, const Eigen::ArrayBase< ExponentDerived > &x)
Definition SpecialFunctionsArrayAPI.h:31
const Eigen::CwiseBinaryOp< Eigen::internal::scalar_zeta_op< typename DerivedX::Scalar >, const DerivedX, const DerivedQ > zeta(const Eigen::ArrayBase< DerivedX > &x, const Eigen::ArrayBase< DerivedQ > &q)
Definition SpecialFunctionsArrayAPI.h:152