Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Added option to impute missing values during training #336

Open
wants to merge 19 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from 9 commits
Commits
Show all changes
19 commits
Select commit Hold shift + click to select a range
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
3 changes: 3 additions & 0 deletions NEWS.md
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,9 @@
- New column `lower` and `upper` to report the bounds of the empirical 95% confidence interval from the permutation test.
See `vignette('parallel')` for an example of plotting feature importance with confidence intervals.
- Minor documentation improvements (#323, @kelly-sovacool).
- Added option to impute missing data during training rather than preprocessing (#301, @megancoden and @shah-priyal).
- Added impute_in_training option to `run_ml()`, which defaults to FALSE.
- Added impute_in_preprocessing option to `preprocess()`, which defaults to TRUE.

# mikropml 1.5.0

Expand Down
49 changes: 49 additions & 0 deletions R/impute.R
Original file line number Diff line number Diff line change
@@ -0,0 +1,49 @@
impute <- function(transformed_cont, n_missing) {
megancoden marked this conversation as resolved.
Show resolved Hide resolved
transformed_cont <- sapply_fn(transformed_cont, function(x) {
if (class(x) %in% c("integer", "numeric")) {
m <- is.na(x)
x[m] <- stats::median(x, na.rm = TRUE)
}
return(x)
}) %>% dplyr::as_tibble()
message(
paste0(
n_missing,
" missing continuous value(s) were imputed using the median value of the feature."
)
)
return (transformed_cont)
}

prep_data <- function(dataset, outcome_colname, prefilter_threshold, method, impute_in_preprocessing, to_numeric) {
dataset[[outcome_colname]] <- replace_spaces(dataset[[outcome_colname]])
dataset <- rm_missing_outcome(dataset, outcome_colname)
split_dat <- split_outcome_features(dataset, outcome_colname)

features <- split_dat$features
removed_feats <- character(0)
if (to_numeric) {
feats <- change_to_num(features) %>%
remove_singleton_columns(threshold = prefilter_threshold)
removed_feats <- feats$removed_feats
features <- feats$dat
}
pbtick(progbar)

nv_feats <- process_novar_feats(features, progbar = progbar)
pbtick(progbar)
split_feats <- process_cat_feats(nv_feats$var_feats, progbar = progbar)
pbtick(progbar)
cont_feats <- process_cont_feats(split_feats$cont_feats, method, impute_in_preprocessing)
pbtick(progbar)
# combine all processed features
processed_feats <- dplyr::bind_cols(
cont_feats$transformed_cont,
split_feats$cat_feats,
nv_feats$novar_feats
)
pbtick(progbar)

processed_data <- list(cont_feats = cont_feats, removed_feats = removed_feats, split_dat = split_dat, processed_feats = processed_feats)
return(processed_data)
}
59 changes: 17 additions & 42 deletions R/preprocess.R
Original file line number Diff line number Diff line change
@@ -1,3 +1,6 @@
# TODO: set this for a generic path (probably using here::here)
library(here)
here("R", "impute.R")
megancoden marked this conversation as resolved.
Show resolved Hide resolved
#' Preprocess data prior to running machine learning
#'
#' Function to preprocess your data for input into [run_ml()].
Expand Down Expand Up @@ -60,7 +63,7 @@ preprocess_data <- function(dataset, outcome_colname,
method = c("center", "scale"),
remove_var = "nzv", collapse_corr_feats = TRUE,
to_numeric = TRUE, group_neg_corr = TRUE,
prefilter_threshold = 1) {
prefilter_threshold = 1, impute_in_preprocessing = TRUE) {
progbar <- NULL
if (isTRUE(check_packages_installed("progressr"))) {
progbar <- progressr::progressor(steps = 20, message = "preprocessing")
Expand All @@ -70,34 +73,15 @@ preprocess_data <- function(dataset, outcome_colname,
check_outcome_column(dataset, outcome_colname, check_values = FALSE)
check_remove_var(remove_var)
pbtick(progbar)
dataset[[outcome_colname]] <- replace_spaces(dataset[[outcome_colname]])
dataset <- rm_missing_outcome(dataset, outcome_colname)
split_dat <- split_outcome_features(dataset, outcome_colname)

features <- split_dat$features
removed_feats <- character(0)
if (to_numeric) {
feats <- change_to_num(features) %>%
remove_singleton_columns(threshold = prefilter_threshold)
removed_feats <- feats$removed_feats
features <- feats$dat
}
pbtick(progbar)

nv_feats <- process_novar_feats(features, progbar = progbar)
pbtick(progbar)
split_feats <- process_cat_feats(nv_feats$var_feats, progbar = progbar)
pbtick(progbar)
cont_feats <- process_cont_feats(split_feats$cont_feats, method)
pbtick(progbar)


processed_data <- prep_data(dataset, outcome_colname, prefilter_threshold, method, impute_in_preprocessing, to_numeric)
megancoden marked this conversation as resolved.
Show resolved Hide resolved
removed_feats <- processed_data$removed_feats
processed_feats <- processed_data$processed_feats
split_dat <- processed_data$split_dat
cont_feats <- processed_data$cont_feats

# combine all processed features
processed_feats <- dplyr::bind_cols(
cont_feats$transformed_cont,
split_feats$cat_feats,
nv_feats$novar_feats
)
pbtick(progbar)


# remove features with (near-)zero variance
feats <- get_caret_processed_df(processed_feats, remove_var)
Expand Down Expand Up @@ -364,7 +348,7 @@ process_cat_feats <- function(features, progbar = NULL) {
#'
#' @examples
#' process_cont_feats(mikropml::otu_small[, 2:ncol(otu_small)], c("center", "scale"))
process_cont_feats <- function(features, method) {
process_cont_feats <- function(features, method, impute_in_preprocessing) {
transformed_cont <- NULL
removed_cont <- NULL

Expand All @@ -388,19 +372,10 @@ process_cont_feats <- function(features, method) {
n_missing <- sum(missing)
if (n_missing > 0) {
# impute missing data using the median value
transformed_cont <- sapply_fn(transformed_cont, function(x) {
if (class(x) %in% c("integer", "numeric")) {
m <- is.na(x)
x[m] <- stats::median(x, na.rm = TRUE)
}
return(x)
}) %>% dplyr::as_tibble()
message(
paste0(
n_missing,
" missing continuous value(s) were imputed using the median value of the feature."
)
)
if (impute_in_preprocessing) {
source("impute.R")
megancoden marked this conversation as resolved.
Show resolved Hide resolved
transformed_cont <- impute(transformed_cont, n_missing)
}
}
}
}
Expand Down
54 changes: 35 additions & 19 deletions R/run_ml.R
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
# TODO: test if calling these functions works
# TODO: figure out if there's a way to only specify option in one place (for runml and for preprocess)
#' Run the machine learning pipeline
#'
#' This function splits the data set into a train & test set,
Expand Down Expand Up @@ -144,6 +146,7 @@ run_ml <-
group_partitions = NULL,
corr_thresh = 1,
seed = NA,
impute_in_training = FALSE,
megancoden marked this conversation as resolved.
Show resolved Hide resolved
...) {
check_all(
dataset,
Expand All @@ -162,7 +165,7 @@ run_ml <-
if (!is.na(seed)) {
set.seed(seed)
}

# `future.apply` is required for `find_feature_importance()`.
# check it here to adhere to the fail fast principle.
if (find_feature_importance) {
Expand All @@ -173,20 +176,20 @@ run_ml <-
if (find_feature_importance) {
check_cat_feats(dataset %>% dplyr::select(-outcome_colname))
}

dataset <- dataset %>%
randomize_feature_order(outcome_colname) %>%
# convert tibble to dataframe to silence warning from caret::train():
# "Warning: Setting row names on a tibble is deprecated.."
as.data.frame()

outcomes_vctr <- dataset %>% dplyr::pull(outcome_colname)

if (length(training_frac) == 1) {
training_inds <- get_partition_indices(outcomes_vctr,
training_frac = training_frac,
groups = groups,
group_partitions = group_partitions
training_frac = training_frac,
groups = groups,
group_partitions = group_partitions
)
} else {
training_inds <- training_frac
Expand All @@ -201,30 +204,43 @@ run_ml <-
}
check_training_frac(training_frac)
check_training_indices(training_inds, dataset)

train_data <- dataset[training_inds, ]
test_data <- dataset[-training_inds, ]
if (impute_in_training == TRUE) {

train_processed_data <- prep_data(train_data, outcome_colname, prefilter_threshold=1, method=c("center", "scale"), impute_in_preprocessing=TRUE, to_numeric=TRUE)
train_processed_feats <- train_processed_data$processed_feats
split_dat <- train_processed_data$split_dat
train_data <- dplyr::bind_cols(split_dat$outcome, train_processed_feats) %>%
dplyr::as_tibble()
test_processed_data <- prep_data(test_data, outcome_colname, prefilter_threshold=1, method=c("center", "scale"), impute_in_preprocessing=TRUE, to_numeric=TRUE)
test_processed_feats <- test_processed_data$processed_feats
split_dat <- test_processed_data$split_dat
test_data <- dplyr::bind_cols(split_dat$outcome, test_processed_feats) %>%
dplyr::as_tibble()
}
megancoden marked this conversation as resolved.
Show resolved Hide resolved
# train_groups & test_groups will be NULL if groups is NULL
train_groups <- groups[training_inds]
test_groups <- groups[-training_inds]

if (is.null(hyperparameters)) {
hyperparameters <- get_hyperparams_list(dataset, method)
}
tune_grid <- get_tuning_grid(hyperparameters, method)


outcome_type <- get_outcome_type(outcomes_vctr)
class_probs <- outcome_type != "continuous"

if (is.null(perf_metric_function)) {
perf_metric_function <- get_perf_metric_fn(outcome_type)
}

if (is.null(perf_metric_name)) {
perf_metric_name <- get_perf_metric_name(outcome_type)
}

if (is.null(cross_val)) {
cross_val <- define_cv(
train_data,
Expand All @@ -238,8 +254,8 @@ run_ml <-
group_partitions = group_partitions
)
}


message("Training the model...")
trained_model_caret <- train_model(
train_data = train_data,
Expand All @@ -254,7 +270,7 @@ run_ml <-
if (!is.na(seed)) {
set.seed(seed)
}

if (calculate_performance) {
performance_tbl <- get_performance_tbl(
trained_model_caret,
Expand All @@ -269,7 +285,7 @@ run_ml <-
} else {
performance_tbl <- "Skipped calculating performance"
}

if (find_feature_importance) {
message("Finding feature importance...")
feature_importance_tbl <- get_feature_importance(
Expand All @@ -287,7 +303,7 @@ run_ml <-
} else {
feature_importance_tbl <- "Skipped feature importance"
}

return(
list(
trained_model = trained_model_caret,
Expand Down