Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[FIX] change creation of region_names and add test #4360

Open
wants to merge 13 commits into
base: main
Choose a base branch
from
47 changes: 35 additions & 12 deletions nilearn/maskers/nifti_labels_masker.py
Original file line number Diff line number Diff line change
Expand Up @@ -503,6 +503,39 @@ def fit(self, imgs=None, y=None):
msg = f"loading data from {repr}"
_utils.logger.log(msg=msg, verbose=self.verbose)
self.labels_img_ = _utils.check_niimg_3d(self.labels_img)

# create region_id_name_ dictionary
self.region_id_name_ = None
if self.labels is not None:
known_backgrounds = {"background", "Background"}
initial_region_ids = [
region_id
for region_id in np.unique(
_utils.niimg.safe_get_data(self.labels_img_)
)
if region_id != self.background_label
]
initial_region_names = [
region_name
for region_name in self.labels
if region_name not in known_backgrounds
]

if len(initial_region_ids) != len(initial_region_names):
warnings.warn(
"Number of regions in the labels image "
"does not match the number of labels provided.",
stacklevel=3,
)
# if number of regions in the labels image is more
# than the number of labels provided, then we cannot
# create region_id_name_ dictionary
if len(initial_region_ids) <= len(initial_region_names):
self.region_id_name_ = {
region_id: initial_region_names[i]
for i, region_id in enumerate(initial_region_ids)
}

if self.mask_img is not None:
repr = _utils._repr_niimgs(
self.mask_img, shorten=(not self.verbose)
Expand Down Expand Up @@ -761,20 +794,10 @@ def transform_single_imgs(self, imgs, confounds=None, sample_mask=None):
self.labels_, tolerant=True, resampling_done=True
)

if self.labels is not None:

# Keep track if background was explicitly passed as a label
# background should always be explicitly passed in the labels
# to avoid this.
lower_case_labels = {x.lower() for x in self.labels}
known_backgrounds = {"background"}
background_in_labels = any(
known_backgrounds.intersection(lower_case_labels)
)
offset = 1 if background_in_labels else 0
if self.region_id_name_ is not None:

self.region_names_ = {
key: self.labels[key + offset]
key: self.region_id_name_[region_id]
for key, region_id in region_ids.items()
if region_id != self.background_label
}
Expand Down
101 changes: 101 additions & 0 deletions nilearn/maskers/tests/test_nifti_labels_masker.py
Original file line number Diff line number Diff line change
Expand Up @@ -841,6 +841,107 @@ def check_region_names_after_fit(
assert region_names_after_fit == region_names


@pytest.mark.parametrize(
"background",
[
None,
"background",
"Background",
],
)
@pytest.mark.parametrize(
"affine_data",
[
None, # no resampling
np.diag(
(4, 4, 4, 4) # with resampling
), # region_names_ matches signals after resampling drops labels
],
)
@pytest.mark.parametrize(
"masking",
[
False, # no masking
True, # with masking
],
)
@pytest.mark.parametrize(
"keep_masked_labels",
[
False,
True,
],
)
def test_region_names_ids_match_after_fit(
shape_3d_default,
affine_eye,
background,
affine_data,
n_regions,
masking,
keep_masked_labels,
):
"""Test that the same region names and ids correspond after fit."""
if affine_data is None:
# no resampling
affine_data = affine_eye
fmri_img, _ = generate_random_img(shape_3d_default, affine=affine_data)
labels_img = generate_labeled_regions(
shape_3d_default[:3],
affine=affine_eye,
n_regions=n_regions,
)

region_names = generate_labels(n_regions, background=background)
region_ids = [region_id for region_id in np.unique(get_data(labels_img))]

if masking:
# create a mask_img with 3 regions
labels_data = get_data(labels_img)
mask_data = (
(labels_data == 1) + (labels_data == 2) + (labels_data == 5)
)
mask_img = Nifti1Image(mask_data.astype(np.int8), labels_img.affine)
else:
mask_img = None

masker = NiftiLabelsMasker(
labels_img,
labels=region_names,
resampling_target="data",
mask_img=mask_img,
keep_masked_labels=keep_masked_labels,
)

_ = masker.fit().transform(fmri_img)

check_region_names_ids_match_after_fit(
masker, region_names, region_ids, background
)


def check_region_names_ids_match_after_fit(
masker, region_names, region_ids, background
):
"""Check the region names and ids correspondence.

Check that the same region names and ids correspond to each other
after fit by comparing with before fit.
"""
# region_ids includes background, so we make
# sure that the region_names also include it
if not background:
region_names.insert(0, "background")
# if they don't have the same length, we can't compare them
if len(region_names) == len(region_ids):
region_id_names = {
region_id: region_names[i]
for i, region_id in enumerate(region_ids)
}
for key, region_name in masker.region_names_.items():
assert region_id_names[masker.region_ids_[key]] == region_name


def generate_labels(n_regions, background=True):
labels = []
if background:
Expand Down