Skip to content

Commit 4187942

Browse files
committed
bayesian: Reorganize settings for model predictions
1 parent 22bebd4 commit 4187942

1 file changed

Lines changed: 111 additions & 70 deletions

File tree

R/bayesian_make.R

Lines changed: 111 additions & 70 deletions
Original file line numberDiff line numberDiff line change
@@ -363,36 +363,15 @@ bayesian_make <- function(modes = c("classification", "regression")) {
363363
)
364364
)
365365
)
366-
} else {
366+
367367
parsnip::set_pred(
368368
model = model,
369369
eng = engine,
370370
mode = mode,
371-
type = "numeric",
371+
type = "conf_int",
372372
value = list(
373373
pre = NULL,
374374
post = function(results, object) {
375-
tibble::tibble(.pred = results[, 1])
376-
},
377-
func = predfunc,
378-
args = list(
379-
object = rlang::expr(object$fit),
380-
newdata = rlang::expr(new_data),
381-
summary = TRUE
382-
)
383-
)
384-
)
385-
}
386-
387-
parsnip::set_pred(
388-
model = model,
389-
eng = engine,
390-
mode = mode,
391-
type = "conf_int",
392-
value = list(
393-
pre = NULL,
394-
post = function(results, object) {
395-
if (mode == "classification") {
396375
res_2 <-
397376
tibble::tibble(
398377
lo = parsnip::convert_stan_interval(
@@ -412,43 +391,28 @@ bayesian_make <- function(modes = c("classification", "regression")) {
412391
lo_nms <- paste0(".pred_lower_", object$lvl)
413392
hi_nms <- paste0(".pred_upper_", object$lvl)
414393
colnames(res) <- c(lo_nms[1], hi_nms[1], lo_nms[2], hi_nms[2])
415-
} else {
416-
res <-
417-
tibble::tibble(
418-
.pred_lower = parsnip::convert_stan_interval(
419-
results,
420-
level = object$spec$method$pred$conf_int$extras$level
421-
),
422-
.pred_upper = parsnip::convert_stan_interval(
423-
results,
424-
level = object$spec$method$pred$conf_int$extras$level,
425-
lower = FALSE
426-
),
427-
)
428-
}
429-
if (object$spec$method$pred$conf_int$extras$std_error) {
430-
res$.std_error <- apply(results, 2, stats::sd, na.rm = TRUE)
431-
}
432-
res
433-
},
434-
func = predfunc,
435-
args = list(
436-
object = rlang::expr(object$fit),
437-
newdata = rlang::expr(new_data),
438-
summary = FALSE
394+
if (object$spec$method$pred$conf_int$extras$std_error) {
395+
res$.std_error <- apply(results, 2, stats::sd, na.rm = TRUE)
396+
}
397+
res
398+
},
399+
func = predfunc,
400+
args = list(
401+
object = rlang::expr(object$fit),
402+
newdata = rlang::expr(new_data),
403+
summary = FALSE
404+
)
439405
)
440406
)
441-
)
442407

443-
parsnip::set_pred(
444-
model = model,
445-
eng = engine,
446-
mode = mode,
447-
type = "pred_int",
448-
value = list(
449-
pre = NULL,
450-
post = function(results, object) {
451-
if (mode == "classification") {
408+
parsnip::set_pred(
409+
model = model,
410+
eng = engine,
411+
mode = mode,
412+
type = "pred_int",
413+
value = list(
414+
pre = NULL,
415+
post = function(results, object) {
452416
res_2 <-
453417
tibble::tibble(
454418
lo = parsnip::convert_stan_interval(
@@ -468,7 +432,83 @@ bayesian_make <- function(modes = c("classification", "regression")) {
468432
lo_nms <- paste0(".pred_lower_", object$lvl)
469433
hi_nms <- paste0(".pred_upper_", object$lvl)
470434
colnames(res) <- c(lo_nms[1], hi_nms[1], lo_nms[2], hi_nms[2])
471-
} else {
435+
436+
if (object$spec$method$pred$pred_int$extras$std_error) {
437+
res$.std_error <- apply(results, 2, stats::sd, na.rm = TRUE)
438+
}
439+
res
440+
},
441+
func = predfunc,
442+
args = list(
443+
object = rlang::expr(object$fit),
444+
newdata = rlang::expr(new_data),
445+
summary = FALSE
446+
)
447+
)
448+
)
449+
} else {
450+
parsnip::set_pred(
451+
model = model,
452+
eng = engine,
453+
mode = mode,
454+
type = "numeric",
455+
value = list(
456+
pre = NULL,
457+
post = function(results, object) {
458+
tibble::tibble(.pred = results[, 1])
459+
},
460+
func = predfunc,
461+
args = list(
462+
object = rlang::expr(object$fit),
463+
newdata = rlang::expr(new_data),
464+
summary = TRUE
465+
)
466+
)
467+
)
468+
469+
parsnip::set_pred(
470+
model = model,
471+
eng = engine,
472+
mode = mode,
473+
type = "conf_int",
474+
value = list(
475+
pre = NULL,
476+
post = function(results, object) {
477+
res <-
478+
tibble::tibble(
479+
.pred_lower = parsnip::convert_stan_interval(
480+
results,
481+
level = object$spec$method$pred$conf_int$extras$level
482+
),
483+
.pred_upper = parsnip::convert_stan_interval(
484+
results,
485+
level = object$spec$method$pred$conf_int$extras$level,
486+
lower = FALSE
487+
),
488+
)
489+
490+
if (object$spec$method$pred$conf_int$extras$std_error) {
491+
res$.std_error <- apply(results, 2, stats::sd, na.rm = TRUE)
492+
}
493+
res
494+
},
495+
func = predfunc,
496+
args = list(
497+
object = rlang::expr(object$fit),
498+
newdata = rlang::expr(new_data),
499+
summary = FALSE
500+
)
501+
)
502+
)
503+
504+
parsnip::set_pred(
505+
model = model,
506+
eng = engine,
507+
mode = mode,
508+
type = "pred_int",
509+
value = list(
510+
pre = NULL,
511+
post = function(results, object) {
472512
res <-
473513
tibble::tibble(
474514
.pred_lower = parsnip::convert_stan_interval(
@@ -481,20 +521,21 @@ bayesian_make <- function(modes = c("classification", "regression")) {
481521
lower = FALSE
482522
),
483523
)
484-
}
485-
if (object$spec$method$pred$pred_int$extras$std_error) {
486-
res$.std_error <- apply(results, 2, stats::sd, na.rm = TRUE)
487-
}
488-
res
489-
},
490-
func = predfunc,
491-
args = list(
492-
object = rlang::expr(object$fit),
493-
newdata = rlang::expr(new_data),
494-
summary = FALSE
524+
525+
if (object$spec$method$pred$pred_int$extras$std_error) {
526+
res$.std_error <- apply(results, 2, stats::sd, na.rm = TRUE)
527+
}
528+
res
529+
},
530+
func = predfunc,
531+
args = list(
532+
object = rlang::expr(object$fit),
533+
newdata = rlang::expr(new_data),
534+
summary = FALSE
535+
)
495536
)
496537
)
497-
)
538+
}
498539

499540
parsnip::set_pred(
500541
model = model,

0 commit comments

Comments
 (0)