-
-
Notifications
You must be signed in to change notification settings - Fork 198
super stable gamma_lccdf #3266
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Open
spinkney
wants to merge
10
commits into
develop
Choose a base branch
from
fix-gamma-lccdf-v2
base: develop
Could not load branches
Branch not found: {{ refName }}
Loading
Could not load tags
Nothing to show
Loading
Are you sure you want to change the base?
Some commits from the old base branch may be removed from the timeline,
and old review comments may become outdated.
+462
−38
Open
super stable gamma_lccdf #3266
Changes from all commits
Commits
Show all changes
10 commits
Select commit
Hold shift + click to select a range
026b39a
super stable gamma_lccdf
spinkney 170bf8d
Merge commit '5df6fc50b02b07109e21134448d1b7f5b2c38444' into HEAD
yashikno f0cfa1c
[Jenkins] auto-formatting by clang-format version 10.0.0-4ubuntu1
stan-buildbot 2125bb6
explicit type
spinkney 1f112d3
remove fwd and fix templating
spinkney 9e5ca4b
[Jenkins] auto-formatting by clang-format version 10.0.0-4ubuntu1
stan-buildbot df09600
remove internal function in lccdf and add two tests for expanded range
spinkney 538f55c
[Jenkins] auto-formatting by clang-format version 10.0.0-4ubuntu1
stan-buildbot 2c6ef87
make order of precision and max_steps the same as grad_reg_lower_inc_…
spinkney 36337f6
[Jenkins] auto-formatting by clang-format version 10.0.0-4ubuntu1
stan-buildbot File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,156 @@ | ||
| #ifndef STAN_MATH_PRIM_FUN_LOG_GAMMA_Q_DGAMMA_HPP | ||
| #define STAN_MATH_PRIM_FUN_LOG_GAMMA_Q_DGAMMA_HPP | ||
|
|
||
| #include <stan/math/prim/meta.hpp> | ||
| #include <stan/math/prim/fun/constants.hpp> | ||
| #include <stan/math/prim/fun/digamma.hpp> | ||
| #include <stan/math/prim/fun/exp.hpp> | ||
| #include <stan/math/prim/fun/gamma_p.hpp> | ||
| #include <stan/math/prim/fun/gamma_q.hpp> | ||
| #include <stan/math/prim/fun/grad_reg_inc_gamma.hpp> | ||
| #include <stan/math/prim/fun/lgamma.hpp> | ||
| #include <stan/math/prim/fun/log.hpp> | ||
| #include <stan/math/prim/fun/log1m.hpp> | ||
| #include <stan/math/prim/fun/tgamma.hpp> | ||
| #include <stan/math/prim/fun/value_of.hpp> | ||
| #include <cmath> | ||
|
|
||
| namespace stan { | ||
| namespace math { | ||
|
|
||
| /** | ||
| * Result structure containing log(Q(a,z)) and its gradient with respect to a. | ||
| * | ||
| * @tparam T return type | ||
| */ | ||
| template <typename T> | ||
| struct log_gamma_q_result { | ||
| T log_q; ///< log(Q(a,z)) where Q is upper regularized incomplete gamma | ||
| T dlog_q_da; ///< d/da log(Q(a,z)) | ||
| }; | ||
|
|
||
| namespace internal { | ||
|
|
||
| /** | ||
| * Compute log(Q(a,z)) using continued fraction expansion for upper incomplete | ||
| * gamma function. | ||
| * | ||
| * @tparam T_a Type of shape parameter a (double or fvar types) | ||
| * @tparam T_z Type of value parameter z (double or fvar types) | ||
| * @param a Shape parameter | ||
| * @param z Value at which to evaluate | ||
| * @param precision Convergence threshold | ||
| * @param max_steps Maximum number of continued fraction iterations | ||
| * @return log(Q(a,z)) with same type as T_a and T_z | ||
| */ | ||
| template <typename T_a, typename T_z> | ||
| inline auto log_q_gamma_cf(const T_a& a, const T_z& z, double precision = 1e-16, | ||
| int max_steps = 250) { | ||
| using stan::math::lgamma; | ||
| using stan::math::log; | ||
| using stan::math::value_of; | ||
| using std::fabs; | ||
| using T_return = return_type_t<T_a, T_z>; | ||
|
|
||
| const T_return a_ret = a; | ||
| const T_return z_ret = z; | ||
| const auto log_prefactor = a_ret * log(z_ret) - z_ret - lgamma(a_ret); | ||
|
|
||
| auto b = z_ret + 1.0 - a_ret; | ||
| auto C = (fabs(value_of(b)) >= EPSILON) ? b : T_return(EPSILON); | ||
| auto D = T_return(0.0); | ||
| auto f = C; | ||
|
|
||
| for (int i = 1; i <= max_steps; ++i) { | ||
| auto an = -i * (i - a_ret); | ||
| b += 2.0; | ||
|
|
||
| D = b + an * D; | ||
| if (fabs(value_of(D)) < EPSILON) { | ||
| D = T_return(EPSILON); | ||
| } | ||
| C = b + an / C; | ||
| if (fabs(value_of(C)) < EPSILON) { | ||
| C = T_return(EPSILON); | ||
| } | ||
|
|
||
| D = 1.0 / D; | ||
| auto delta = C * D; | ||
| f *= delta; | ||
|
|
||
| const double delta_m1 = value_of(fabs(value_of(delta) - 1.0)); | ||
| if (delta_m1 < precision) { | ||
| break; | ||
| } | ||
| } | ||
|
|
||
| return log_prefactor - log(f); | ||
| } | ||
|
|
||
| } // namespace internal | ||
|
|
||
| /** | ||
| * Compute log(Q(a,z)) and its gradient with respect to a using continued | ||
| * fraction expansion, where Q(a,z) = Gamma(a,z) / Gamma(a) is the regularized | ||
| * upper incomplete gamma function. | ||
| * | ||
| * This uses a continued fraction representation for numerical stability when | ||
| * computing the upper incomplete gamma function in log space, along with | ||
| * analytical gradient computation. | ||
| * | ||
| * @tparam T_a type of the shape parameter | ||
| * @tparam T_z type of the value parameter | ||
| * @param a shape parameter (must be positive) | ||
| * @param z value parameter (must be non-negative) | ||
| * @param precision convergence threshold | ||
| * @param max_steps maximum iterations for continued fraction | ||
| * @return structure containing log(Q(a,z)) and d/da log(Q(a,z)) | ||
| */ | ||
| template <typename T_a, typename T_z> | ||
| inline log_gamma_q_result<return_type_t<T_a, T_z>> log_gamma_q_dgamma( | ||
| const T_a& a, const T_z& z, double precision = 1e-16, int max_steps = 250) { | ||
| using std::exp; | ||
| using std::log; | ||
| using T_return = return_type_t<T_a, T_z>; | ||
|
|
||
| const double a_dbl = value_of(a); | ||
| const double z_dbl = value_of(z); | ||
|
|
||
| log_gamma_q_result<T_return> result; | ||
|
|
||
| // For z > a + 1, use continued fraction for better numerical stability | ||
| if (z_dbl > a_dbl + 1.0) { | ||
| result.log_q = internal::log_q_gamma_cf(a_dbl, z_dbl, precision, max_steps); | ||
|
|
||
| // For gradient, use: d/da log(Q) = (1/Q) * dQ/da | ||
| // grad_reg_inc_gamma computes dQ/da | ||
| const double Q_val = exp(result.log_q); | ||
| const double dQ_da | ||
| = grad_reg_inc_gamma(a_dbl, z_dbl, tgamma(a_dbl), digamma(a_dbl)); | ||
| result.dlog_q_da = dQ_da / Q_val; | ||
|
|
||
| } else { | ||
| // For z <= a + 1, use log1m(P(a,z)) for better numerical accuracy | ||
| const double P_val = gamma_p(a_dbl, z_dbl); | ||
| result.log_q = log1m(P_val); | ||
|
|
||
| // Gradient: d/da log(Q) = (1/Q) * dQ/da | ||
| // grad_reg_inc_gamma computes dQ/da | ||
| const double Q_val = exp(result.log_q); | ||
| if (Q_val > 0) { | ||
| const double dQ_da | ||
| = grad_reg_inc_gamma(a_dbl, z_dbl, tgamma(a_dbl), digamma(a_dbl)); | ||
| result.dlog_q_da = dQ_da / Q_val; | ||
| } else { | ||
| // Fallback if Q rounds to zero - use asymptotic approximation | ||
| result.dlog_q_da = log(z_dbl) - digamma(a_dbl); | ||
| } | ||
| } | ||
|
|
||
| return result; | ||
| } | ||
|
|
||
| } // namespace math | ||
| } // namespace stan | ||
|
|
||
| #endif | ||
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -6,28 +6,33 @@ | |
| #include <stan/math/prim/fun/constants.hpp> | ||
| #include <stan/math/prim/fun/digamma.hpp> | ||
| #include <stan/math/prim/fun/exp.hpp> | ||
| #include <stan/math/prim/fun/gamma_q.hpp> | ||
| #include <stan/math/prim/fun/fma.hpp> | ||
| #include <stan/math/prim/fun/gamma_p.hpp> | ||
| #include <stan/math/prim/fun/grad_reg_inc_gamma.hpp> | ||
| #include <stan/math/prim/fun/grad_reg_lower_inc_gamma.hpp> | ||
| #include <stan/math/prim/fun/lgamma.hpp> | ||
| #include <stan/math/prim/fun/log.hpp> | ||
| #include <stan/math/prim/fun/log1m.hpp> | ||
| #include <stan/math/prim/fun/max_size.hpp> | ||
| #include <stan/math/prim/fun/scalar_seq_view.hpp> | ||
| #include <stan/math/prim/fun/size.hpp> | ||
| #include <stan/math/prim/fun/size_zero.hpp> | ||
| #include <stan/math/prim/fun/tgamma.hpp> | ||
| #include <stan/math/prim/fun/value_of.hpp> | ||
| #include <stan/math/prim/fun/log_gamma_q_dgamma.hpp> | ||
| #include <stan/math/prim/functor/partials_propagator.hpp> | ||
| #include <cmath> | ||
|
|
||
| namespace stan { | ||
| namespace math { | ||
|
|
||
| template <typename T_y, typename T_shape, typename T_inv_scale> | ||
| inline return_type_t<T_y, T_shape, T_inv_scale> gamma_lccdf( | ||
| const T_y& y, const T_shape& alpha, const T_inv_scale& beta) { | ||
| return_type_t<T_y, T_shape, T_inv_scale> gamma_lccdf(const T_y& y, | ||
| const T_shape& alpha, | ||
| const T_inv_scale& beta) { | ||
| using T_partials_return = partials_return_t<T_y, T_shape, T_inv_scale>; | ||
| using std::exp; | ||
| using std::log; | ||
| using std::pow; | ||
| using T_y_ref = ref_type_t<T_y>; | ||
| using T_alpha_ref = ref_type_t<T_shape>; | ||
| using T_beta_ref = ref_type_t<T_inv_scale>; | ||
|
|
@@ -51,61 +56,159 @@ inline return_type_t<T_y, T_shape, T_inv_scale> gamma_lccdf( | |
| scalar_seq_view<T_y_ref> y_vec(y_ref); | ||
| scalar_seq_view<T_alpha_ref> alpha_vec(alpha_ref); | ||
| scalar_seq_view<T_beta_ref> beta_vec(beta_ref); | ||
| size_t N = max_size(y, alpha, beta); | ||
|
|
||
| // Explicit return for extreme values | ||
| // The gradients are technically ill-defined, but treated as zero | ||
| for (size_t i = 0; i < stan::math::size(y); i++) { | ||
| if (y_vec.val(i) == 0) { | ||
| // LCCDF(0) = log(P(Y > 0)) = log(1) = 0 | ||
| return ops_partials.build(0.0); | ||
| } | ||
| } | ||
| const size_t N = max_size(y, alpha, beta); | ||
|
|
||
| constexpr bool need_y_beta_deriv = !is_constant_all<T_y, T_inv_scale>::value; | ||
| constexpr bool any_fvar = is_fvar<scalar_type_t<T_y>>::value | ||
| || is_fvar<scalar_type_t<T_shape>>::value | ||
| || is_fvar<scalar_type_t<T_inv_scale>>::value; | ||
| constexpr bool partials_fvar = is_fvar<T_partials_return>::value; | ||
|
|
||
| for (size_t n = 0; n < N; n++) { | ||
| // Explicit results for extreme values | ||
| // The gradients are technically ill-defined, but treated as zero | ||
| if (y_vec.val(n) == INFTY) { | ||
| // LCCDF(∞) = log(P(Y > ∞)) = log(0) = -∞ | ||
| const T_partials_return y_dbl = y_vec.val(n); | ||
| if (y_dbl == 0.0) { | ||
| continue; | ||
| } | ||
| if (y_dbl == INFTY) { | ||
| return ops_partials.build(negative_infinity()); | ||
| } | ||
|
|
||
| const T_partials_return y_dbl = y_vec.val(n); | ||
| const T_partials_return alpha_dbl = alpha_vec.val(n); | ||
| const T_partials_return beta_dbl = beta_vec.val(n); | ||
| const T_partials_return beta_y_dbl = beta_dbl * y_dbl; | ||
|
|
||
| // Qn = 1 - Pn | ||
| const T_partials_return Qn = gamma_q(alpha_dbl, beta_y_dbl); | ||
| const T_partials_return log_Qn = log(Qn); | ||
| const T_partials_return beta_y = beta_dbl * y_dbl; | ||
| if (beta_y == INFTY) { | ||
| return ops_partials.build(negative_infinity()); | ||
| } | ||
|
|
||
| bool use_cf = beta_y > alpha_dbl + 1.0; | ||
| T_partials_return log_Qn; | ||
| [[maybe_unused]] T_partials_return dlogQ_dalpha = 0.0; | ||
| // Extract double values for the double-only continued fraction path. | ||
| [[maybe_unused]] const double beta_y_dbl = value_of(value_of(beta_y)); | ||
| [[maybe_unused]] const double alpha_dbl_val = value_of(value_of(alpha_dbl)); | ||
|
|
||
| if (use_cf) { | ||
|
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I was having a hard time parsing this logic, could we do the branching logic for the autodiff types first and then do the |
||
| if constexpr (!any_fvar && is_autodiff_v<T_shape>) { | ||
| // var-only: use analytical gradient with double inputs | ||
| auto log_q_result = log_gamma_q_dgamma(alpha_dbl_val, beta_y_dbl); | ||
| log_Qn = log_q_result.log_q; | ||
| dlogQ_dalpha = log_q_result.dlog_q_da; | ||
| } else { | ||
| log_Qn = internal::log_q_gamma_cf(alpha_dbl, beta_y); | ||
| if constexpr (is_autodiff_v<T_shape>) { | ||
| if constexpr (partials_fvar) { | ||
| auto alpha_unit = alpha_dbl; | ||
| alpha_unit.d_ = 1; | ||
| auto beta_unit = beta_y; | ||
| beta_unit.d_ = 0; | ||
| auto log_Qn_fvar = internal::log_q_gamma_cf(alpha_unit, beta_unit); | ||
| dlogQ_dalpha = log_Qn_fvar.d_; | ||
| } else { | ||
| const T_partials_return Qn = exp(log_Qn); | ||
| dlogQ_dalpha | ||
| = grad_reg_inc_gamma(alpha_dbl, beta_y, tgamma(alpha_dbl), | ||
| digamma(alpha_dbl)) | ||
| / Qn; | ||
| } | ||
| } | ||
| } | ||
| } else { | ||
| const T_partials_return Pn = gamma_p(alpha_dbl, beta_y); | ||
| log_Qn = log1m(Pn); | ||
|
|
||
| if (!std::isfinite(value_of(value_of(log_Qn)))) { | ||
| use_cf = beta_y > 0.0; | ||
| if (use_cf) { | ||
| // Fallback to continued fraction if log1m fails | ||
| if constexpr (!any_fvar && is_autodiff_v<T_shape>) { | ||
| auto log_q_result = log_gamma_q_dgamma(alpha_dbl_val, beta_y_dbl); | ||
| log_Qn = log_q_result.log_q; | ||
| dlogQ_dalpha = log_q_result.dlog_q_da; | ||
| } else { | ||
| log_Qn = internal::log_q_gamma_cf(alpha_dbl, beta_y); | ||
| if constexpr (is_autodiff_v<T_shape>) { | ||
| if constexpr (partials_fvar) { | ||
| auto alpha_unit = alpha_dbl; | ||
| alpha_unit.d_ = 1; | ||
| auto beta_unit = beta_y; | ||
| beta_unit.d_ = 0; | ||
| auto log_Qn_fvar | ||
| = internal::log_q_gamma_cf(alpha_unit, beta_unit); | ||
| dlogQ_dalpha = log_Qn_fvar.d_; | ||
| } else { | ||
| const T_partials_return Qn = exp(log_Qn); | ||
| dlogQ_dalpha | ||
| = grad_reg_inc_gamma(alpha_dbl, beta_y, tgamma(alpha_dbl), | ||
| digamma(alpha_dbl)) | ||
| / Qn; | ||
| } | ||
| } | ||
| } | ||
| } | ||
| } | ||
|
|
||
| if constexpr (is_autodiff_v<T_shape>) { | ||
| if (!use_cf) { | ||
| if constexpr (partials_fvar) { | ||
| auto alpha_unit = alpha_dbl; | ||
| alpha_unit.d_ = 1; | ||
| auto beta_unit = beta_y; | ||
| beta_unit.d_ = 0; | ||
| auto log_Qn_fvar = log1m(gamma_p(alpha_unit, beta_unit)); | ||
| dlogQ_dalpha = log_Qn_fvar.d_; | ||
| } else { | ||
| const T_partials_return Qn = exp(log_Qn); | ||
| if (Qn > 0.0) { | ||
| dlogQ_dalpha = -grad_reg_lower_inc_gamma(alpha_dbl, beta_y) / Qn; | ||
| } else { | ||
| // Fallback to continued fraction if Q rounds to zero | ||
| if constexpr (!any_fvar) { | ||
| auto log_q_result | ||
| = log_gamma_q_dgamma(alpha_dbl_val, beta_y_dbl); | ||
| log_Qn = log_q_result.log_q; | ||
| dlogQ_dalpha = log_q_result.dlog_q_da; | ||
| } else { | ||
| log_Qn = internal::log_q_gamma_cf(alpha_dbl, beta_y); | ||
| const T_partials_return Qn_cf = exp(log_Qn); | ||
| dlogQ_dalpha | ||
| = grad_reg_inc_gamma(alpha_dbl, beta_y, tgamma(alpha_dbl), | ||
| digamma(alpha_dbl)) | ||
| / Qn_cf; | ||
| } | ||
| use_cf = true; | ||
| } | ||
| } | ||
| } | ||
| } | ||
| } | ||
| if (!std::isfinite(value_of(value_of(log_Qn)))) { | ||
| return ops_partials.build(negative_infinity()); | ||
| } | ||
| P += log_Qn; | ||
|
|
||
| if constexpr (is_any_autodiff_v<T_y, T_inv_scale>) { | ||
| const T_partials_return log_y_dbl = log(y_dbl); | ||
| const T_partials_return log_beta_dbl = log(beta_dbl); | ||
| if constexpr (need_y_beta_deriv) { | ||
| const T_partials_return log_y = log(y_dbl); | ||
| const T_partials_return log_beta = log(beta_dbl); | ||
| const T_partials_return lgamma_alpha = lgamma(alpha_dbl); | ||
| const T_partials_return alpha_minus_one = fma(alpha_dbl, log_y, -log_y); | ||
|
|
||
| const T_partials_return log_pdf | ||
| = alpha_dbl * log_beta_dbl - lgamma(alpha_dbl) | ||
| + (alpha_dbl - 1.0) * log_y_dbl - beta_y_dbl; | ||
| const T_partials_return common_term = exp(log_pdf - log_Qn); | ||
| = alpha_dbl * log_beta - lgamma_alpha + alpha_minus_one - beta_y; | ||
|
|
||
| const T_partials_return hazard = exp(log_pdf - log_Qn); // f/Q | ||
|
|
||
| if constexpr (is_autodiff_v<T_y>) { | ||
| // d/dy log(1-F(y)) = -f(y)/(1-F(y)) | ||
| partials<0>(ops_partials)[n] -= common_term; | ||
| partials<0>(ops_partials)[n] -= hazard; | ||
| } | ||
| if constexpr (is_autodiff_v<T_inv_scale>) { | ||
| // d/dbeta log(1-F(y)) = -y*f(y)/(beta*(1-F(y))) | ||
| partials<2>(ops_partials)[n] -= y_dbl / beta_dbl * common_term; | ||
| partials<2>(ops_partials)[n] -= (y_dbl / beta_dbl) * hazard; | ||
| } | ||
| } | ||
|
|
||
| if constexpr (is_autodiff_v<T_shape>) { | ||
| const T_partials_return digamma_val = digamma(alpha_dbl); | ||
| const T_partials_return gamma_val = tgamma(alpha_dbl); | ||
| // d/dalpha log(1-F(y)) = grad_upper_inc_gamma / (1-F(y)) | ||
| partials<1>(ops_partials)[n] | ||
| += grad_reg_inc_gamma(alpha_dbl, beta_y_dbl, gamma_val, digamma_val) | ||
| / Qn; | ||
| partials<1>(ops_partials)[n] += dlogQ_dalpha; | ||
| } | ||
| } | ||
| return ops_partials.build(P); | ||
|
|
||
Oops, something went wrong.
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Quick Q for clarify: Does this function only work for double types or should it be able to accept autodiff types as well? Reading the gamma_lccdf code I'm kind of confused on the branching logic.