10#ifndef EIGEN_SPECIAL_FUNCTIONS_H
11#define EIGEN_SPECIAL_FUNCTIONS_H
44template <
typename Scalar>
47 static EIGEN_STRONG_INLINE Scalar run(
const Scalar) {
48 EIGEN_STATIC_ASSERT((internal::is_same<Scalar, Scalar>::value ==
false),
49 THIS_TYPE_IS_NOT_SUPPORTED);
54template <
typename Scalar>
61#if defined(__GLIBC__) && ((__GLIBC__>=2 && __GLIBC_MINOR__ >= 19) || __GLIBC__>2) \
62 && (defined(_DEFAULT_SOURCE) || defined(_BSD_SOURCE) || defined(_SVID_SOURCE))
63#define EIGEN_HAS_LGAMMA_R
67#if defined(__GLIBC__) && ((__GLIBC__==2 && __GLIBC_MINOR__ < 19) || __GLIBC__<2) \
68 && (defined(_BSD_SOURCE) || defined(_SVID_SOURCE))
69#define EIGEN_HAS_LGAMMA_R
73struct lgamma_impl<float> {
75 static EIGEN_STRONG_INLINE
float run(
float x) {
76#if !defined(EIGEN_GPU_COMPILE_PHASE) && defined (EIGEN_HAS_LGAMMA_R) && !defined(__APPLE__)
78 return ::lgammaf_r(x, &dummy);
79#elif defined(SYCL_DEVICE_ONLY)
80 return cl::sycl::lgamma(x);
88struct lgamma_impl<double> {
90 static EIGEN_STRONG_INLINE
double run(
double x) {
91#if !defined(EIGEN_GPU_COMPILE_PHASE) && defined(EIGEN_HAS_LGAMMA_R) && !defined(__APPLE__)
93 return ::lgamma_r(x, &dummy);
94#elif defined(SYCL_DEVICE_ONLY)
95 return cl::sycl::lgamma(x);
102#undef EIGEN_HAS_LGAMMA_R
109template <
typename Scalar>
110struct digamma_retval {
127template <
typename Scalar>
128struct digamma_impl_maybe_poly {
130 static EIGEN_STRONG_INLINE Scalar run(
const Scalar) {
131 EIGEN_STATIC_ASSERT((internal::is_same<Scalar, Scalar>::value ==
false),
132 THIS_TYPE_IS_NOT_SUPPORTED);
139struct digamma_impl_maybe_poly<float> {
141 static EIGEN_STRONG_INLINE
float run(
const float s) {
143 -4.16666666666666666667E-3f,
144 3.96825396825396825397E-3f,
145 -8.33333333333333333333E-3f,
146 8.33333333333333333333E-2f
152 return z * internal::ppolevl<float, 3>::run(z, A);
158struct digamma_impl_maybe_poly<double> {
160 static EIGEN_STRONG_INLINE
double run(
const double s) {
162 8.33333333333333333333E-2,
163 -2.10927960927960927961E-2,
164 7.57575757575757575758E-3,
165 -4.16666666666666666667E-3,
166 3.96825396825396825397E-3,
167 -8.33333333333333333333E-3,
168 8.33333333333333333333E-2
174 return z * internal::ppolevl<double, 6>::run(z, A);
180template <
typename Scalar>
183 static Scalar run(Scalar x) {
241 Scalar p, q, nz, s, w, y;
242 bool negative =
false;
244 const Scalar nan = NumTraits<Scalar>::quiet_NaN();
245 const Scalar m_pi = Scalar(EIGEN_PI);
247 const Scalar zero = Scalar(0);
248 const Scalar one = Scalar(1);
249 const Scalar half = Scalar(0.5);
255 p = numext::floor(q);
268 nz = m_pi / numext::tan(m_pi * nz);
279 while (s < Scalar(10)) {
284 y = digamma_impl_maybe_poly<Scalar>::run(s);
286 y = numext::log(s) - (half / s) - y - w;
288 return (negative) ? y - nz : y;
304EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE T generic_fast_erf_float(
const T& x) {
305 const float kErfInvOneMinusHalfULP = 3.832506856900711f;
306 const T clamp = pcmp_le(pset1<T>(kErfInvOneMinusHalfULP), pabs(x));
308 const T alpha_1 = pset1<T>(-1.60960333262415e-02f);
309 const T alpha_3 = pset1<T>(-2.95459980854025e-03f);
310 const T alpha_5 = pset1<T>(-7.34990630326855e-04f);
311 const T alpha_7 = pset1<T>(-5.69250639462346e-05f);
312 const T alpha_9 = pset1<T>(-2.10102402082508e-06f);
313 const T alpha_11 = pset1<T>(2.77068142495902e-08f);
314 const T alpha_13 = pset1<T>(-2.72614225801306e-10f);
317 const T beta_0 = pset1<T>(-1.42647390514189e-02f);
318 const T beta_2 = pset1<T>(-7.37332916720468e-03f);
319 const T beta_4 = pset1<T>(-1.68282697438203e-03f);
320 const T beta_6 = pset1<T>(-2.13374055278905e-04f);
321 const T beta_8 = pset1<T>(-1.45660718464996e-05f);
324 const T x2 = pmul(x, x);
327 T p = pmadd(x2, alpha_13, alpha_11);
328 p = pmadd(x2, p, alpha_9);
329 p = pmadd(x2, p, alpha_7);
330 p = pmadd(x2, p, alpha_5);
331 p = pmadd(x2, p, alpha_3);
332 p = pmadd(x2, p, alpha_1);
336 T q = pmadd(x2, beta_8, beta_6);
337 q = pmadd(x2, q, beta_4);
338 q = pmadd(x2, q, beta_2);
339 q = pmadd(x2, q, beta_0);
342 const T
sign = pselect(pcmp_le(x, pset1<T>(0.0f)), pset1<T>(-1.0f), pset1<T>(1.0f));
343 return pselect(clamp,
sign, pdiv(p, q));
349 static EIGEN_STRONG_INLINE T run(
const T& x) {
350 return generic_fast_erf_float(x);
354template <
typename Scalar>
359#if EIGEN_HAS_C99_MATH
361struct erf_impl<float> {
363 static EIGEN_STRONG_INLINE
float run(
float x) {
364#if defined(SYCL_DEVICE_ONLY)
365 return cl::sycl::erf(x);
367 return generic_fast_erf_float(x);
373struct erf_impl<double> {
375 static EIGEN_STRONG_INLINE
double run(
double x) {
376#if defined(SYCL_DEVICE_ONLY)
377 return cl::sycl::erf(x);
389template <
typename Scalar>
392 static EIGEN_STRONG_INLINE Scalar run(
const Scalar) {
393 EIGEN_STATIC_ASSERT((internal::is_same<Scalar, Scalar>::value ==
false),
394 THIS_TYPE_IS_NOT_SUPPORTED);
399template <
typename Scalar>
404#if EIGEN_HAS_C99_MATH
406struct erfc_impl<float> {
408 static EIGEN_STRONG_INLINE
float run(
const float x) {
409#if defined(SYCL_DEVICE_ONLY)
410 return cl::sycl::erfc(x);
418struct erfc_impl<double> {
420 static EIGEN_STRONG_INLINE
double run(
const double x) {
421#if defined(SYCL_DEVICE_ONLY)
422 return cl::sycl::erfc(x);
489EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE T flipsign(
490 const T& should_flipsign,
const T& x) {
491 typedef typename unpacket_traits<T>::type Scalar;
492 const T sign_mask = pset1<T>(Scalar(-0.0));
493 T sign_bit = pand<T>(should_flipsign, sign_mask);
494 return pxor<T>(sign_bit, x);
498EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE
double flipsign<double>(
499 const double& should_flipsign,
const double& x) {
500 return should_flipsign == 0 ? x : -x;
504EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE
float flipsign<float>(
505 const float& should_flipsign,
const float& x) {
506 return should_flipsign == 0 ? x : -x;
513template <
typename T,
typename ScalarType>
514EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE T generic_ndtri_gt_exp_neg_two(
const T& b) {
515 const ScalarType p0[] = {
516 ScalarType(-5.99633501014107895267e1),
517 ScalarType(9.80010754185999661536e1),
518 ScalarType(-5.66762857469070293439e1),
519 ScalarType(1.39312609387279679503e1),
520 ScalarType(-1.23916583867381258016e0)
522 const ScalarType q0[] = {
524 ScalarType(1.95448858338141759834e0),
525 ScalarType(4.67627912898881538453e0),
526 ScalarType(8.63602421390890590575e1),
527 ScalarType(-2.25462687854119370527e2),
528 ScalarType(2.00260212380060660359e2),
529 ScalarType(-8.20372256168333339912e1),
530 ScalarType(1.59056225126211695515e1),
531 ScalarType(-1.18331621121330003142e0)
533 const T sqrt2pi = pset1<T>(ScalarType(2.50662827463100050242e0));
534 const T half = pset1<T>(ScalarType(0.5));
535 T c, c2, ndtri_gt_exp_neg_two;
539 ndtri_gt_exp_neg_two = pmadd(c, pmul(
541 internal::ppolevl<T, 4>::run(c2, p0),
542 internal::ppolevl<T, 8>::run(c2, q0))), c);
543 return pmul(ndtri_gt_exp_neg_two, sqrt2pi);
546template <
typename T,
typename ScalarType>
547EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE T generic_ndtri_lt_exp_neg_two(
548 const T& b,
const T& should_flipsign) {
552 const ScalarType p1[] = {
553 ScalarType(4.05544892305962419923e0),
554 ScalarType(3.15251094599893866154e1),
555 ScalarType(5.71628192246421288162e1),
556 ScalarType(4.40805073893200834700e1),
557 ScalarType(1.46849561928858024014e1),
558 ScalarType(2.18663306850790267539e0),
559 ScalarType(-1.40256079171354495875e-1),
560 ScalarType(-3.50424626827848203418e-2),
561 ScalarType(-8.57456785154685413611e-4)
563 const ScalarType q1[] = {
565 ScalarType(1.57799883256466749731e1),
566 ScalarType(4.53907635128879210584e1),
567 ScalarType(4.13172038254672030440e1),
568 ScalarType(1.50425385692907503408e1),
569 ScalarType(2.50464946208309415979e0),
570 ScalarType(-1.42182922854787788574e-1),
571 ScalarType(-3.80806407691578277194e-2),
572 ScalarType(-9.33259480895457427372e-4)
577 const ScalarType p2[] = {
578 ScalarType(3.23774891776946035970e0),
579 ScalarType(6.91522889068984211695e0),
580 ScalarType(3.93881025292474443415e0),
581 ScalarType(1.33303460815807542389e0),
582 ScalarType(2.01485389549179081538e-1),
583 ScalarType(1.23716634817820021358e-2),
584 ScalarType(3.01581553508235416007e-4),
585 ScalarType(2.65806974686737550832e-6),
586 ScalarType(6.23974539184983293730e-9)
588 const ScalarType q2[] = {
590 ScalarType(6.02427039364742014255e0),
591 ScalarType(3.67983563856160859403e0),
592 ScalarType(1.37702099489081330271e0),
593 ScalarType(2.16236993594496635890e-1),
594 ScalarType(1.34204006088543189037e-2),
595 ScalarType(3.28014464682127739104e-4),
596 ScalarType(2.89247864745380683936e-6),
597 ScalarType(6.79019408009981274425e-9)
599 const T eight = pset1<T>(ScalarType(8.0));
600 const T one = pset1<T>(ScalarType(1));
601 const T neg_two = pset1<T>(ScalarType(-2));
604 x = psqrt(pmul(neg_two, plog(b)));
605 x0 = psub(x, pdiv(plog(x), x));
610 pdiv(internal::ppolevl<T, 8>::run(z, p1),
611 internal::ppolevl<T, 8>::run(z, q1)),
612 pdiv(internal::ppolevl<T, 8>::run(z, p2),
613 internal::ppolevl<T, 8>::run(z, q2))));
614 return flipsign(should_flipsign, psub(x0, x1));
617template <
typename T,
typename ScalarType>
618EIGEN_DEVICE_FUNC EIGEN_ALWAYS_INLINE
619T generic_ndtri(
const T& a) {
620 const T maxnum = pset1<T>(NumTraits<ScalarType>::infinity());
621 const T neg_maxnum = pset1<T>(-NumTraits<ScalarType>::infinity());
623 const T zero = pset1<T>(ScalarType(0));
624 const T one = pset1<T>(ScalarType(1));
626 const T exp_neg_two = pset1<T>(ScalarType(0.13533528323661269189));
627 T b,
ndtri, should_flipsign;
629 should_flipsign = pcmp_le(a, psub(one, exp_neg_two));
630 b = pselect(should_flipsign, a, psub(one, a));
633 pcmp_lt(exp_neg_two, b),
634 generic_ndtri_gt_exp_neg_two<T, ScalarType>(b),
635 generic_ndtri_lt_exp_neg_two<T, ScalarType>(b, should_flipsign));
638 pcmp_eq(a, zero), neg_maxnum,
639 pselect(pcmp_eq(one, a), maxnum,
ndtri));
642template <
typename Scalar>
647#if !EIGEN_HAS_C99_MATH
649template <
typename Scalar>
652 static EIGEN_STRONG_INLINE Scalar run(
const Scalar) {
653 EIGEN_STATIC_ASSERT((internal::is_same<Scalar, Scalar>::value ==
false),
654 THIS_TYPE_IS_NOT_SUPPORTED);
661template <
typename Scalar>
664 static EIGEN_STRONG_INLINE Scalar run(
const Scalar x) {
665 return generic_ndtri<Scalar, Scalar>(x);
676template <
typename Scalar>
677struct igammac_retval {
682template <
typename Scalar>
683struct cephes_helper {
685 static EIGEN_STRONG_INLINE Scalar machep() { assert(
false &&
"machep not supported for this type");
return 0.0; }
687 static EIGEN_STRONG_INLINE Scalar big() { assert(
false &&
"big not supported for this type");
return 0.0; }
689 static EIGEN_STRONG_INLINE Scalar biginv() { assert(
false &&
"biginv not supported for this type");
return 0.0; }
693struct cephes_helper<float> {
695 static EIGEN_STRONG_INLINE
float machep() {
696 return NumTraits<float>::epsilon() / 2;
699 static EIGEN_STRONG_INLINE
float big() {
701 return 1.0f / (NumTraits<float>::epsilon() / 2);
704 static EIGEN_STRONG_INLINE
float biginv() {
711struct cephes_helper<double> {
713 static EIGEN_STRONG_INLINE
double machep() {
714 return NumTraits<double>::epsilon() / 2;
717 static EIGEN_STRONG_INLINE
double big() {
718 return 1.0 / NumTraits<double>::epsilon();
721 static EIGEN_STRONG_INLINE
double biginv() {
723 return NumTraits<double>::epsilon();
727enum IgammaComputationMode { VALUE, DERIVATIVE, SAMPLE_DERIVATIVE };
729template <
typename Scalar>
731static EIGEN_STRONG_INLINE Scalar main_igamma_term(Scalar a, Scalar x) {
733 Scalar logax = a * numext::log(x) - x - lgamma_impl<Scalar>::run(a);
734 if (logax < -numext::log(NumTraits<Scalar>::highest()) ||
736 (numext::isnan)(logax)) {
739 return numext::exp(logax);
742template <
typename Scalar, IgammaComputationMode mode>
744int igamma_num_iterations() {
751 if (internal::is_same<Scalar, float>::value) {
753 }
else if (internal::is_same<Scalar, double>::value) {
760template <
typename Scalar, IgammaComputationMode mode>
761struct igammac_cf_impl {
772 static Scalar run(Scalar a, Scalar x) {
773 const Scalar zero = 0;
774 const Scalar one = 1;
775 const Scalar two = 2;
776 const Scalar machep = cephes_helper<Scalar>::machep();
777 const Scalar big = cephes_helper<Scalar>::big();
778 const Scalar biginv = cephes_helper<Scalar>::biginv();
780 if ((numext::isinf)(x)) {
784 Scalar ax = main_igamma_term<Scalar>(a, x);
795 Scalar z = x + y + one;
799 Scalar pkm1 = x + one;
801 Scalar ans = pkm1 / qkm1;
803 Scalar dpkm2_da = zero;
804 Scalar dqkm2_da = zero;
805 Scalar dpkm1_da = zero;
806 Scalar dqkm1_da = -x;
807 Scalar dans_da = (dpkm1_da - ans * dqkm1_da) / qkm1;
809 for (
int i = 0; i < igamma_num_iterations<Scalar, mode>(); i++) {
815 Scalar pk = pkm1 * z - pkm2 * yc;
816 Scalar qk = qkm1 * z - qkm2 * yc;
818 Scalar dpk_da = dpkm1_da * z - pkm1 - dpkm2_da * yc + pkm2 * c;
819 Scalar dqk_da = dqkm1_da * z - qkm1 - dqkm2_da * yc + qkm2 * c;
822 Scalar ans_prev = ans;
825 Scalar dans_da_prev = dans_da;
826 dans_da = (dpk_da - ans * dqk_da) / qk;
829 if (numext::abs(ans_prev - ans) <= machep * numext::abs(ans)) {
833 if (numext::abs(dans_da - dans_da_prev) <= machep) {
849 if (numext::abs(pk) > big) {
863 Scalar dlogax_da = numext::log(x) - digamma_impl<Scalar>::run(a);
864 Scalar dax_da = ax * dlogax_da;
870 return ans * dax_da + dans_da * ax;
871 case SAMPLE_DERIVATIVE:
873 return -(dans_da + ans * dlogax_da) * x;
878template <
typename Scalar, IgammaComputationMode mode>
879struct igamma_series_impl {
889 static Scalar run(Scalar a, Scalar x) {
890 const Scalar zero = 0;
891 const Scalar one = 1;
892 const Scalar machep = cephes_helper<Scalar>::machep();
894 Scalar ax = main_igamma_term<Scalar>(a, x);
912 Scalar dans_da = zero;
914 for (
int i = 0; i < igamma_num_iterations<Scalar, mode>(); i++) {
917 Scalar dterm_da = -x / (r * r);
918 dc_da = term * dc_da + dterm_da * c;
924 if (c <= machep * ans) {
928 if (numext::abs(dc_da) <= machep * numext::abs(dans_da)) {
934 Scalar dlogax_da = numext::log(x) - digamma_impl<Scalar>::run(a + one);
935 Scalar dax_da = ax * dlogax_da;
941 return ans * dax_da + dans_da * ax;
942 case SAMPLE_DERIVATIVE:
944 return -(dans_da + ans * dlogax_da) * x / a;
949#if !EIGEN_HAS_C99_MATH
951template <
typename Scalar>
954 static Scalar run(Scalar a, Scalar x) {
955 EIGEN_STATIC_ASSERT((internal::is_same<Scalar, Scalar>::value ==
false),
956 THIS_TYPE_IS_NOT_SUPPORTED);
963template <
typename Scalar>
966 static Scalar run(Scalar a, Scalar x) {
1021 const Scalar zero = 0;
1022 const Scalar one = 1;
1023 const Scalar nan = NumTraits<Scalar>::quiet_NaN();
1025 if ((x < zero) || (a <= zero)) {
1030 if ((numext::isnan)(a) || (numext::isnan)(x)) {
1034 if ((x < one) || (x < a)) {
1035 return (one - igamma_series_impl<Scalar, VALUE>::run(a, x));
1038 return igammac_cf_impl<Scalar, VALUE>::run(a, x);
1048#if !EIGEN_HAS_C99_MATH
1050template <
typename Scalar, IgammaComputationMode mode>
1051struct igamma_generic_impl {
1053 static EIGEN_STRONG_INLINE Scalar run(Scalar a, Scalar x) {
1054 EIGEN_STATIC_ASSERT((internal::is_same<Scalar, Scalar>::value ==
false),
1055 THIS_TYPE_IS_NOT_SUPPORTED);
1062template <
typename Scalar, IgammaComputationMode mode>
1063struct igamma_generic_impl {
1065 static Scalar run(Scalar a, Scalar x) {
1074 const Scalar zero = 0;
1075 const Scalar one = 1;
1076 const Scalar nan = NumTraits<Scalar>::quiet_NaN();
1078 if (x == zero)
return zero;
1080 if ((x < zero) || (a <= zero)) {
1084 if ((numext::isnan)(a) || (numext::isnan)(x)) {
1088 if ((x > one) && (x > a)) {
1089 Scalar ret = igammac_cf_impl<Scalar, mode>::run(a, x);
1090 if (mode == VALUE) {
1097 return igamma_series_impl<Scalar, mode>::run(a, x);
1103template <
typename Scalar>
1104struct igamma_retval {
1105 typedef Scalar type;
1108template <
typename Scalar>
1109struct igamma_impl : igamma_generic_impl<Scalar, VALUE> {
1179template <
typename Scalar>
1180struct igamma_der_a_retval : igamma_retval<Scalar> {};
1182template <
typename Scalar>
1183struct igamma_der_a_impl : igamma_generic_impl<Scalar, DERIVATIVE> {
1200template <
typename Scalar>
1201struct gamma_sample_der_alpha_retval : igamma_retval<Scalar> {};
1203template <
typename Scalar>
1204struct gamma_sample_der_alpha_impl
1205 : igamma_generic_impl<Scalar, SAMPLE_DERIVATIVE> {
1249template <
typename Scalar>
1251 typedef Scalar type;
1254template <
typename Scalar>
1255struct zeta_impl_series {
1257 static EIGEN_STRONG_INLINE Scalar run(
const Scalar) {
1258 EIGEN_STATIC_ASSERT((internal::is_same<Scalar, Scalar>::value ==
false),
1259 THIS_TYPE_IS_NOT_SUPPORTED);
1265struct zeta_impl_series<float> {
1267 static EIGEN_STRONG_INLINE
bool run(
float& a,
float& b,
float& s,
const float x,
const float machep) {
1273 b = numext::pow( a, -x );
1275 if( numext::abs(b/s) < machep )
1285struct zeta_impl_series<double> {
1287 static EIGEN_STRONG_INLINE
bool run(
double& a,
double& b,
double& s,
const double x,
const double machep) {
1289 while( (i < 9) || (a <= 9.0) )
1293 b = numext::pow( a, -x );
1295 if( numext::abs(b/s) < machep )
1304template <
typename Scalar>
1307 static Scalar run(Scalar x, Scalar q) {
1370 Scalar p, r, a, b, k, s, t, w;
1372 const Scalar A[] = {
1378 Scalar(-1.8924375803183791606e9),
1379 Scalar(7.47242496e10),
1380 Scalar(-2.950130727918164224e12),
1381 Scalar(1.1646782814350067249e14),
1382 Scalar(-4.5979787224074726105e15),
1383 Scalar(1.8152105401943546773e17),
1384 Scalar(-7.1661652561756670113e18)
1387 const Scalar maxnum = NumTraits<Scalar>::infinity();
1388 const Scalar zero = Scalar(0.0), half = Scalar(0.5), one = Scalar(1.0);
1389 const Scalar machep = cephes_helper<Scalar>::machep();
1390 const Scalar nan = NumTraits<Scalar>::quiet_NaN();
1402 if(q == numext::floor(q))
1404 if (x == numext::floor(x) &&
long(x) % 2 == 0) {
1412 r = numext::floor(p);
1422 s = numext::pow( q, -x );
1426 if (zeta_impl_series<Scalar>::run(a, b, s, x, machep)) {
1433 if (numext::equal_strict(b, zero)) {
1443 for( i=0; i<12; i++ )
1449 t = numext::abs(t/s);
1466template <
typename Scalar>
1467struct polygamma_retval {
1468 typedef Scalar type;
1471#if !EIGEN_HAS_C99_MATH
1473template <
typename Scalar>
1474struct polygamma_impl {
1476 static EIGEN_STRONG_INLINE Scalar run(Scalar n, Scalar x) {
1477 EIGEN_STATIC_ASSERT((internal::is_same<Scalar, Scalar>::value ==
false),
1478 THIS_TYPE_IS_NOT_SUPPORTED);
1485template <
typename Scalar>
1486struct polygamma_impl {
1488 static Scalar run(Scalar n, Scalar x) {
1489 Scalar zero = 0.0, one = 1.0;
1490 Scalar nplus = n + one;
1491 const Scalar nan = NumTraits<Scalar>::quiet_NaN();
1494 if (numext::floor(n) != n || n < zero) {
1498 else if (n == zero) {
1499 return digamma_impl<Scalar>::run(x);
1503 Scalar factorial = numext::exp(lgamma_impl<Scalar>::run(nplus));
1504 return numext::pow(-one, nplus) * factorial * zeta_impl<Scalar>::run(nplus, x);
1515template <
typename Scalar>
1516struct betainc_retval {
1517 typedef Scalar type;
1520#if !EIGEN_HAS_C99_MATH
1522template <
typename Scalar>
1523struct betainc_impl {
1525 static EIGEN_STRONG_INLINE Scalar run(Scalar a, Scalar b, Scalar x) {
1526 EIGEN_STATIC_ASSERT((internal::is_same<Scalar, Scalar>::value ==
false),
1527 THIS_TYPE_IS_NOT_SUPPORTED);
1534template <
typename Scalar>
1535struct betainc_impl {
1537 static EIGEN_STRONG_INLINE Scalar run(Scalar, Scalar, Scalar) {
1607 EIGEN_STATIC_ASSERT((internal::is_same<Scalar, Scalar>::value ==
false),
1608 THIS_TYPE_IS_NOT_SUPPORTED);
1616template <
typename Scalar>
1619 static EIGEN_STRONG_INLINE Scalar run(Scalar a, Scalar b, Scalar x,
bool small_branch) {
1620 EIGEN_STATIC_ASSERT((internal::is_same<Scalar, float>::value ||
1621 internal::is_same<Scalar, double>::value),
1622 THIS_TYPE_IS_NOT_SUPPORTED);
1623 const Scalar big = cephes_helper<Scalar>::big();
1624 const Scalar machep = cephes_helper<Scalar>::machep();
1625 const Scalar biginv = cephes_helper<Scalar>::biginv();
1627 const Scalar zero = 0;
1628 const Scalar one = 1;
1629 const Scalar two = 2;
1631 Scalar xk, pk, pkm1, pkm2, qk, qkm1, qkm2;
1632 Scalar k1, k2, k3, k4, k5, k6, k7, k8, k26update;
1636 const int num_iters = (internal::is_same<Scalar, float>::value) ? 100 : 300;
1637 const Scalar thresh =
1638 (internal::is_same<Scalar, float>::value) ? machep : Scalar(3) * machep;
1639 Scalar r = (internal::is_same<Scalar, float>::value) ? zero : one;
1672 xk = -(x * k1 * k2) / (k3 * k4);
1673 pk = pkm1 + pkm2 * xk;
1674 qk = qkm1 + qkm2 * xk;
1680 xk = (x * k5 * k6) / (k7 * k8);
1681 pk = pkm1 + pkm2 * xk;
1682 qk = qkm1 + qkm2 * xk;
1690 if (numext::abs(ans - r) < numext::abs(r) * thresh) {
1705 if ((numext::abs(qk) + numext::abs(pk)) > big) {
1711 if ((numext::abs(qk) < biginv) || (numext::abs(pk) < biginv)) {
1717 }
while (++n < num_iters);
1724template <
typename Scalar>
1725struct betainc_helper {};
1728struct betainc_helper<float> {
1730 EIGEN_DEVICE_FUNC
static EIGEN_STRONG_INLINE
float incbsa(
float aa,
float bb,
1732 float ans, a, b, t, x, onemx;
1733 bool reversed_a_b =
false;
1738 if (xx > (aa / (aa + bb))) {
1739 reversed_a_b =
true;
1753 if (numext::abs(b * x / a) < 0.3f) {
1754 t = betainc_helper<float>::incbps(a, b, x);
1755 if (reversed_a_b) t = 1.0f - t;
1760 ans = x * (a + b - 2.0f) / (a - 1.0f);
1762 ans = incbeta_cfe<float>::run(a, b, x,
true );
1763 t = b * numext::log(t);
1765 ans = incbeta_cfe<float>::run(a, b, x,
false );
1766 t = (b - 1.0f) * numext::log(t);
1769 t += a * numext::log(x) + lgamma_impl<float>::run(a + b) -
1770 lgamma_impl<float>::run(a) - lgamma_impl<float>::run(b);
1771 t += numext::log(ans / a);
1774 if (reversed_a_b) t = 1.0f - t;
1779 static EIGEN_STRONG_INLINE
float incbps(
float a,
float b,
float x) {
1781 const float machep = cephes_helper<float>::machep();
1783 y = a * numext::log(x) + (b - 1.0f) * numext::log1p(-x) - numext::log(a);
1784 y -= lgamma_impl<float>::run(a) + lgamma_impl<float>::run(b);
1785 y += lgamma_impl<float>::run(a + b);
1798 }
while (numext::abs(u) > machep);
1800 return numext::exp(y) * (1.0f + s);
1805struct betainc_impl<float> {
1807 static float run(
float a,
float b,
float x) {
1808 const float nan = NumTraits<float>::quiet_NaN();
1811 if (a <= 0.0f)
return nan;
1812 if (b <= 0.0f)
return nan;
1813 if ((x <= 0.0f) || (x >= 1.0f)) {
1814 if (x == 0.0f)
return 0.0f;
1815 if (x == 1.0f)
return 1.0f;
1822 ans = betainc_helper<float>::incbsa(a + 1.0f, b, x);
1823 t = a * numext::log(x) + b * numext::log1p(-x) +
1824 lgamma_impl<float>::run(a + b) - lgamma_impl<float>::run(a + 1.0f) -
1825 lgamma_impl<float>::run(b);
1826 return (ans + numext::exp(t));
1828 return betainc_helper<float>::incbsa(a, b, x);
1834struct betainc_helper<double> {
1836 static EIGEN_STRONG_INLINE
double incbps(
double a,
double b,
double x) {
1837 const double machep = cephes_helper<double>::machep();
1839 double s, t, u, v, n, t1, z, ai;
1849 while (numext::abs(v) > z) {
1850 u = (n - b) * x / n;
1859 u = a * numext::log(x);
1867 t = lgamma_impl<double>::run(a + b) - lgamma_impl<double>::run(a) -
1868 lgamma_impl<double>::run(b) + u + numext::log(s);
1869 return s = numext::exp(t);
1874struct betainc_impl<double> {
1876 static double run(
double aa,
double bb,
double xx) {
1877 const double nan = NumTraits<double>::quiet_NaN();
1878 const double machep = cephes_helper<double>::machep();
1881 double a, b, t, x, xc, w, y;
1882 bool reversed_a_b =
false;
1884 if (aa <= 0.0 || bb <= 0.0) {
1888 if ((xx <= 0.0) || (xx >= 1.0)) {
1889 if (xx == 0.0)
return (0.0);
1890 if (xx == 1.0)
return (1.0);
1895 if ((bb * xx) <= 1.0 && xx <= 0.95) {
1896 return betainc_helper<double>::incbps(aa, bb, xx);
1902 if (xx > (aa / (aa + bb))) {
1903 reversed_a_b =
true;
1915 if (reversed_a_b && (b * x) <= 1.0 && x <= 0.95) {
1916 t = betainc_helper<double>::incbps(a, b, x);
1926 y = x * (a + b - 2.0) - (a - 1.0);
1928 w = incbeta_cfe<double>::run(a, b, x,
true );
1930 w = incbeta_cfe<double>::run(a, b, x,
false ) / xc;
1937 y = a * numext::log(x);
1938 t = b * numext::log(xc);
1951 y += t + lgamma_impl<double>::run(a + b) - lgamma_impl<double>::run(a) -
1952 lgamma_impl<double>::run(b);
1953 y += numext::log(w / a);
1976template <
typename Scalar>
1977EIGEN_DEVICE_FUNC
inline EIGEN_MATHFUNC_RETVAL(
lgamma, Scalar)
1978 lgamma(
const Scalar& x) {
1979 return EIGEN_MATHFUNC_IMPL(
lgamma, Scalar)::run(x);
1982template <
typename Scalar>
1983EIGEN_DEVICE_FUNC
inline EIGEN_MATHFUNC_RETVAL(
digamma, Scalar)
1985 return EIGEN_MATHFUNC_IMPL(
digamma, Scalar)::run(x);
1988template <
typename Scalar>
1989EIGEN_DEVICE_FUNC
inline EIGEN_MATHFUNC_RETVAL(
zeta, Scalar)
1990zeta(
const Scalar& x,
const Scalar& q) {
1991 return EIGEN_MATHFUNC_IMPL(
zeta, Scalar)::run(x, q);
1994template <
typename Scalar>
1995EIGEN_DEVICE_FUNC
inline EIGEN_MATHFUNC_RETVAL(
polygamma, Scalar)
1996polygamma(
const Scalar& n,
const Scalar& x) {
1997 return EIGEN_MATHFUNC_IMPL(
polygamma, Scalar)::run(n, x);
2000template <
typename Scalar>
2001EIGEN_DEVICE_FUNC
inline EIGEN_MATHFUNC_RETVAL(
erf, Scalar)
2002 erf(
const Scalar& x) {
2003 return EIGEN_MATHFUNC_IMPL(
erf, Scalar)::run(x);
2006template <
typename Scalar>
2007EIGEN_DEVICE_FUNC
inline EIGEN_MATHFUNC_RETVAL(
erfc, Scalar)
2008 erfc(
const Scalar& x) {
2009 return EIGEN_MATHFUNC_IMPL(
erfc, Scalar)::run(x);
2012template <
typename Scalar>
2013EIGEN_DEVICE_FUNC
inline EIGEN_MATHFUNC_RETVAL(
ndtri, Scalar)
2014 ndtri(
const Scalar& x) {
2015 return EIGEN_MATHFUNC_IMPL(
ndtri, Scalar)::run(x);
2018template <
typename Scalar>
2019EIGEN_DEVICE_FUNC
inline EIGEN_MATHFUNC_RETVAL(
igamma, Scalar)
2020 igamma(
const Scalar& a,
const Scalar& x) {
2021 return EIGEN_MATHFUNC_IMPL(
igamma, Scalar)::run(a, x);
2024template <
typename Scalar>
2025EIGEN_DEVICE_FUNC
inline EIGEN_MATHFUNC_RETVAL(
igamma_der_a, Scalar)
2027 return EIGEN_MATHFUNC_IMPL(
igamma_der_a, Scalar)::run(a, x);
2030template <
typename Scalar>
2036template <
typename Scalar>
2037EIGEN_DEVICE_FUNC
inline EIGEN_MATHFUNC_RETVAL(
igammac, Scalar)
2038 igammac(
const Scalar& a,
const Scalar& x) {
2039 return EIGEN_MATHFUNC_IMPL(
igammac, Scalar)::run(a, x);
2042template <
typename Scalar>
2043EIGEN_DEVICE_FUNC
inline EIGEN_MATHFUNC_RETVAL(
betainc, Scalar)
2044 betainc(
const Scalar& a,
const Scalar& b,
const Scalar& x) {
2045 return EIGEN_MATHFUNC_IMPL(
betainc, Scalar)::run(a, b, x);
Namespace containing all symbols from the Eigen library.
const Eigen::CwiseBinaryOp< Eigen::internal::scalar_igammac_op< typename Derived::Scalar >, const Derived, const ExponentDerived > igammac(const Eigen::ArrayBase< Derived > &a, const Eigen::ArrayBase< ExponentDerived > &x)
Definition SpecialFunctionsArrayAPI.h:90
const Eigen::CwiseBinaryOp< Eigen::internal::scalar_igamma_der_a_op< typename Derived::Scalar >, const Derived, const ExponentDerived > igamma_der_a(const Eigen::ArrayBase< Derived > &a, const Eigen::ArrayBase< ExponentDerived > &x)
Definition SpecialFunctionsArrayAPI.h:51
const Eigen::CwiseUnaryOp< Eigen::internal::scalar_sign_op< typename Derived::Scalar >, const Derived > sign(const Eigen::ArrayBase< Derived > &x)
const Eigen::CwiseUnaryOp< Eigen::internal::scalar_lgamma_op< typename Derived::Scalar >, const Derived > lgamma(const Eigen::ArrayBase< Derived > &x)
const Eigen::CwiseUnaryOp< Eigen::internal::scalar_erf_op< typename Derived::Scalar >, const Derived > erf(const Eigen::ArrayBase< Derived > &x)
const TensorCwiseTernaryOp< internal::scalar_betainc_op< typename XDerived::Scalar >, const ADerived, const BDerived, const XDerived > betainc(const ADerived &a, const BDerived &b, const XDerived &x)
Definition TensorGlobalFunctions.h:24
const Eigen::CwiseBinaryOp< Eigen::internal::scalar_gamma_sample_der_alpha_op< typename AlphaDerived::Scalar >, const AlphaDerived, const SampleDerived > gamma_sample_der_alpha(const Eigen::ArrayBase< AlphaDerived > &alpha, const Eigen::ArrayBase< SampleDerived > &sample)
Definition SpecialFunctionsArrayAPI.h:72
const Eigen::CwiseUnaryOp< Eigen::internal::scalar_erfc_op< typename Derived::Scalar >, const Derived > erfc(const Eigen::ArrayBase< Derived > &x)
const Eigen::CwiseUnaryOp< Eigen::internal::scalar_ndtri_op< typename Derived::Scalar >, const Derived > ndtri(const Eigen::ArrayBase< Derived > &x)
const Eigen::CwiseUnaryOp< Eigen::internal::scalar_digamma_op< typename Derived::Scalar >, const Derived > digamma(const Eigen::ArrayBase< Derived > &x)
const Eigen::CwiseBinaryOp< Eigen::internal::scalar_polygamma_op< typename DerivedX::Scalar >, const DerivedN, const DerivedX > polygamma(const Eigen::ArrayBase< DerivedN > &n, const Eigen::ArrayBase< DerivedX > &x)
Definition SpecialFunctionsArrayAPI.h:112
const Eigen::CwiseBinaryOp< Eigen::internal::scalar_igamma_op< typename Derived::Scalar >, const Derived, const ExponentDerived > igamma(const Eigen::ArrayBase< Derived > &a, const Eigen::ArrayBase< ExponentDerived > &x)
Definition SpecialFunctionsArrayAPI.h:28
const Eigen::CwiseBinaryOp< Eigen::internal::scalar_zeta_op< typename DerivedX::Scalar >, const DerivedX, const DerivedQ > zeta(const Eigen::ArrayBase< DerivedX > &x, const Eigen::ArrayBase< DerivedQ > &q)
Definition SpecialFunctionsArrayAPI.h:156