Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[ENH] add plotting functions to the experimental surface module #4235

Merged
merged 10 commits into from
Jun 3, 2024
66 changes: 19 additions & 47 deletions examples/08_experimental/plot_surface_image_and_maskers.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,39 +21,8 @@

# %%

from typing import Optional, Sequence

from nilearn import plotting
from nilearn.experimental import surface


def plot_surf_img(
img: surface.SurfaceImage,
parts: Optional[Sequence[str]] = None,
mesh: Optional[surface.PolyMesh] = None,
**kwargs,
) -> plt.Figure:
if mesh is None:
mesh = img.mesh
if parts is None:
parts = list(img.data.keys())
fig, axes = plt.subplots(
1,
len(parts),
subplot_kw={"projection": "3d"},
figsize=(4 * len(parts), 4),
)
for ax, mesh_part in zip(axes, parts):
plotting.plot_surf(
mesh[mesh_part],
img.data[mesh_part],
axes=ax,
title=mesh_part,
**kwargs,
)
assert isinstance(fig, plt.Figure)
return fig

from nilearn.experimental import plotting, surface
from nilearn.plotting import plot_matrix

img = surface.fetch_nki()[0]
print(f"NKI image: {img}")
Expand All @@ -66,20 +35,25 @@ def plot_surf_img(
mean_img = masker.inverse_transform(mean_data)
print(f"Image mean: {mean_img}")

plot_surf_img(mean_img)
plotting.show()
plotting.plot_surf(mean_img)
plt.show()

# %%
# Connectivity with a surface atlas and `SurfaceLabelsMasker`
# -----------------------------------------------------------
from nilearn import connectome, plotting
from nilearn import connectome

img = surface.fetch_nki()[0]
print(f"NKI image: {img}")

labels_img, label_names = surface.fetch_destrieux()
print(f"Destrieux image: {labels_img}")
plot_surf_img(labels_img, cmap="gist_ncar", avg_method="median")
plotting.plot_surf(
labels_img,
views=["lateral", "medial"],
cmap="gist_ncar",
avg_method="median",
)

labels_masker = surface.SurfaceLabelsMasker(labels_img, label_names).fit()
masked_data = labels_masker.transform(img)
Expand All @@ -88,17 +62,17 @@ def plot_surf_img(
connectome = (
connectome.ConnectivityMeasure(kind="correlation").fit([masked_data]).mean_
)
plotting.plot_matrix(connectome, labels=labels_masker.label_names_)
plot_matrix(connectome, labels=labels_masker.label_names_)

plotting.show()
plt.show()


# %%
# Using the `Decoder`
# -------------------
import numpy as np

from nilearn import decoding, plotting
from nilearn import decoding
from nilearn._utils import param_validation

# %%
Expand Down Expand Up @@ -131,17 +105,15 @@ def adjust_screening_percentile(screening_percentile, *args, **kwargs):
decoder.fit(img, y)
print("CV scores:", decoder.cv_scores_)

plot_surf_img(decoder.coef_img_[0], threshold=1e-6)
plotting.show()
plotting.plot_surf(decoder.coef_img_[0], threshold=1e-6)
plt.show()

# %%
# Decoding with a scikit-learn `Pipeline`
# ---------------------------------------
import numpy as np
from sklearn import feature_selection, linear_model, pipeline, preprocessing

from nilearn import plotting

img = surface.fetch_nki()[0]
y = np.random.RandomState(0).normal(size=img.shape[0])

Expand All @@ -158,12 +130,12 @@ def adjust_screening_percentile(screening_percentile, *args, **kwargs):
coef_img = decoder[:-1].inverse_transform(np.atleast_2d(decoder[-1].coef_))


vmax = max([np.absolute(dp).max() for dp in coef_img.data.values()])
plot_surf_img(
vmax = max([np.absolute(dp).max() for dp in coef_img.data.parts.values()])
plotting.plot_surf(
coef_img,
cmap="cold_hot",
vmin=-vmax,
vmax=vmax,
threshold=1e-6,
)
plotting.show()
plt.show()
3 changes: 3 additions & 0 deletions nilearn/experimental/plotting/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
from ._surface_plotting import plot_surf

__all__ = ["plot_surf"]
34 changes: 34 additions & 0 deletions nilearn/experimental/plotting/_surface_plotting.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,34 @@
import numpy as np
from matplotlib import pyplot as plt

from nilearn import plotting as old_plotting


def plot_surf(img, parts=None, mesh=None, views=["lateral"], **kwargs):
"""Plot a SurfaceImage.

TODO: docstring.
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

THis one need to be done.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Will do that in the follow up PR

"""
if mesh is None:
mesh = img.mesh

Check warning on line 13 in nilearn/experimental/plotting/_surface_plotting.py

View check run for this annotation

Codecov / codecov/patch

nilearn/experimental/plotting/_surface_plotting.py#L13

Added line #L13 was not covered by tests
if parts is None:
parts = list(img.data.parts.keys())
fig, axes = plt.subplots(

Check warning on line 16 in nilearn/experimental/plotting/_surface_plotting.py

View check run for this annotation

Codecov / codecov/patch

nilearn/experimental/plotting/_surface_plotting.py#L15-L16

Added lines #L15 - L16 were not covered by tests
len(views),
len(parts),
subplot_kw={"projection": "3d"},
figsize=(4 * len(parts), 4),
)
axes = np.atleast_2d(axes)

Check warning on line 22 in nilearn/experimental/plotting/_surface_plotting.py

View check run for this annotation

Codecov / codecov/patch

nilearn/experimental/plotting/_surface_plotting.py#L22

Added line #L22 was not covered by tests
for view, ax_row in zip(views, axes):
for ax, mesh_part in zip(ax_row, parts):
old_plotting.plot_surf(

Check warning on line 25 in nilearn/experimental/plotting/_surface_plotting.py

View check run for this annotation

Codecov / codecov/patch

nilearn/experimental/plotting/_surface_plotting.py#L25

Added line #L25 was not covered by tests
mesh.parts[mesh_part],
img.data.parts[mesh_part],
hemi=mesh_part,
view=view,
axes=ax,
title=mesh_part,
**kwargs,
)
return fig

Check warning on line 34 in nilearn/experimental/plotting/_surface_plotting.py

View check run for this annotation

Codecov / codecov/patch

nilearn/experimental/plotting/_surface_plotting.py#L34

Added line #L34 was not covered by tests
16 changes: 8 additions & 8 deletions nilearn/experimental/surface/_datasets.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,6 @@
from nilearn.experimental.surface import _io
from nilearn.experimental.surface._surface_image import (
FileMesh,
Mesh,
PolyMesh,
SurfaceImage,
)
Expand All @@ -17,14 +16,15 @@
def load_fsaverage(mesh_name: str = "fsaverage5") -> Dict[str, PolyMesh]:
"""Load several fsaverage mesh types for both hemispheres."""
fsaverage = datasets.fetch_surf_fsaverage(mesh_name)
meshes: Dict[str, Dict[str, Mesh]] = {}
meshes: Dict[str, PolyMesh] = {}
renaming = {"pial": "pial", "white": "white_matter", "infl": "inflated"}
for mesh_type, mesh_name in renaming.items():
meshes[mesh_name] = {}
parts = {}
for hemisphere in "left", "right":
meshes[mesh_name][f"{hemisphere}_hemisphere"] = FileMesh(
parts[hemisphere] = FileMesh(
fsaverage[f"{mesh_type}_{hemisphere}"]
)
meshes[mesh_name] = PolyMesh(**parts)
return meshes


Expand All @@ -41,8 +41,8 @@ def fetch_nki(n_subjects=1) -> Sequence[SurfaceImage]:
img = SurfaceImage(
mesh=fsaverage["pial"],
data={
"left_hemisphere": left_data,
"right_hemisphere": right_data,
"left": left_data,
"right": right_data,
},
)
images.append(img)
Expand All @@ -60,8 +60,8 @@ def fetch_destrieux() -> Tuple[SurfaceImage, Dict[int, str]]:
SurfaceImage(
mesh=fsaverage["pial"],
data={
"left_hemisphere": destrieux["map_left"],
"right_hemisphere": destrieux["map_right"],
"left": destrieux["map_left"],
"right": destrieux["map_right"],
},
),
label_names,
Expand Down
28 changes: 16 additions & 12 deletions nilearn/experimental/surface/_maskers.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,18 +16,19 @@

def check_same_n_vertices(mesh_1: PolyMesh, mesh_2: PolyMesh) -> None:
"""Check that 2 meshes have the same keys and that n vertices match."""
keys_1, keys_2 = set(mesh_1.keys()), set(mesh_2.keys())
keys_1, keys_2 = set(mesh_1.parts.keys()), set(mesh_2.parts.keys())
if keys_1 != keys_2:
diff = keys_1.symmetric_difference(keys_2)
raise ValueError(
"Meshes do not have the same keys. " f"Offending keys: {diff}"
)
for key in keys_1:
if mesh_1[key].n_vertices != mesh_2[key].n_vertices:
if mesh_1.parts[key].n_vertices != mesh_2.parts[key].n_vertices:
raise ValueError(
f"Number of vertices do not match for '{key}'."
f"number of vertices in mesh_1: {mesh_1[key].n_vertices}; "
f"in mesh_2: {mesh_2[key].n_vertices}"
"number of vertices in mesh_1: "
f"{mesh_1.parts[key].n_vertices}; "
f"in mesh_2: {mesh_2.parts[key].n_vertices}"
)


Expand Down Expand Up @@ -83,7 +84,8 @@
# TODO: don't store a full array of 1 to mean "no masking"; use some
# sentinel value
mask_data = {
k: np.ones(v.n_vertices, dtype=bool) for (k, v) in img.mesh.items()
k: np.ones(v.n_vertices, dtype=bool)
for (k, v) in img.mesh.parts.items()
}
self.mask_img_ = SurfaceImage(mesh=img.mesh, data=mask_data)

Expand All @@ -110,7 +112,7 @@
assert self.mask_img_ is not None
start, stop = 0, 0
self.slices = {}
for part_name, mask in self.mask_img_.data.items():
for part_name, mask in self.mask_img_.data.parts.items():
assert isinstance(mask, np.ndarray)
stop = start + mask.sum()
self.slices[part_name] = start, stop
Expand Down Expand Up @@ -159,9 +161,9 @@
check_same_n_vertices(self.mask_img_.mesh, img.mesh)
output = np.empty((*img.shape[:-1], self.output_dimension_))
for part_name, (start, stop) in self.slices.items():
mask = self.mask_img_.data[part_name]
mask = self.mask_img_.data.parts[part_name]
assert isinstance(mask, np.ndarray)
output[..., start:stop] = img.data[part_name][..., mask]
output[..., start:stop] = img.data.parts[part_name][..., mask]

# signal cleaning here
output = cache(
Expand Down Expand Up @@ -232,7 +234,7 @@
f"last dimension should be {self.output_dimension_}"
)
data = {}
for part_name, mask in self.mask_img_.data.items():
for part_name, mask in self.mask_img_.data.parts.items():
assert isinstance(mask, np.ndarray)
data[part_name] = np.zeros(
(*masked_img.shape[:-1], mask.shape[0]),
Expand Down Expand Up @@ -279,7 +281,9 @@
) -> None:
self.labels_img = labels_img
self.label_names = label_names
self.labels_data_ = np.concatenate(list(labels_img.data.values()))
self.labels_data_ = np.concatenate(

Check warning on line 284 in nilearn/experimental/surface/_maskers.py

View check run for this annotation

Codecov / codecov/patch

nilearn/experimental/surface/_maskers.py#L284

Added line #L284 was not covered by tests
list(labels_img.data.parts.values())
)
all_labels = set(self.labels_data_.ravel())
all_labels.discard(0)
self.labels_ = np.asarray(list(all_labels))
Expand Down Expand Up @@ -328,7 +332,7 @@
shape: (img data shape, total number of vertices)
"""
check_same_n_vertices(self.labels_img.mesh, img.mesh)
img_data = np.concatenate(list(img.data.values()), axis=-1)
img_data = np.concatenate(list(img.data.parts.values()), axis=-1)

Check warning on line 335 in nilearn/experimental/surface/_maskers.py

View check run for this annotation

Codecov / codecov/patch

nilearn/experimental/surface/_maskers.py#L335

Added line #L335 was not covered by tests
output = np.empty((*img_data.shape[:-1], len(self.labels_)))
for i, label in enumerate(self.labels_):
output[..., i] = img_data[..., self.labels_data_ == label].mean(
Expand Down Expand Up @@ -371,7 +375,7 @@
Mesh and data for both hemispheres.
"""
data = {}
for part_name, labels_part in self.labels_img.data.items():
for part_name, labels_part in self.labels_img.data.parts.items():
data[part_name] = np.zeros(
(*masked_img.shape[:-1], labels_part.shape[0]),
dtype=masked_img.dtype,
Expand Down