Skip to content

Commit

Permalink
Merge pull request #664 from cmu-delphi/ds/style
Browse files Browse the repository at this point in the history
style: minor styling changes
  • Loading branch information
dshemetov committed Oct 2, 2023
2 parents fdbbb69 + a9afed7 commit 31c8bf8
Show file tree
Hide file tree
Showing 2 changed files with 90 additions and 41 deletions.
9 changes: 9 additions & 0 deletions R-packages/evalcast/.lintr
@@ -0,0 +1,9 @@
linters: linters_with_defaults(
line_length_linter(120),
cyclocomp_linter = NULL,
object_length_linter(length = 40L)
)
exclusions: list(
"renv",
"venv"
)
122 changes: 81 additions & 41 deletions R-packages/evalcast/R/get_predictions.R
@@ -1,5 +1,5 @@
#' Get predictions
#'
#'
#' For each of the provided forecast dates, runs a forecaster using the data
#' that would have been available as of that given forecast date. Returns a list
#' of "predictions cards", where each list element corresponds to a different
Expand All @@ -14,13 +14,12 @@
#' probabilities associated with quantile forecasts for that location and
#' ahead. If your forecaster produces point forecasts, then set `quantile=NA`.
#'
#' One argument to `forecaster` must be named `df_list`. It will be
#' populated with the list of historical data returned by a call
#' to COVIDcast. The list will be the same length as the number of rows in
#' the `signals` tibble (see below).
#' The forecaster will also receive a single `forecast_date` as a named argument.
#' Any additional named arguments can be passed via the `forecaster_args`
#' argument below.
#' One argument to `forecaster` must be named `df_list`. It will be populated
#' with the list of historical data returned by a call to COVIDcast. The list
#' will be the same length as the number of rows in the `signals` tibble (see
#' below). The forecaster will also receive a single `forecast_date` as a
#' named argument. Any additional named arguments can be passed via the
#' `forecaster_args` argument below.
#'
#' Thus, the forecaster should have a signature like
#' `forecaster(df_list = data, forecast_data = forecast_date, ...)`
Expand All @@ -29,49 +28,52 @@
#' @template forecast_dates-template
#' @template incidence_period-template
#' @template apply_corrections-template
#' @param response_data_source String indicating the `data_source` of the response.
#' This is used mainly for downstream evaluation. By default, this will be the
#' same as the `data_source` in the first row of the `signals` tibble.
#' @param response_data_source String indicating the `data_source` of the
#' response. This is used mainly for downstream evaluation. By default, this
#' will be the same as the `data_source` in the first row of the `signals`
#' tibble.
#' @param response_data_signal String indicating the `signal` of the response.
#' This is used mainly for downstream evaluation. By default, this will be the
#' same as the `signal` in the first row in the `signals` tibble.
#' @param forecaster_args a list of additional named arguments to be passed
#' to `forecaster()`. A common use case would be to pass the period ahead
#' (e.g. predict 1 day, 2 days, ..., k days ahead). Note that `ahead` is a
#' required component of the forecaster output (see above).
#' @param parallel_execution FALSE (default), TRUE, or a single positive integer. If TRUE,
#' executes each forecast date prediction in parallel on the number of
#' detected cores available minus 1. If FALSE, the code is run on a single
#' core. If integer, runs in parallel on that many cores, clipping to the
#' number of detected cores available if a greater number is requested. Uses
#' [`bettermc::mclapply`] for parallelism if using more than one core.
#' @param parallel_execution FALSE (default), TRUE, or a single positive
#' integer. If TRUE, executes each forecast date prediction in parallel on the
#' number of detected cores available minus 1. If FALSE, the code is run on a
#' single core. If integer, runs in parallel on that many cores, clipping to
#' the number of detected cores available if a greater number is requested.
#' Uses [`bettermc::mclapply`] for parallelism if using more than one core.
#' @param additional_mclapply_args a named list of additional arguments to pass
#' to [`bettermc::mclapply`] (besides `X`, `FUN`, and
#' `mc.cores`.)
#' @param honest_as_of a boolean that, if true, ensures that the forecast_day, end_day,
#' and as_of are all equal when downloading data. Otherwise, as_of is allowed to be freely
#' set (and defaults to current date if not presented).
#' @param honest_as_of a boolean that, if true, ensures that the forecast_day,
#' end_day, and as_of are all equal when downloading data. Otherwise, as_of is
#' allowed to be freely set (and defaults to current date if not presented).
#' @param offline_signal_dir the directory that stores the cached data for each
#' (signal, forecast day) pair. If this is null, no caching is done and the data is
#' downloaded from covidcast.
#' (signal, forecast day) pair. If this is null, no caching is done and the data
#' is downloaded from covidcast.
#'
#' @template predictions_cards-template
#'
#' @examples \dontrun{
#' baby_predictions = get_predictions(
#' baby_predictions <- get_predictions(
#' baseline_forecaster, "baby",
#' tibble::tibble(
#' data_source = "jhu-csse",
#' signal = "deaths_incidence_num",
#' start_day = "2020-08-15",
#' geo_values = "mi",
#' geo_type = "state"),
#' geo_type = "state"
#' ),
#' forecast_dates = "2020-10-01",
#' incidence_period = "epiweek",
#' forecaster_args = list(
#' incidence_period = "epiweek",
#' ahead = 1:4
#' ))
#' )
#' )
#' }
#'
#' @export
Expand All @@ -89,15 +91,40 @@ get_predictions <- function(forecaster,
additional_mclapply_args = list(),
honest_as_of = TRUE,
offline_signal_dir = NULL) {

assert_that(is_tibble(signals), msg = "`signals` should be a tibble.")
assert_that(xor(honest_as_of, "as_of" %in% names(signals)), msg = "`honest_as_of` should be set if and only if `as_of` is not set. Either remove as_of specification or set honest_as_of to FALSE.")
assert_that(
xor(honest_as_of, "as_of" %in% names(signals)),
msg = paste(
"`honest_as_of` should be set if and only if `as_of` is not set. Either",
" remove as_of specification or set honest_as_of to FALSE."
)
)
incidence_period <- match.arg(incidence_period)

if (incidence_period == "epiweek") rlang::warn("incidence_period of 'epiweek' only selects weekly (Saturday) target end dates, actual 'epiweek' signals not supported.", "evalcast::get_predictions")
if ("time_type" %in% names(signals) && any(signals$time_type == "week" | signals$time_type == "epiweek")) rlang::abort("time_type in the signals arg can only be 'day'.", "evalcast::get_predictions")
if ("lag" %in% names(signals)) rlang::warn("lag in the signals arg will be ignored.", "evalcast::get_predictions")
if ("issues" %in% names(signals)) rlang::warn("issues in the signals arg will be ignored.", "evalcast::get_predictions")
if (incidence_period == "epiweek") {
rlang::warn(
paste(
"incidence_period of 'epiweek' only selects weekly (Saturday) target ",
"end dates, actual 'epiweek' signals not supported."
),
"evalcast::get_predictions"
)
}
if (
"time_type" %in% names(signals) &&
any(signals$time_type == "week" | signals$time_type == "epiweek")
) {
rlang::abort(
"time_type in the signals arg can only be 'day'.",
"evalcast::get_predictions"
)
}
if ("lag" %in% names(signals)) {
rlang::warn("lag in the signals arg will be ignored.", "evalcast::get_predictions")
}
if ("issues" %in% names(signals)) {
rlang::warn("issues in the signals arg will be ignored.", "evalcast::get_predictions")
}

if (rlang::is_bool(parallel_execution) && parallel_execution == TRUE) {
num_cores <- max(1L, parallel::detectCores() - 1L)
Expand All @@ -109,11 +136,14 @@ get_predictions <- function(forecaster,
stop("parallel_execution argument is neither boolean nor integer.")
}

assert_that(rlang::is_named2(additional_mclapply_args) &&
all(rlang::names2(additional_mclapply_args) %in% rlang::fn_fmls_names(bettermc::mclapply)))
assert_that(
rlang::is_named2(additional_mclapply_args) &&
all(rlang::names2(additional_mclapply_args) %in% rlang::fn_fmls_names(bettermc::mclapply))
)

get_predictions_single_date_ <- function(forecast_date) {
preds <- do.call(get_predictions_single_date,
preds <- do.call(
get_predictions_single_date,
list(
forecaster = forecaster,
signals = signals,
Expand All @@ -128,7 +158,14 @@ get_predictions <- function(forecaster,
}

if (num_cores > 1L) {
out <- rlang::inject(bettermc::mclapply(forecast_dates, get_predictions_single_date_, mc.cores = num_cores, !!!additional_mclapply_args)) %>% bind_rows()
out <- rlang::inject(
bettermc::mclapply(
forecast_dates,
get_predictions_single_date_,
mc.cores = num_cores,
!!!additional_mclapply_args
)
) %>% bind_rows()
} else {
out <- lapply(forecast_dates, get_predictions_single_date_) %>% bind_rows()
}
Expand All @@ -143,9 +180,10 @@ get_predictions <- function(forecaster,
target_end_date = get_target_period(
.data$forecast_date,
incidence_period,
.data$ahead)$end,
.data$ahead
)$end,
incidence_period = incidence_period
) %>%
) %>%
relocate(.data$forecaster, .before = .data$forecast_date)

class(out) <- c("predictions_cards", class(out))
Expand All @@ -159,7 +197,6 @@ get_predictions_single_date <- function(forecaster,
forecaster_args,
honest_as_of = TRUE,
offline_signal_dir = NULL) {

forecast_date <- lubridate::ymd(forecast_date)
signals <- signal_listcols(signals, forecast_date)

Expand Down Expand Up @@ -189,9 +226,12 @@ get_predictions_single_date <- function(forecaster,

out <- do.call(forecaster, forecaster_args)
assert_that(all(c("ahead", "geo_value", "quantile", "value") %in% names(out)),
msg = paste("Your forecaster must return a data frame with",
"(at least) the columnns `ahead`, `geo_value`,",
"`quantile`, and `value`."))
msg = paste(
"Your forecaster must return a data frame with",
"(at least) the columnns `ahead`, `geo_value`,",
"`quantile`, and `value`."
)
)

# make a predictions card
out$forecast_date <- forecast_date
Expand Down

0 comments on commit 31c8bf8

Please sign in to comment.