Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[FIX] validate input value symmetric_cbar #4339

Draft
wants to merge 4 commits into
base: main
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
17 changes: 17 additions & 0 deletions nilearn/_utils/param_validation.py
Original file line number Diff line number Diff line change
Expand Up @@ -83,6 +83,23 @@
return threshold


def check_symmetric_cbar(symmetric_cbar):
"""Check value of symmetric_cbar."""
if symmetric_cbar is None:
symmetric_cbar = "auto"

Check warning on line 89 in nilearn/_utils/param_validation.py

View check run for this annotation

Codecov / codecov/patch

nilearn/_utils/param_validation.py#L89

Added line #L89 was not covered by tests
if not isinstance(symmetric_cbar, (bool, str)):
raise ValueError(
"'symmetric_cbar' must be a boolean or 'auto'.\n"
f"got: {symmetric_cbar}"
)
if isinstance(symmetric_cbar, str) and symmetric_cbar != "auto":
raise ValueError(

Check warning on line 96 in nilearn/_utils/param_validation.py

View check run for this annotation

Codecov / codecov/patch

nilearn/_utils/param_validation.py#L96

Added line #L96 was not covered by tests
"'symmetric_cbar' must be a boolean or 'auto'.\n"
f"got: {symmetric_cbar}"
)
return symmetric_cbar


def get_mask_volume(mask_img):
"""Compute the volume of a brain mask in mm^3.

Expand Down
6 changes: 5 additions & 1 deletion nilearn/plotting/img_plotting.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,7 @@
from .._utils.ndimage import get_border_data
from .._utils.niimg import safe_get_data
from .._utils.numpy_conversions import as_ndarray
from .._utils.param_validation import check_threshold
from .._utils.param_validation import check_symmetric_cbar, check_threshold
from ..datasets import load_mni152_template
from ..image import get_data, iter_img, math_img, new_img_like, resample_to_img
from ..masking import apply_mask, compute_epi_mask
Expand Down Expand Up @@ -1313,6 +1313,8 @@ def plot_stat_map(

stat_map_img = _utils.check_niimg_3d(stat_map_img, dtype="auto")

symmetric_cbar = check_symmetric_cbar(symmetric_cbar)

cbar_vmin, cbar_vmax, vmin, vmax = get_colorbar_and_data_ranges(
safe_get_data(stat_map_img, ensure_finite=True),
vmin=vmin,
Expand Down Expand Up @@ -1455,6 +1457,8 @@ def plot_glass_brain(
cmap(np.linspace(0.5, 1, 256)),
)

symmetric_cbar = check_symmetric_cbar(symmetric_cbar)

if stat_map_img:
stat_map_img = _utils.check_niimg_3d(stat_map_img, dtype="auto")
if plot_abs:
Expand Down
4 changes: 4 additions & 0 deletions nilearn/plotting/surf_plotting.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
from nilearn import image, surface
from nilearn._utils import check_niimg_3d, compare_version, fill_doc
from nilearn._utils.helpers import is_kaleido_installed, is_plotly_installed
from nilearn._utils.param_validation import check_symmetric_cbar
from nilearn.plotting.cm import cold_hot, mix_colormaps
from nilearn.plotting.displays._slicers import _get_cbar_ticks
from nilearn.plotting.html_surface import get_vertexcolor
Expand Down Expand Up @@ -1201,6 +1202,7 @@ def plot_surf_stat_map(surf_mesh, stat_map, bg_map=None,
nilearn.surface.vol_to_surf : For info on the generation of surfaces.

"""
symmetric_cbar = check_symmetric_cbar(symmetric_cbar)
check_extensions(stat_map, DATA_EXTENSIONS, FREESURFER_DATA_EXTENSIONS)
loaded_stat_map = load_surf_data(stat_map)

Expand Down Expand Up @@ -1453,6 +1455,8 @@ def plot_img_on_surf(stat_map, surf_mesh='fsaverage5', mask_img=None,
hemis = _check_hemispheres(hemispheres)
surf_mesh = check_mesh(surf_mesh)

symmetric_cbar = check_symmetric_cbar(symmetric_cbar)

mesh_prefix = "infl" if inflate else "pial"
surf = {
'left': surf_mesh[f'{mesh_prefix}_left'],
Expand Down
11 changes: 11 additions & 0 deletions nilearn/plotting/tests/test_img_plotting/test_plot_stat_map.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,17 @@ def test_plot_stat_map_bad_input(img_3d_mni, tmp_path):
plt.close()


def test_plot_stat_map_invalid_symmetric_cbar(img_3d_mni):
"""Test for bad input for symmetric_cbar."""
with pytest.raises(
ValueError, match="'symmetric_cbar' must be a boolean or 'auto'"
):
plot_stat_map(
img_3d_mni,
symmetric_cbar=1,
)


@pytest.mark.parametrize(
"params", [{}, {"display_mode": "x", "cut_coords": 3}]
)
Expand Down