analysis_train <- readRDS(here("models", "data", "analysis_train.rds"))
analysis_data <- readRDS(here("models", "data", "analysis_data.rds"))
analysis_folds <- readRDS(here("models", "data", "analysis_folds.rds"))Classification Tree
##Library
carry_weapon_recipe_tree <-
recipe(formula = WeaponCarryingSchool ~ ., data = analysis_train) |>
step_impute_mode(all_nominal_predictors()) |>
step_impute_mean(all_numeric_predictors())
carry_weapon_recipe_tree
── Recipe ──────────────────────────────────────────────────────────────────────
── Inputs
Number of variables by role
outcome: 1
predictor: 10
── Operations
• Mode imputation for: all_nominal_predictors()
• Mean imputation for: all_numeric_predictors()
carry_weapon_spec_tree <-
decision_tree(
cost_complexity = tune(),
tree_depth = tune(),
min_n = tune()) |>
set_engine("rpart") |>
set_mode("classification")
carry_weapon_spec_tree Decision Tree Model Specification (classification)
Main Arguments:
cost_complexity = tune()
tree_depth = tune()
min_n = tune()
Computational engine: rpart
carry_weapon_workflow_tree <-
workflow () |>
add_recipe(carry_weapon_recipe_tree) |>
add_model(carry_weapon_spec_tree)
carry_weapon_workflow_tree══ Workflow ════════════════════════════════════════════════════════════════════
Preprocessor: Recipe
Model: decision_tree()
── Preprocessor ────────────────────────────────────────────────────────────────
2 Recipe Steps
• step_impute_mode()
• step_impute_mean()
── Model ───────────────────────────────────────────────────────────────────────
Decision Tree Model Specification (classification)
Main Arguments:
cost_complexity = tune()
tree_depth = tune()
min_n = tune()
Computational engine: rpart
##tree tunifn
tree_grid <-
grid_regular(cost_complexity(),
tree_depth(c(2, 5)),
min_n(),
levels = 2)
tree_grid# A tibble: 8 × 3
cost_complexity tree_depth min_n
<dbl> <int> <int>
1 0.0000000001 2 2
2 0.1 2 2
3 0.0000000001 5 2
4 0.1 5 2
5 0.0000000001 2 40
6 0.1 2 40
7 0.0000000001 5 40
8 0.1 5 40
cart_tune <-
carry_weapon_workflow_tree |>
tune_grid(resamples = analysis_folds,
grid = tree_grid,
metrics = metric_set(roc_auc),
control = control_grid(save_pred = TRUE)
)
cart_tune
saveRDS(cart_tune, here("model_outputs", "tree_tune.rds"))bestPlot_cart <-
autoplot(cart_tune)
bestPlot_cart
best_cart <- select_best(
cart_tune,
metric = "roc_auc")
best_cart# A tibble: 1 × 4
cost_complexity tree_depth min_n .config
<dbl> <int> <int> <chr>
1 0.0000000001 5 2 Preprocessor1_Model3
cart_final_wf <- finalize_workflow(carry_weapon_workflow_tree, best_cart)
cart_final_wf══ Workflow ════════════════════════════════════════════════════════════════════
Preprocessor: Recipe
Model: decision_tree()
── Preprocessor ────────────────────────────────────────────────────────────────
2 Recipe Steps
• step_impute_mode()
• step_impute_mean()
── Model ───────────────────────────────────────────────────────────────────────
Decision Tree Model Specification (classification)
Main Arguments:
cost_complexity = 1e-10
tree_depth = 5
min_n = 2
Computational engine: rpart
cart_fit <- fit(
cart_final_wf,
analysis_train)
cart_fit
saveRDS(cart_fit, here("model_outputs", "tree_fit.rds"))tree_pred <-
augment(cart_fit, analysis_train) |>
select(WeaponCarryingSchool, .pred_class, .pred_1, .pred_0)
tree_pred# A tibble: 14,696 × 4
WeaponCarryingSchool .pred_class .pred_1 .pred_0
<fct> <fct> <dbl> <dbl>
1 0 0 0.0396 0.960
2 0 0 0.0396 0.960
3 0 0 0.0396 0.960
4 0 0 0.0396 0.960
5 0 0 0.0396 0.960
6 0 0 0.0396 0.960
7 0 0 0.0396 0.960
8 0 0 0.0396 0.960
9 0 0 0.0396 0.960
10 0 0 0.0396 0.960
# ℹ 14,686 more rows
roc_tree <- tree_pred |>
roc_curve(truth = WeaponCarryingSchool, .pred_1,
event_level = "second") |>
autoplot()
saveRDS(roc_tree, here("models", "roc_graphs","tree.rds"))#second part
tree_pred |>
roc_auc(truth = WeaponCarryingSchool,
.pred_1,
event_level = "second")# A tibble: 1 × 3
.metric .estimator .estimate
<chr> <chr> <dbl>
1 roc_auc binary 0.544
saveRDS(roc_tree, here("models", "roc_graphs", "tree.rds"))fit_resamples(cart_final_wf, resamples = analysis_folds) |>
collect_metrics()# A tibble: 3 × 6
.metric .estimator mean n std_err .config
<chr> <chr> <dbl> <int> <dbl> <chr>
1 accuracy binary 0.955 5 0.00141 Preprocessor1_Model1
2 brier_class binary 0.0412 5 0.00135 Preprocessor1_Model1
3 roc_auc binary 0.592 5 0.0226 Preprocessor1_Model1
library(rpart.plot)Warning: package 'rpart.plot' was built under R version 4.3.3
cart_fit |>
extract_fit_engine() |>
rpart.plot::rpart.plot(roundint=FALSE)