Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[ENH] add plotting functions to the experimental surface module #4235

Merged
merged 10 commits into from
Jun 3, 2024
94 changes: 44 additions & 50 deletions examples/08_experimental/plot_surface_image_and_maskers.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,40 +21,10 @@
raise RuntimeError("This script needs the matplotlib library")

# %%
import numpy as np

from typing import Optional, Sequence

from nilearn import plotting
from nilearn.experimental import surface


def plot_surf_img(
img: surface.SurfaceImage,
parts: Optional[Sequence[str]] = None,
mesh: Optional[surface.PolyMesh] = None,
**kwargs,
) -> plt.Figure:
if mesh is None:
mesh = img.mesh
if parts is None:
parts = list(img.data.keys())
fig, axes = plt.subplots(
1,
len(parts),
subplot_kw={"projection": "3d"},
figsize=(4 * len(parts), 4),
)
for ax, mesh_part in zip(axes, parts):
plotting.plot_surf(
mesh[mesh_part],
img.data[mesh_part],
axes=ax,
title=mesh_part,
**kwargs,
)
assert isinstance(fig, plt.Figure)
return fig

from nilearn.experimental import plotting, surface
from nilearn.plotting import plot_matrix

img = surface.fetch_nki()[0]
print(f"NKI image: {img}")
Expand All @@ -67,20 +37,49 @@ def plot_surf_img(
mean_img = masker.inverse_transform(mean_data)
print(f"Image mean: {mean_img}")

plot_surf_img(mean_img)
plotting.show()
# let's create a figure with all the views for both hemispheres
views = ["lateral", "medial", "dorsal", "ventral", "anterior", "posterior"]
hemispheres = ["left", "right"]

fig, axes = plt.subplots(
len(views),
len(hemispheres),
subplot_kw={"projection": "3d"},
figsize=(4 * len(hemispheres), 4),
)
axes = np.atleast_2d(axes)

for view, ax_row in zip(views, axes):
for ax, hemi in zip(ax_row, hemispheres):
plotting.plot_surf(
mean_img,
part=hemi,
view=view,
figure=fig,
axes=ax,
title=f"mean image - {hemi} - {view}",
colorbar=False,
cmap="bwr",
symmetric_cmap=True,
)
fig.set_size_inches(6, 8)

plt.show()

# %%
# Connectivity with a surface atlas and `SurfaceLabelsMasker`
# -----------------------------------------------------------
from nilearn import connectome, plotting
from nilearn import connectome

img = surface.fetch_nki()[0]
print(f"NKI image: {img}")

labels_img, label_names = surface.fetch_destrieux()
print(f"Destrieux image: {labels_img}")
plot_surf_img(labels_img, cmap="gist_ncar", avg_method="median")
plotting.plot_surf(
labels_img,
avg_method="median",
)

labels_masker = surface.SurfaceLabelsMasker(labels_img, label_names).fit()
masked_data = labels_masker.transform(img)
Expand All @@ -89,17 +88,15 @@ def plot_surf_img(
connectome = (
connectome.ConnectivityMeasure(kind="correlation").fit([masked_data]).mean_
)
plotting.plot_matrix(connectome, labels=labels_masker.label_names_)
plot_matrix(connectome, labels=labels_masker.label_names_)

plotting.show()
plt.show()


# %%
# Using the `Decoder`
# -------------------
import numpy as np

from nilearn import decoding, plotting
from nilearn import decoding
from nilearn._utils import param_validation

# %%
Expand Down Expand Up @@ -132,17 +129,14 @@ def adjust_screening_percentile(screening_percentile, *args, **kwargs):
decoder.fit(img, y)
print("CV scores:", decoder.cv_scores_)

plot_surf_img(decoder.coef_img_[0], threshold=1e-6)
plotting.show()
plotting.plot_surf(decoder.coef_img_[0], threshold=1e-6)
plt.show()

# %%
# Decoding with a scikit-learn `Pipeline`
# ---------------------------------------
import numpy as np
from sklearn import feature_selection, linear_model, pipeline, preprocessing

from nilearn import plotting

img = surface.fetch_nki()[0]
y = np.random.RandomState(0).normal(size=img.shape[0])

Expand All @@ -159,12 +153,12 @@ def adjust_screening_percentile(screening_percentile, *args, **kwargs):
coef_img = decoder[:-1].inverse_transform(np.atleast_2d(decoder[-1].coef_))


vmax = max([np.absolute(dp).max() for dp in coef_img.data.values()])
plot_surf_img(
vmax = max([np.absolute(dp).max() for dp in coef_img.data.parts.values()])
plotting.plot_surf(
coef_img,
cmap="cold_hot",
vmin=-vmax,
vmax=vmax,
threshold=1e-6,
)
plotting.show()
plt.show()
2 changes: 1 addition & 1 deletion nilearn/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@
try:
import matplotlib # noqa: F401
except ImportError:
collect_ignore.extend(["plotting", "reporting"])
collect_ignore.extend(["plotting", "reporting", "experimental/plotting"])
matplotlib = None


Expand Down
2 changes: 1 addition & 1 deletion nilearn/experimental/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,4 +16,4 @@
Use those features at your own risks!
"""

__all__ = ["surface"]
__all__ = ["surface", "plotting"]
3 changes: 3 additions & 0 deletions nilearn/experimental/plotting/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
from ._surface_plotting import plot_surf

__all__ = ["plot_surf"]
36 changes: 36 additions & 0 deletions nilearn/experimental/plotting/_surface_plotting.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,36 @@
from __future__ import annotations

from nilearn import plotting as old_plotting
from nilearn.experimental.surface import SurfaceImage


def plot_surf(
img, part: str | None = None, mesh=None, view: str | None = None, **kwargs
):
"""Plot a SurfaceImage.

TODO: docstring.
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

THis one need to be done.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Will do that in the follow up PR

"""
if not isinstance(img, SurfaceImage):
return old_plotting.plot_surf(

Check warning on line 15 in nilearn/experimental/plotting/_surface_plotting.py

View check run for this annotation

Codecov / codecov/patch

nilearn/experimental/plotting/_surface_plotting.py#L15

Added line #L15 was not covered by tests
surf_mesh=mesh,
surf_map=img,
hemi=part,
**kwargs,
)

if mesh is None:
mesh = img.mesh

Check warning on line 23 in nilearn/experimental/plotting/_surface_plotting.py

View check run for this annotation

Codecov / codecov/patch

nilearn/experimental/plotting/_surface_plotting.py#L23

Added line #L23 was not covered by tests
if part is None:
# only take the first hemisphere by default
part = list(img.data.parts.keys())[0]

Check warning on line 26 in nilearn/experimental/plotting/_surface_plotting.py

View check run for this annotation

Codecov / codecov/patch

nilearn/experimental/plotting/_surface_plotting.py#L26

Added line #L26 was not covered by tests
if view is None:
view = "lateral"

Check warning on line 28 in nilearn/experimental/plotting/_surface_plotting.py

View check run for this annotation

Codecov / codecov/patch

nilearn/experimental/plotting/_surface_plotting.py#L28

Added line #L28 was not covered by tests

return old_plotting.plot_surf(

Check warning on line 30 in nilearn/experimental/plotting/_surface_plotting.py

View check run for this annotation

Codecov / codecov/patch

nilearn/experimental/plotting/_surface_plotting.py#L30

Added line #L30 was not covered by tests
surf_mesh=mesh.parts[part],
surf_map=img.data.parts[part],
hemi=part,
view=view,
**kwargs,
)
16 changes: 8 additions & 8 deletions nilearn/experimental/surface/_datasets.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,6 @@
from nilearn.experimental.surface import _io
from nilearn.experimental.surface._surface_image import (
FileMesh,
Mesh,
PolyMesh,
SurfaceImage,
)
Expand All @@ -18,14 +17,15 @@
def load_fsaverage(mesh_name: str = "fsaverage5") -> Dict[str, PolyMesh]:
"""Load several fsaverage mesh types for both hemispheres."""
fsaverage = datasets.fetch_surf_fsaverage(mesh_name)
meshes: Dict[str, Dict[str, Mesh]] = {}
meshes: Dict[str, PolyMesh] = {}
renaming = {"pial": "pial", "white": "white_matter", "infl": "inflated"}
for mesh_type, mesh_name in renaming.items():
meshes[mesh_name] = {}
parts = {}
for hemisphere in "left", "right":
meshes[mesh_name][f"{hemisphere}_hemisphere"] = FileMesh(
parts[hemisphere] = FileMesh(
fsaverage[f"{mesh_type}_{hemisphere}"]
)
meshes[mesh_name] = PolyMesh(**parts)
return meshes


Expand All @@ -42,8 +42,8 @@ def fetch_nki(n_subjects=1) -> Sequence[SurfaceImage]:
img = SurfaceImage(
mesh=fsaverage["pial"],
data={
"left_hemisphere": left_data,
"right_hemisphere": right_data,
"left": left_data,
"right": right_data,
},
)
images.append(img)
Expand All @@ -61,8 +61,8 @@ def fetch_destrieux() -> Tuple[SurfaceImage, Dict[int, str]]:
SurfaceImage(
mesh=fsaverage["pial"],
data={
"left_hemisphere": destrieux["map_left"],
"right_hemisphere": destrieux["map_right"],
"left": destrieux["map_left"],
"right": destrieux["map_right"],
},
),
label_names,
Expand Down
28 changes: 16 additions & 12 deletions nilearn/experimental/surface/_maskers.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,18 +17,19 @@

def check_same_n_vertices(mesh_1: PolyMesh, mesh_2: PolyMesh) -> None:
"""Check that 2 meshes have the same keys and that n vertices match."""
keys_1, keys_2 = set(mesh_1.keys()), set(mesh_2.keys())
keys_1, keys_2 = set(mesh_1.parts.keys()), set(mesh_2.parts.keys())
if keys_1 != keys_2:
diff = keys_1.symmetric_difference(keys_2)
raise ValueError(
"Meshes do not have the same keys. " f"Offending keys: {diff}"
)
for key in keys_1:
if mesh_1[key].n_vertices != mesh_2[key].n_vertices:
if mesh_1.parts[key].n_vertices != mesh_2.parts[key].n_vertices:
raise ValueError(
f"Number of vertices do not match for '{key}'."
f"number of vertices in mesh_1: {mesh_1[key].n_vertices}; "
f"in mesh_2: {mesh_2[key].n_vertices}"
"number of vertices in mesh_1: "
f"{mesh_1.parts[key].n_vertices}; "
f"in mesh_2: {mesh_2.parts[key].n_vertices}"
)


Expand Down Expand Up @@ -86,7 +87,8 @@
# TODO: don't store a full array of 1 to mean "no masking"; use some
# sentinel value
mask_data = {
k: np.ones(v.n_vertices, dtype=bool) for (k, v) in img.mesh.items()
k: np.ones(v.n_vertices, dtype=bool)
for (k, v) in img.mesh.parts.items()
}
self.mask_img_ = SurfaceImage(mesh=img.mesh, data=mask_data)

Expand All @@ -113,7 +115,7 @@
assert self.mask_img_ is not None
start, stop = 0, 0
self.slices = {}
for part_name, mask in self.mask_img_.data.items():
for part_name, mask in self.mask_img_.data.parts.items():
assert isinstance(mask, np.ndarray)
stop = start + mask.sum()
self.slices[part_name] = start, stop
Expand Down Expand Up @@ -162,9 +164,9 @@
check_same_n_vertices(self.mask_img_.mesh, img.mesh)
output = np.empty((*img.shape[:-1], self.output_dimension_))
for part_name, (start, stop) in self.slices.items():
mask = self.mask_img_.data[part_name]
mask = self.mask_img_.data.parts[part_name]
assert isinstance(mask, np.ndarray)
output[..., start:stop] = img.data[part_name][..., mask]
output[..., start:stop] = img.data.parts[part_name][..., mask]

# signal cleaning here
output = cache(
Expand Down Expand Up @@ -235,7 +237,7 @@
f"last dimension should be {self.output_dimension_}"
)
data = {}
for part_name, mask in self.mask_img_.data.items():
for part_name, mask in self.mask_img_.data.parts.items():
assert isinstance(mask, np.ndarray)
data[part_name] = np.zeros(
(*masked_img.shape[:-1], mask.shape[0]),
Expand Down Expand Up @@ -282,7 +284,9 @@
) -> None:
self.labels_img = labels_img
self.label_names = label_names
self.labels_data_ = np.concatenate(list(labels_img.data.values()))
self.labels_data_ = np.concatenate(

Check warning on line 287 in nilearn/experimental/surface/_maskers.py

View check run for this annotation

Codecov / codecov/patch

nilearn/experimental/surface/_maskers.py#L287

Added line #L287 was not covered by tests
list(labels_img.data.parts.values())
)
all_labels = set(self.labels_data_.ravel())
all_labels.discard(0)
self.labels_ = np.asarray(list(all_labels))
Expand Down Expand Up @@ -331,7 +335,7 @@
shape: (img data shape, total number of vertices)
"""
check_same_n_vertices(self.labels_img.mesh, img.mesh)
img_data = np.concatenate(list(img.data.values()), axis=-1)
img_data = np.concatenate(list(img.data.parts.values()), axis=-1)

Check warning on line 338 in nilearn/experimental/surface/_maskers.py

View check run for this annotation

Codecov / codecov/patch

nilearn/experimental/surface/_maskers.py#L338

Added line #L338 was not covered by tests
output = np.empty((*img_data.shape[:-1], len(self.labels_)))
for i, label in enumerate(self.labels_):
output[..., i] = img_data[..., self.labels_data_ == label].mean(
Expand Down Expand Up @@ -374,7 +378,7 @@
Mesh and data for both hemispheres.
"""
data = {}
for part_name, labels_part in self.labels_img.data.items():
for part_name, labels_part in self.labels_img.data.parts.items():
data[part_name] = np.zeros(
(*masked_img.shape[:-1], labels_part.shape[0]),
dtype=masked_img.dtype,
Expand Down