Skip to contents

Generic S3 Function for rpwf_workflow_set and rpwf_data_set

Usage

rpwf_augment(obj, ...)

# S3 method for rpwf_workflow_set
rpwf_augment(
  obj,
  db_con,
  grid_fun = NULL,
  ...,
  range = c(1L, 5000L),
  seed = 1234L
)

Arguments

obj

a rpwf_workflow_set or rpwf_data_set object.

...

additional arguments for the grid_fun functions.

db_con

an rpwf_connect_db() object.

grid_fun

a dials::grid_<functions>, e.g., dials::grid_random(), dials::grid_latin_hypercube(). Default NULL assumes that no grid search is performed and sklearn defaults are used.

range

range of seed to sample from.

seed

random seed.

Value

A data frame with necessary columns for export into the database.

Examples

# Create the database
board <- pins::board_temp()
tmp_dir <- tempdir()
db_con <- rpwf_connect_db(paste(tmp_dir, "db.SQLite", sep = "/"), board)

# Create a `workflow_set`
d <- mtcars
d$id <- seq_len(nrow(d))
m1 <- parsnip::boost_tree() |>
  parsnip::set_engine("xgboost") |>
  parsnip::set_mode("classification") |>
  set_py_engine(py_module = "xgboost", py_base_learner = "XGBClassifier")
r1 <- d |>
  recipes::recipe(vs ~ .) |>
  # "pd.index" is the special column that used for indexing in pandas
  recipes::update_role(id, new_role = "pd.index")
wf <- rpwf_workflow_set(list(r1), list(m1), "neg_log_loss")

to_export <- wf |>
  rpwf_augment(db_con, dials::grid_latin_hypercube, size = 10)
#> No hyper param tuning specified
#> Adding id as pandas idx
print(to_export)
#> # A tibble: 1 × 11
#>   preprocs models    costs py_ba…¹ model…² recip…³ model…⁴ rando…⁵ grids Rgrid  
#>   <list>   <list>    <chr> <chr>   <chr>   <chr>     <int>   <int> <lis> <list> 
#> 1 <recipe> <spec[+]> neg_… NA      XGBCla… NA            5    1004 <lgl> <RGrid>
#> # … with 1 more variable: TrainDf <list>, and abbreviated variable names
#> #   ¹​py_base_learner_args, ²​model_tag, ³​recipe_tag, ⁴​model_type_id,
#> #   ⁵​random_state