-
-
Notifications
You must be signed in to change notification settings - Fork 220
Fix Bayesian R2 computation for normal and Bernoulli #1880
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: master
Are you sure you want to change the base?
Changes from 2 commits
6e5c73b
0a26d7f
0d6c462
b1270da
7e8ecb1
227b063
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -15,6 +15,10 @@ | |
| #' | ||
| #' @details For an introduction to the approach, see Gelman et al. (2019) | ||
| #' and \url{https://github.com/jgabry/bayes_R2/}. | ||
| #' For Gaussian and Bernoulli models, \code{bayes_R2} uses model-based | ||
| #' residual variances as proposed in Gelman et al. (2019), with Bernoulli | ||
| #' models using Tjur's pseudo-variance. For other families, \code{bayes_R2} | ||
| #' warns and falls back to residual-based variances. | ||
| #' | ||
| #' @references Andrew Gelman, Ben Goodrich, Jonah Gabry & Aki Vehtari. (2019). | ||
| #' R-squared for Bayesian regression models, \emph{The American Statistician}, | ||
|
|
@@ -70,18 +74,38 @@ bayes_R2.brmsfit <- function(object, resp = NULL, summary = TRUE, | |
| "'bayes_R2' which is likely invalid for ordinal families." | ||
| ) | ||
| } | ||
| args_y <- list(object, warn = TRUE, ...) | ||
|
paul-buerkner marked this conversation as resolved.
|
||
| args_ypred <- list(object, sort = TRUE, ...) | ||
| R2 <- named_list(paste0("R2", resp)) | ||
| warned_families <- character(0) | ||
| for (i in seq_along(R2)) { | ||
| # assumes expectations of different responses to be independent | ||
| args_ypred$resp <- args_y$resp <- resp[i] | ||
| y <- do_call(get_y, args_y) | ||
| args_ypred$resp <- resp[i] | ||
| ypred <- do_call(posterior_epred, args_ypred) | ||
| if (is_ordinal(family(object, resp = resp[i]))) { | ||
| ypred <- ordinal_probs_continuous(ypred) | ||
| } | ||
| R2[[i]] <- .bayes_R2(y, ypred) | ||
| family_name <- family(object, resp = resp[i])$family | ||
| if (family_name %in% "gaussian") { | ||
|
Owner
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. this pattern works for now. If we later on support the model-based approach for more than two families, we should provide separate functions per family, similar to how we do it with a lot of other post-processing funcftions.
Collaborator
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I refactored the code and outsourced the family-specific computations for better maintainability and easier future updating. Thank you for pointing this out. |
||
| args_sigma <- args_ypred | ||
| args_sigma$dpar <- "sigma" | ||
| sigma <- do_call(posterior_epred, args_sigma) | ||
| var_res <- .bayes_R2_var_res_gaussian(sigma) | ||
| R2[[i]] <- .bayes_R2_model(ypred, var_res) | ||
| } else if (family_name %in% "bernoulli") { | ||
| var_res <- .bayes_R2_var_res_bernoulli(ypred) | ||
| R2[[i]] <- .bayes_R2_model(ypred, var_res) | ||
| } else { | ||
| if (!family_name %in% warned_families) { | ||
| warning2( | ||
| "No model-based residual variance is currently implemented for ", | ||
| "family '", family_name, "' in 'bayes_R2'. Falling back ", | ||
| "to residual-based R2 computation." | ||
| ) | ||
| warned_families <- c(warned_families, family_name) | ||
| } | ||
| y <- do_call(get_y, c(list(object, warn = TRUE, resp = resp[i]), list(...))) | ||
| R2[[i]] <- .bayes_R2_res(y, ypred) | ||
| } | ||
| } | ||
| R2 <- do_call(cbind, R2) | ||
| colnames(R2) <- paste0("R2", resp) | ||
|
|
@@ -93,9 +117,23 @@ bayes_R2.brmsfit <- function(object, resp = NULL, summary = TRUE, | |
|
|
||
| # internal function of bayes_R2.brmsfit | ||
| # see https://github.com/jgabry/bayes_R2/blob/master/bayes_R2.pdf | ||
| .bayes_R2 <- function(y, ypred, ...) { | ||
| .bayes_R2_res <- function(y, ypred, ...) { | ||
| e <- -1 * sweep(ypred, 2, y) | ||
| var_ypred <- matrixStats::rowVars(ypred) | ||
| var_e <- matrixStats::rowVars(e) | ||
| as.matrix(var_ypred / (var_ypred + var_e)) | ||
| } | ||
|
|
||
| .bayes_R2_model <- function(ypred, var_res, ...) { | ||
| var_ypred <- matrixStats::rowVars(ypred) | ||
| as.matrix(var_ypred / (var_ypred + var_res)) | ||
| } | ||
|
|
||
| .bayes_R2_var_res_gaussian <- function(sigma, ...) { | ||
| sigma <- as.matrix(sigma) | ||
| matrixStats::rowMeans2(sigma^2) | ||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I don't understand why there is rowMeans. The paper and the online appendix provided equations only for scalar sigma, and I think just taking mean for vector sigma in heteroscedastic case is not correct. Even if more correct approach would be too complicated to implement, at least this approximation should be documented.
Collaborator
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Thank you for pointing this out. I decided now for the following approach: check whether sigma is heteroscedastic. If true, a warning is provided and R2 computation falls back to residual-based R2.
Collaborator
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. After discussing this point again, we (Aki, Florence) decided to use also for the heteroscedastic case the mean of sigma as an approximation as this is consistent with the discussion provided in Tjur (2009). A corresponding note is added in the function documentation for the user as well as in the code as a comment for the developer. To sum up, we use model-based variances for (homo- and heteroscedastic) Gaussian and Bernoulli models. Otherwise residual-based variances are used and the user warned. |
||
| } | ||
|
|
||
| .bayes_R2_var_res_bernoulli <- function(ypred, ...) { | ||
| matrixStats::rowMeans2(ypred * (1 - ypred)) | ||
| } | ||
Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.
Uh oh!
There was an error while loading. Please reload this page.