Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[WIP][MAINT] implement walrus operator? #4202

Draft
wants to merge 2 commits into
base: main
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
5 changes: 5 additions & 0 deletions .pre-commit-config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -51,6 +51,11 @@ repos:
hooks:
- id: flynt

- repo: https://github.com/MarcoGorelli/auto-walrus
rev: v0.2.2
hooks:
- id: auto-walrus

- repo: https://github.com/asottile/pyupgrade
rev: v3.15.0
hooks:
Expand Down
18 changes: 6 additions & 12 deletions nilearn/datasets/atlas.py
Original file line number Diff line number Diff line change
Expand Up @@ -100,9 +100,8 @@ def fetch_atlas_difumo(
512: "9b76y",
1024: "34792",
}
valid_dimensions = [64, 128, 256, 512, 1024]
valid_resolution_mm = [2, 3]
if dimension not in valid_dimensions:
if dimension not in (valid_dimensions := [64, 128, 256, 512, 1024]):
raise ValueError(
f"Requested dimension={dimension} is not available. "
f"Valid options: {valid_dimensions}"
Expand Down Expand Up @@ -1268,8 +1267,7 @@ def fetch_atlas_aal(
Licence: unknown.

"""
versions = ["SPM5", "SPM8", "SPM12"]
if version not in versions:
if version not in (versions := ["SPM5", "SPM8", "SPM12"]):
raise ValueError(
f"The version of AAL requested '{version}' does not exist."
f"Please choose one among {versions}."
Expand Down Expand Up @@ -1405,8 +1403,7 @@ def fetch_atlas_basc_multiscale_2015(
https://figshare.com/articles/basc/1285615

"""
versions = ["sym", "asym"]
if version not in versions:
if version not in (versions := ["sym", "asym"]):
raise ValueError(
f"The version of Brain parcellations requested '{version}' "
"does not exist. "
Expand Down Expand Up @@ -2113,20 +2110,17 @@ def fetch_atlas_schaefer_2018(
Licence: MIT.

"""
valid_n_rois = list(range(100, 1100, 100))
valid_yeo_networks = [7, 17]
valid_resolution_mm = [1, 2]
if n_rois not in valid_n_rois:
if n_rois not in (valid_n_rois := list(range(100, 1100, 100))):
raise ValueError(
f"Requested n_rois={n_rois} not available. "
f"Valid options: {valid_n_rois}"
)
if yeo_networks not in valid_yeo_networks:
if yeo_networks not in (valid_yeo_networks := [7, 17]):
raise ValueError(
f"Requested yeo_networks={yeo_networks} not available. "
f"Valid options: {valid_yeo_networks}"
)
if resolution_mm not in valid_resolution_mm:
if resolution_mm not in (valid_resolution_mm := [1, 2]):
raise ValueError(
f"Requested resolution_mm={resolution_mm} not available. "
f"Valid options: {valid_resolution_mm}"
Expand Down
12 changes: 4 additions & 8 deletions nilearn/datasets/func.py
Original file line number Diff line number Diff line change
Expand Up @@ -1492,8 +1492,7 @@ def fetch_megatrawls_netmats(
"Invalid {0} input is provided: {1}, choose one of them {2}"
)
# standard dataset terms
dimensionalities = [25, 50, 100, 200, 300]
if dimensionality not in dimensionalities:
if dimensionality not in (dimensionalities := [25, 50, 100, 200, 300]):
raise ValueError(
error_message.format(
"dimensionality", dimensionality, dimensionalities
Expand Down Expand Up @@ -2092,8 +2091,7 @@ def fetch_development_fmri(

def _filter_func_regressors_by_participants(participants, age_group):
"""Filter functional and regressors based on participants."""
valid_age_groups = ("both", "child", "adult")
if age_group not in valid_age_groups:
if age_group not in (valid_age_groups := ("both", "child", "adult")):
raise ValueError(
f"Wrong value for age_group={age_group}. "
f"Valid arguments are: {valid_age_groups}"
Expand Down Expand Up @@ -2983,8 +2981,7 @@ def fetch_spm_multimodal_fmri(
subject_dir = os.path.join(data_dir, subject_id)

# maybe data_dir already contains the data ?
data = _glob_spm_multimodal_fmri_data(subject_dir)
if data is not None:
if (data := _glob_spm_multimodal_fmri_data(subject_dir)) is not None:
Remi-Gau marked this conversation as resolved.
Show resolved Hide resolved
return data

# No. Download the data
Expand Down Expand Up @@ -3040,8 +3037,7 @@ def _glob_fiac_data():
return Bunch(**_subject_data)

# maybe data_dir already contains the data ?
data = _glob_fiac_data()
if data is not None:
if (data := _glob_fiac_data()) is not None:
Remi-Gau marked this conversation as resolved.
Show resolved Hide resolved
return data

# No. Download the data
Expand Down
3 changes: 1 addition & 2 deletions nilearn/glm/_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -172,8 +172,7 @@ def full_rank(X, cmax=1e15):
"""
U, s, V = spl.svd(X, full_matrices=False)
smax, smin = s.max(), s.min()
cond = smax / smin
if cond < cmax:
if (cond := smax / smin) < cmax:
return X, cond

warn("Matrix is singular at working precision, regularizing...")
Expand Down
3 changes: 1 addition & 2 deletions nilearn/glm/contrasts.py
Original file line number Diff line number Diff line change
Expand Up @@ -93,8 +93,7 @@ def compute_contrast(labels, regression_result, con_val, stat_type=None):
if stat_type is None:
stat_type = "t" if dim == 1 else "F"

acceptable_stat_types = ["t", "F"]
if stat_type not in acceptable_stat_types:
if stat_type not in (acceptable_stat_types := ["t", "F"]):
raise ValueError(
f"'{stat_type}' is not a known contrast type. "
f"Allowed types are {acceptable_stat_types}."
Expand Down
5 changes: 3 additions & 2 deletions nilearn/glm/first_level/experimental_paradigm.py
Original file line number Diff line number Diff line change
Expand Up @@ -161,8 +161,9 @@ def _handle_modulation(events):

def _check_unexpected_columns(events):
"""Warn for each unexpected column that will not be used afterwards."""
unexpected_columns = list(set(events.columns).difference(VALID_FIELDS))
if unexpected_columns:
if unexpected_columns := list(
set(events.columns).difference(VALID_FIELDS)
):
warnings.warn(
"The following unexpected columns "
"in events data will be ignored: "
Expand Down
6 changes: 2 additions & 4 deletions nilearn/interfaces/fmriprep/load_confounds_components.py
Original file line number Diff line number Diff line change
Expand Up @@ -77,8 +77,7 @@ def _load_high_pass(confounds_raw):
DataFrame of high pass filter regressors.
If not present in file, return an empty DataFrame.
"""
high_pass_params = find_confounds(confounds_raw, ["cosine"])
if high_pass_params:
if high_pass_params := find_confounds(confounds_raw, ["cosine"]):
return confounds_raw[high_pass_params]
else:
return pd.DataFrame()
Expand Down Expand Up @@ -287,8 +286,7 @@ def _load_non_steady_state(confounds_raw):
DataFrame of non steady state regressors generated by fMRIPrep.
If none were found, return an empty DataFrame.
"""
nss_outliers = find_confounds(confounds_raw, ["non_steady_state"])
if nss_outliers:
if nss_outliers := find_confounds(confounds_raw, ["non_steady_state"]):
return confounds_raw[nss_outliers]
else:
return pd.DataFrame()
3 changes: 1 addition & 2 deletions nilearn/plotting/img_plotting.py
Original file line number Diff line number Diff line change
Expand Up @@ -906,8 +906,7 @@ def plot_roi(
(4D images)

"""
valid_view_types = ["continuous", "contours"]
if view_type not in valid_view_types:
if view_type not in (valid_view_types := ["continuous", "contours"]):
raise ValueError(
f"Unknown view type: {view_type}. "
f"Valid view types are {valid_view_types}."
Expand Down
3 changes: 1 addition & 2 deletions nilearn/plotting/matrix_plotting.py
Original file line number Diff line number Diff line change
Expand Up @@ -99,8 +99,7 @@ def _sanitize_labels(mat_shape, labels):

def _sanitize_tri(tri):
"""Help for plot_matrix."""
VALID_TRI_VALUES = ("full", "lower", "diag")
if tri not in VALID_TRI_VALUES:
if tri not in (VALID_TRI_VALUES := ("full", "lower", "diag")):
raise ValueError(
"Parameter tri needs to be one of: "
f"{', '.join(VALID_TRI_VALUES)}."
Expand Down
3 changes: 1 addition & 2 deletions nilearn/signal.py
Original file line number Diff line number Diff line change
Expand Up @@ -1014,9 +1014,8 @@ def _sanitize_sample_mask(n_time, n_runs, runs, sample_mask):
def _check_sample_mask_index(i, n_runs, runs, current_mask):
"""Ensure the index in sample mask is valid."""
len_run = sum(i == runs)
len_current_mask = len(current_mask)
# sample_mask longer than signal
if len_current_mask > len_run:
if (len_current_mask := len(current_mask)) > len_run:
raise IndexError(
f"sample_mask {i + 1} of {n_runs} is has more timepoints "
f"than the current run ;sample_mask contains {len_current_mask} "
Expand Down
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -62,7 +62,7 @@ doc = [
"sphinx-design",
"sphinx-gallery",
"sphinxcontrib-bibtex",
"sphinxext-opengraph",
"sphinxext-opengraph"
]
# the following is kept for "backward compatibility"
plotly = ["nilearn[plotting]"]
Expand Down