@@ -363,36 +363,15 @@ bayesian_make <- function(modes = c("classification", "regression")) {
363363 )
364364 )
365365 )
366- } else {
366+
367367 parsnip :: set_pred(
368368 model = model ,
369369 eng = engine ,
370370 mode = mode ,
371- type = " numeric " ,
371+ type = " conf_int " ,
372372 value = list (
373373 pre = NULL ,
374374 post = function (results , object ) {
375- tibble :: tibble(.pred = results [, 1 ])
376- },
377- func = predfunc ,
378- args = list (
379- object = rlang :: expr(object $ fit ),
380- newdata = rlang :: expr(new_data ),
381- summary = TRUE
382- )
383- )
384- )
385- }
386-
387- parsnip :: set_pred(
388- model = model ,
389- eng = engine ,
390- mode = mode ,
391- type = " conf_int" ,
392- value = list (
393- pre = NULL ,
394- post = function (results , object ) {
395- if (mode == " classification" ) {
396375 res_2 <-
397376 tibble :: tibble(
398377 lo = parsnip :: convert_stan_interval(
@@ -412,43 +391,28 @@ bayesian_make <- function(modes = c("classification", "regression")) {
412391 lo_nms <- paste0(" .pred_lower_" , object $ lvl )
413392 hi_nms <- paste0(" .pred_upper_" , object $ lvl )
414393 colnames(res ) <- c(lo_nms [1 ], hi_nms [1 ], lo_nms [2 ], hi_nms [2 ])
415- } else {
416- res <-
417- tibble :: tibble(
418- .pred_lower = parsnip :: convert_stan_interval(
419- results ,
420- level = object $ spec $ method $ pred $ conf_int $ extras $ level
421- ),
422- .pred_upper = parsnip :: convert_stan_interval(
423- results ,
424- level = object $ spec $ method $ pred $ conf_int $ extras $ level ,
425- lower = FALSE
426- ),
427- )
428- }
429- if (object $ spec $ method $ pred $ conf_int $ extras $ std_error ) {
430- res $ .std_error <- apply(results , 2 , stats :: sd , na.rm = TRUE )
431- }
432- res
433- },
434- func = predfunc ,
435- args = list (
436- object = rlang :: expr(object $ fit ),
437- newdata = rlang :: expr(new_data ),
438- summary = FALSE
394+ if (object $ spec $ method $ pred $ conf_int $ extras $ std_error ) {
395+ res $ .std_error <- apply(results , 2 , stats :: sd , na.rm = TRUE )
396+ }
397+ res
398+ },
399+ func = predfunc ,
400+ args = list (
401+ object = rlang :: expr(object $ fit ),
402+ newdata = rlang :: expr(new_data ),
403+ summary = FALSE
404+ )
439405 )
440406 )
441- )
442407
443- parsnip :: set_pred(
444- model = model ,
445- eng = engine ,
446- mode = mode ,
447- type = " pred_int" ,
448- value = list (
449- pre = NULL ,
450- post = function (results , object ) {
451- if (mode == " classification" ) {
408+ parsnip :: set_pred(
409+ model = model ,
410+ eng = engine ,
411+ mode = mode ,
412+ type = " pred_int" ,
413+ value = list (
414+ pre = NULL ,
415+ post = function (results , object ) {
452416 res_2 <-
453417 tibble :: tibble(
454418 lo = parsnip :: convert_stan_interval(
@@ -468,7 +432,83 @@ bayesian_make <- function(modes = c("classification", "regression")) {
468432 lo_nms <- paste0(" .pred_lower_" , object $ lvl )
469433 hi_nms <- paste0(" .pred_upper_" , object $ lvl )
470434 colnames(res ) <- c(lo_nms [1 ], hi_nms [1 ], lo_nms [2 ], hi_nms [2 ])
471- } else {
435+
436+ if (object $ spec $ method $ pred $ pred_int $ extras $ std_error ) {
437+ res $ .std_error <- apply(results , 2 , stats :: sd , na.rm = TRUE )
438+ }
439+ res
440+ },
441+ func = predfunc ,
442+ args = list (
443+ object = rlang :: expr(object $ fit ),
444+ newdata = rlang :: expr(new_data ),
445+ summary = FALSE
446+ )
447+ )
448+ )
449+ } else {
450+ parsnip :: set_pred(
451+ model = model ,
452+ eng = engine ,
453+ mode = mode ,
454+ type = " numeric" ,
455+ value = list (
456+ pre = NULL ,
457+ post = function (results , object ) {
458+ tibble :: tibble(.pred = results [, 1 ])
459+ },
460+ func = predfunc ,
461+ args = list (
462+ object = rlang :: expr(object $ fit ),
463+ newdata = rlang :: expr(new_data ),
464+ summary = TRUE
465+ )
466+ )
467+ )
468+
469+ parsnip :: set_pred(
470+ model = model ,
471+ eng = engine ,
472+ mode = mode ,
473+ type = " conf_int" ,
474+ value = list (
475+ pre = NULL ,
476+ post = function (results , object ) {
477+ res <-
478+ tibble :: tibble(
479+ .pred_lower = parsnip :: convert_stan_interval(
480+ results ,
481+ level = object $ spec $ method $ pred $ conf_int $ extras $ level
482+ ),
483+ .pred_upper = parsnip :: convert_stan_interval(
484+ results ,
485+ level = object $ spec $ method $ pred $ conf_int $ extras $ level ,
486+ lower = FALSE
487+ ),
488+ )
489+
490+ if (object $ spec $ method $ pred $ conf_int $ extras $ std_error ) {
491+ res $ .std_error <- apply(results , 2 , stats :: sd , na.rm = TRUE )
492+ }
493+ res
494+ },
495+ func = predfunc ,
496+ args = list (
497+ object = rlang :: expr(object $ fit ),
498+ newdata = rlang :: expr(new_data ),
499+ summary = FALSE
500+ )
501+ )
502+ )
503+
504+ parsnip :: set_pred(
505+ model = model ,
506+ eng = engine ,
507+ mode = mode ,
508+ type = " pred_int" ,
509+ value = list (
510+ pre = NULL ,
511+ post = function (results , object ) {
472512 res <-
473513 tibble :: tibble(
474514 .pred_lower = parsnip :: convert_stan_interval(
@@ -481,20 +521,21 @@ bayesian_make <- function(modes = c("classification", "regression")) {
481521 lower = FALSE
482522 ),
483523 )
484- }
485- if (object $ spec $ method $ pred $ pred_int $ extras $ std_error ) {
486- res $ .std_error <- apply(results , 2 , stats :: sd , na.rm = TRUE )
487- }
488- res
489- },
490- func = predfunc ,
491- args = list (
492- object = rlang :: expr(object $ fit ),
493- newdata = rlang :: expr(new_data ),
494- summary = FALSE
524+
525+ if (object $ spec $ method $ pred $ pred_int $ extras $ std_error ) {
526+ res $ .std_error <- apply(results , 2 , stats :: sd , na.rm = TRUE )
527+ }
528+ res
529+ },
530+ func = predfunc ,
531+ args = list (
532+ object = rlang :: expr(object $ fit ),
533+ newdata = rlang :: expr(new_data ),
534+ summary = FALSE
535+ )
495536 )
496537 )
497- )
538+ }
498539
499540 parsnip :: set_pred(
500541 model = model ,
0 commit comments