Skip to content
Open
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions NEWS.md
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,9 @@ Thanks to Sermet Pekin. (#1790)

### Bug Fixes

* `bayes_R2` now uses model-based residual variances for Gaussian and Bernoulli
Comment thread
florence-bockting marked this conversation as resolved.
models and falls back to residual-based computation for other families. This
change may lead to changes in plotting results. (#1815)
* Avoid the creation of zombie workers when executing `log_lik`
in parallel thanks to Aki Vehtari and Noa Kallioinen.
For now, `log_lik` will use PSOCK clusters if run
Expand Down
48 changes: 43 additions & 5 deletions R/bayes_R2.R
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,10 @@
#'
#' @details For an introduction to the approach, see Gelman et al. (2019)
#' and \url{https://github.com/jgabry/bayes_R2/}.
#' For Gaussian and Bernoulli models, \code{bayes_R2} uses model-based
#' residual variances as proposed in Gelman et al. (2019), with Bernoulli
#' models using Tjur's pseudo-variance. For other families, \code{bayes_R2}
#' warns and falls back to residual-based variances.
#'
#' @references Andrew Gelman, Ben Goodrich, Jonah Gabry & Aki Vehtari. (2019).
#' R-squared for Bayesian regression models, \emph{The American Statistician},
Expand Down Expand Up @@ -70,18 +74,38 @@ bayes_R2.brmsfit <- function(object, resp = NULL, summary = TRUE,
"'bayes_R2' which is likely invalid for ordinal families."
)
}
args_y <- list(object, warn = TRUE, ...)
Comment thread
paul-buerkner marked this conversation as resolved.
args_ypred <- list(object, sort = TRUE, ...)
R2 <- named_list(paste0("R2", resp))
warned_families <- character(0)
for (i in seq_along(R2)) {
# assumes expectations of different responses to be independent
args_ypred$resp <- args_y$resp <- resp[i]
y <- do_call(get_y, args_y)
args_ypred$resp <- resp[i]
ypred <- do_call(posterior_epred, args_ypred)
if (is_ordinal(family(object, resp = resp[i]))) {
ypred <- ordinal_probs_continuous(ypred)
}
R2[[i]] <- .bayes_R2(y, ypred)
family_name <- family(object, resp = resp[i])$family
if (family_name %in% "gaussian") {
Copy link
Copy Markdown
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this pattern works for now. If we later on support the model-based approach for more than two families, we should provide separate functions per family, similar to how we do it with a lot of other post-processing funcftions.

Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I refactored the code and outsourced the family-specific computations for better maintainability and easier future updating. Thank you for pointing this out.

args_sigma <- args_ypred
args_sigma$dpar <- "sigma"
sigma <- do_call(posterior_epred, args_sigma)
var_res <- .bayes_R2_var_res_gaussian(sigma)
R2[[i]] <- .bayes_R2_model(ypred, var_res)
} else if (family_name %in% "bernoulli") {
var_res <- .bayes_R2_var_res_bernoulli(ypred)
R2[[i]] <- .bayes_R2_model(ypred, var_res)
} else {
if (!family_name %in% warned_families) {
warning2(
"No model-based residual variance is currently implemented for ",
"family '", family_name, "' in 'bayes_R2'. Falling back ",
"to residual-based R2 computation."
)
warned_families <- c(warned_families, family_name)
}
y <- do_call(get_y, c(list(object, warn = TRUE, resp = resp[i]), list(...)))
R2[[i]] <- .bayes_R2_res(y, ypred)
}
}
R2 <- do_call(cbind, R2)
colnames(R2) <- paste0("R2", resp)
Expand All @@ -93,9 +117,23 @@ bayes_R2.brmsfit <- function(object, resp = NULL, summary = TRUE,

# internal function of bayes_R2.brmsfit
# see https://github.com/jgabry/bayes_R2/blob/master/bayes_R2.pdf
.bayes_R2 <- function(y, ypred, ...) {
.bayes_R2_res <- function(y, ypred, ...) {
e <- -1 * sweep(ypred, 2, y)
var_ypred <- matrixStats::rowVars(ypred)
var_e <- matrixStats::rowVars(e)
as.matrix(var_ypred / (var_ypred + var_e))
}

.bayes_R2_model <- function(ypred, var_res, ...) {
var_ypred <- matrixStats::rowVars(ypred)
as.matrix(var_ypred / (var_ypred + var_res))
}

.bayes_R2_var_res_gaussian <- function(sigma, ...) {
sigma <- as.matrix(sigma)
matrixStats::rowMeans2(sigma^2)
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't understand why there is rowMeans. The paper and the online appendix provided equations only for scalar sigma, and I think just taking mean for vector sigma in heteroscedastic case is not correct. Even if more correct approach would be too complicated to implement, at least this approximation should be documented.

Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thank you for pointing this out. I decided now for the following approach: check whether sigma is heteroscedastic. If true, a warning is provided and R2 computation falls back to residual-based R2.

Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

After discussing this point again, we (Aki, Florence) decided to use also for the heteroscedastic case the mean of sigma as an approximation as this is consistent with the discussion provided in Tjur (2009). A corresponding note is added in the function documentation for the user as well as in the code as a comment for the developer.

To sum up, we use model-based variances for (homo- and heteroscedastic) Gaussian and Bernoulli models. Otherwise residual-based variances are used and the user warned.

}

.bayes_R2_var_res_bernoulli <- function(ypred, ...) {
matrixStats::rowMeans2(ypred * (1 - ypred))
}
4 changes: 4 additions & 0 deletions man/bayes_R2.brmsfit.Rd

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

58 changes: 58 additions & 0 deletions tests/local/tests.models-5.R
Original file line number Diff line number Diff line change
@@ -1,5 +1,63 @@
source("setup_tests_local.R")

old_bayes_R2 <- function(fit) {
y <- brms:::get_y(fit, warn = TRUE)
ypred <- posterior_epred(fit, sort = TRUE)
e <- -1 * sweep(ypred, 2, y)
var_ypred <- matrixStats::rowVars(ypred)
var_e <- matrixStats::rowVars(e)
var_ypred / (var_ypred + var_e)
}

test_that("bayes_R2 prior-only draws match model-based formula", {
set.seed(2801)
dat <- data.frame(
x = rnorm(40),
y = rnorm(40)
)
fit <- brm(
y ~ x, data = dat, family = gaussian(),
prior = c(
prior(normal(0, 0.5), class = b),
prior(normal(0, 0.5), class = Intercept),
prior(exponential(1), class = sigma)
),
sample_prior = "only",
chains = 1, iter = 600, warmup = 200,
seed = 2802, refresh = 0
)
ypred <- posterior_epred(fit, sort = TRUE)
sigma <- posterior_epred(fit, dpar = "sigma", sort = TRUE)
var_mupred <- matrixStats::rowVars(ypred)
var_res <- matrixStats::rowMeans2(sigma^2)
expected <- var_mupred / (var_mupred + var_res)
expect_equal(bayes_R2(fit, summary = FALSE)[, 1], expected)
})

test_that("bayes_R2 posterior difference shrinks with larger n", {
make_fit <- function(n, seed) {
set.seed(seed)
dat <- data.frame(x = rnorm(n))
dat$y <- 1 + 1.5 * dat$x + rnorm(n, sd = 1)
brm(
y ~ x, data = dat, family = gaussian(),
chains = 1, iter = 700, warmup = 300,
seed = seed, refresh = 0
)
}
fit_small <- make_fit(8, 2803)
fit_large <- make_fit(250, 2804)

model_small <- bayes_R2(fit_small, summary = FALSE)[, 1]
model_large <- bayes_R2(fit_large, summary = FALSE)[, 1]
residual_small <- old_bayes_R2(fit_small)
residual_large <- old_bayes_R2(fit_large)

diff_small <- abs(stats::median(model_small) - stats::median(residual_small))
diff_large <- abs(stats::median(model_large) - stats::median(residual_large))
expect_gt(diff_small, diff_large)
})

test_that("global shrinkage priors work correctly", suppressWarnings({
set.seed(2563)
dat <- epilepsy
Expand Down
11 changes: 7 additions & 4 deletions tests/testthat/tests.brmsfit-methods.R
Original file line number Diff line number Diff line change
Expand Up @@ -114,12 +114,15 @@ test_that("autocor has reasonable ouputs", {
})

test_that("bayes_R2 has reasonable ouputs", {
fit1 <- add_criterion(fit1, "bayes_R2")
R2 <- bayes_R2(fit1, summary = FALSE)
expect_warning(
R2 <- bayes_R2(fit1, summary = FALSE),
"No model-based residual variance is currently implemented"
)
expect_equal(dim(R2), c(ndraws(fit1), 1))
R2 <- bayes_R2(fit2, newdata = model.frame(fit2)[1:5, ], re_formula = NA)
fit1 <- SW(add_criterion(fit1, "bayes_R2"))
R2 <- SW(bayes_R2(fit2, newdata = model.frame(fit2)[1:5, ], re_formula = NA))
expect_equal(dim(R2), c(1, 4))
R2 <- bayes_R2(fit6)
R2 <- SW(bayes_R2(fit6))
expect_equal(dim(R2), c(2, 4))
})

Expand Down
Loading