Generic S3 Function for rpwf_workflow_set
and rpwf_data_set
rpwf_augment.Rd
Generic S3 Function for rpwf_workflow_set
and rpwf_data_set
Usage
rpwf_augment(obj, ...)
# S3 method for rpwf_workflow_set
rpwf_augment(
obj,
db_con,
grid_fun = NULL,
...,
range = c(1L, 5000L),
seed = 1234L
)
Arguments
- obj
a
rpwf_workflow_set
orrpwf_data_set
object.- ...
additional arguments for the
grid_fun
functions.- db_con
an
rpwf_connect_db()
object.- grid_fun
a
dials::grid_<functions>
, e.g.,dials::grid_random()
,dials::grid_latin_hypercube()
. DefaultNULL
assumes that no grid search is performed and sklearn defaults are used.- range
range of seed to sample from.
- seed
random seed.
Examples
# Create the database
board <- pins::board_temp()
tmp_dir <- tempdir()
db_con <- rpwf_connect_db(paste(tmp_dir, "db.SQLite", sep = "/"), board)
# Create a `workflow_set`
d <- mtcars
d$id <- seq_len(nrow(d))
m1 <- parsnip::boost_tree() |>
parsnip::set_engine("xgboost") |>
parsnip::set_mode("classification") |>
set_py_engine(py_module = "xgboost", py_base_learner = "XGBClassifier")
r1 <- d |>
recipes::recipe(vs ~ .) |>
# "pd.index" is the special column that used for indexing in pandas
recipes::update_role(id, new_role = "pd.index")
wf <- rpwf_workflow_set(list(r1), list(m1), "neg_log_loss")
to_export <- wf |>
rpwf_augment(db_con, dials::grid_latin_hypercube, size = 10)
#> No hyper param tuning specified
#> Adding id as pandas idx
print(to_export)
#> # A tibble: 1 × 11
#> preprocs models costs py_ba…¹ model…² recip…³ model…⁴ rando…⁵ grids Rgrid
#> <list> <list> <chr> <chr> <chr> <chr> <int> <int> <lis> <list>
#> 1 <recipe> <spec[+]> neg_… NA XGBCla… NA 5 1004 <lgl> <RGrid>
#> # … with 1 more variable: TrainDf <list>, and abbreviated variable names
#> # ¹py_base_learner_args, ²model_tag, ³recipe_tag, ⁴model_type_id,
#> # ⁵random_state