Skip to content
Open
Show file tree
Hide file tree
Changes from 4 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions NEWS.md
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,9 @@ Thanks to Sermet Pekin. (#1790)

### Bug Fixes

* `bayes_R2` now uses model-based residual variances for Gaussian and Bernoulli
Comment thread
florence-bockting marked this conversation as resolved.
models and falls back to residual-based computation for other families. This
change may lead to changes in plotting results. (#1815)
* Avoid the creation of zombie workers when executing `log_lik`
in parallel thanks to Aki Vehtari and Noa Kallioinen.
For now, `log_lik` will use PSOCK clusters if run
Expand Down
48 changes: 43 additions & 5 deletions R/bayes_R2.R
Original file line number Diff line number Diff line change
Expand Up @@ -15,11 +15,21 @@
#'
#' @details For an introduction to the approach, see Gelman et al. (2019)
#' and \url{https://github.com/jgabry/bayes_R2/}.
#' For Gaussian and Bernoulli models, \code{bayes_R2} uses model-based residual
#' variances as proposed in Gelman et al. (2019), with Bernoulli models using
#' Tjur's pseudo-variance. For Gaussian models with heteroscedastic
#' sigma, the mean residual variance is used as an approximation (see Tjur
#' (2009) for discussion on this approximation). For other families,
#' \code{bayes_R2} warns and falls back to residual-based variances.
#'
#' @references Andrew Gelman, Ben Goodrich, Jonah Gabry & Aki Vehtari. (2019).
#' R-squared for Bayesian regression models, \emph{The American Statistician},
#' 73(3):307-309. \code{10.1080/00031305.2018.1549100} (Preprint available at
#' \url{https://acris.aalto.fi/ws/portalfiles/portal/34206843/bayes_R2_v3.pdf})
#'
#' Tue Tjur. (2009). Coefficient of determination in logistic regression
#' models - A new proposal: The coefficient of discrimination, \emph{The
#' American Statistician}, 63:366-372. \code{10.1198/tast.2009.08210}
#'
#' @examples
#' \dontrun{
Expand Down Expand Up @@ -70,18 +80,41 @@ bayes_R2.brmsfit <- function(object, resp = NULL, summary = TRUE,
"'bayes_R2' which is likely invalid for ordinal families."
)
}
args_y <- list(object, warn = TRUE, ...)
Comment thread
paul-buerkner marked this conversation as resolved.
args_ypred <- list(object, sort = TRUE, ...)
R2 <- named_list(paste0("R2", resp))
warned_families <- character(0)

for (i in seq_along(R2)) {
# assumes expectations of different responses to be independent
args_ypred$resp <- args_y$resp <- resp[i]
y <- do_call(get_y, args_y)
args_ypred$resp <- resp[i]
ypred <- do_call(posterior_epred, args_ypred)
if (is_ordinal(family(object, resp = resp[i]))) {
ypred <- ordinal_probs_continuous(ypred)
}
R2[[i]] <- .bayes_R2(y, ypred)
family_name <- family(object, resp = resp[i])$family
if (family_name %in% "gaussian") {
Copy link
Copy Markdown
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this pattern works for now. If we later on support the model-based approach for more than two families, we should provide separate functions per family, similar to how we do it with a lot of other post-processing funcftions.

Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I refactored the code and outsourced the family-specific computations for better maintainability and easier future updating. Thank you for pointing this out.

args_sigma <- args_ypred
args_sigma$dpar <- "sigma"
sigma <- do_call(posterior_epred, args_sigma)
# use mean of heteroscedastic sigma as approximate
# (see Tjur (2009) for discussion)
var_res <- matrixStats::rowMeans2(sigma)^2
R2[[i]] <- .bayes_R2_model(ypred, var_res)
} else if (family_name %in% "bernoulli") {
var_res <- matrixStats::rowMeans2(ypred * (1 - ypred))
R2[[i]] <- .bayes_R2_model(ypred, var_res)
} else {
if (!family_name %in% warned_families) {
warning2(
"No model-based residual variance is currently implemented for ",
"family '", family_name, "'\nin 'bayes_R2'. ",
"Falling back to residual-based R2 computation."
)
warned_families <- c(warned_families, family_name)
}
y <- do_call(get_y, c(list(object, warn = TRUE, resp = resp[i]), list(...)))
R2[[i]] <- .bayes_R2_res(y, ypred)
}
}
R2 <- do_call(cbind, R2)
colnames(R2) <- paste0("R2", resp)
Expand All @@ -93,9 +126,14 @@ bayes_R2.brmsfit <- function(object, resp = NULL, summary = TRUE,

# internal function of bayes_R2.brmsfit
# see https://github.com/jgabry/bayes_R2/blob/master/bayes_R2.pdf
.bayes_R2 <- function(y, ypred, ...) {
.bayes_R2_res <- function(y, ypred, ...) {
e <- -1 * sweep(ypred, 2, y)
var_ypred <- matrixStats::rowVars(ypred)
var_e <- matrixStats::rowVars(e)
as.matrix(var_ypred / (var_ypred + var_e))
}

.bayes_R2_model <- function(ypred, var_res, ...) {
var_ypred <- matrixStats::rowVars(ypred)
as.matrix(var_ypred / (var_ypred + var_res))
}
10 changes: 10 additions & 0 deletions man/bayes_R2.brmsfit.Rd

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

58 changes: 58 additions & 0 deletions tests/local/tests.models-5.R
Original file line number Diff line number Diff line change
@@ -1,5 +1,63 @@
source("setup_tests_local.R")

old_bayes_R2 <- function(fit) {
y <- brms:::get_y(fit, warn = TRUE)
ypred <- posterior_epred(fit, sort = TRUE)
e <- -1 * sweep(ypred, 2, y)
var_ypred <- matrixStats::rowVars(ypred)
var_e <- matrixStats::rowVars(e)
var_ypred / (var_ypred + var_e)
}

test_that("bayes_R2 prior-only draws match model-based formula", {
set.seed(2801)
dat <- data.frame(
x = rnorm(40),
y = rnorm(40)
)
fit <- brm(
y ~ x, data = dat, family = gaussian(),
prior = c(
prior(normal(0, 0.5), class = b),
prior(normal(0, 0.5), class = Intercept),
prior(exponential(1), class = sigma)
),
sample_prior = "only",
chains = 1, iter = 600, warmup = 200,
seed = 2802, refresh = 0
)
ypred <- posterior_epred(fit, sort = TRUE)
sigma <- posterior_epred(fit, dpar = "sigma", sort = TRUE)
var_mupred <- matrixStats::rowVars(ypred)
var_res <- matrixStats::rowMeans2(sigma^2)
expected <- var_mupred / (var_mupred + var_res)
expect_equal(bayes_R2(fit, summary = FALSE)[, 1], expected)
})

test_that("bayes_R2 posterior difference shrinks with larger n", {
make_fit <- function(n, seed) {
set.seed(seed)
dat <- data.frame(x = rnorm(n))
dat$y <- 1 + 1.5 * dat$x + rnorm(n, sd = 1)
brm(
y ~ x, data = dat, family = gaussian(),
chains = 1, iter = 700, warmup = 300,
seed = seed, refresh = 0
)
}
fit_small <- make_fit(8, 2803)
fit_large <- make_fit(250, 2804)

model_small <- bayes_R2(fit_small, summary = FALSE)[, 1]
model_large <- bayes_R2(fit_large, summary = FALSE)[, 1]
residual_small <- old_bayes_R2(fit_small)
residual_large <- old_bayes_R2(fit_large)

diff_small <- abs(stats::median(model_small) - stats::median(residual_small))
diff_large <- abs(stats::median(model_large) - stats::median(residual_large))
expect_gt(diff_small, diff_large)
})

test_that("global shrinkage priors work correctly", suppressWarnings({
set.seed(2563)
dat <- epilepsy
Expand Down
11 changes: 7 additions & 4 deletions tests/testthat/tests.brmsfit-methods.R
Original file line number Diff line number Diff line change
Expand Up @@ -114,12 +114,15 @@ test_that("autocor has reasonable ouputs", {
})

test_that("bayes_R2 has reasonable ouputs", {
fit1 <- add_criterion(fit1, "bayes_R2")
R2 <- bayes_R2(fit1, summary = FALSE)
expect_warning(
R2 <- bayes_R2(fit1, summary = FALSE),
"No model-based residual variance is currently implemented"
)
expect_equal(dim(R2), c(ndraws(fit1), 1))
R2 <- bayes_R2(fit2, newdata = model.frame(fit2)[1:5, ], re_formula = NA)
fit1 <- SW(add_criterion(fit1, "bayes_R2"))
R2 <- SW(bayes_R2(fit2, newdata = model.frame(fit2)[1:5, ], re_formula = NA))
expect_equal(dim(R2), c(1, 4))
R2 <- bayes_R2(fit6)
R2 <- SW(bayes_R2(fit6))
expect_equal(dim(R2), c(2, 4))
})

Expand Down
Loading