Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[ENH] Improve plotting contours for PlotlySurfaceFigure objects by adding add_contours method #3949

Open
wants to merge 17 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
1 change: 1 addition & 0 deletions doc/conf.py
Original file line number Diff line number Diff line change
Expand Up @@ -362,6 +362,7 @@
"pandas": ("https://pandas.pydata.org/pandas-docs/stable/", None),
"nistats": ("https://nistats.github.io", None),
"joblib": ("https://joblib.readthedocs.io/en/latest/", None),
"plotly": ("https://plotly.com/python-api-reference/", None),
}

extlinks = {
Expand Down
35 changes: 34 additions & 1 deletion examples/01_plotting/plot_3d_map_to_surface_projection.py
Original file line number Diff line number Diff line change
Expand Up @@ -134,11 +134,15 @@
##############################################################################
# Display outlines of the regions of interest on top of a statistical map
# -----------------------------------------------------------------------
#
# Regions can be outlined using both engines.
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Maybe add a line clarifying that for the default matplotlib engine used by plot_surf_stat_map, plotting.plot_surf_contours is used.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Makes me also think we should raise an error in plot_surf_contours if plotly object is passed. Now we get the error AttributeError: 'PlotlySurfaceFigure' object has no attribute 'axes' but it would be better to raise one specifying that plot_surf_contours only works with matplotlib. WDYT?

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

yes a more explicit error message would be a good thing for users


figure = plotting.plot_surf_stat_map(fsaverage.infl_right,
texture, hemi='right',
title='Surface right hemisphere',
colorbar=True, threshold=1.,
colorbar=True,
threshold=1.,
bg_on_data=True,
bg_map=fsaverage.sulc_right)

plotting.plot_surf_contours(fsaverage.infl_right, parcellation, labels=labels,
Expand All @@ -147,6 +151,35 @@
colors=['g', 'k'])
plotting.show()

##############################################################################
# The plotly engine allows for enhanced customization of the contours. In
# particular, the lines' width can be modified.
#
# Note that the contours are added with a method of the object that is
# returned by :func:`~nilearn.plotting.plot_surf_stat_map`:
# :meth:`~nilearn.plotting.displays.PlotlySurfaceFigure.add_contours`.
Comment on lines +158 to +160
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
# Note that the contours are added with a method of the object that is
# returned by :func:`~nilearn.plotting.plot_surf_stat_map`:
# :meth:`~nilearn.plotting.displays.PlotlySurfaceFigure.add_contours`.
# Note that the contours are added with
# :meth:`~nilearn.plotting.displays.PlotlySurfaceFigure.add_contours`
# method of a :class:`~nilearn.plotting.displays.PlotlySurfaceFigure` object
# that is returned by :func:`~nilearn.plotting.plot_surf_stat_map`
# when engine is set to "plotly".


figure = plotting.plot_surf_stat_map(fsaverage.infl_right,
texture,
hemi='right',
title='Surface right hemisphere',
colorbar=True,
threshold=1.,
bg_map=fsaverage.sulc_right,
bg_on_data=True,
engine="plotly")

figure.add_contours(roi_map=parcellation,
levels=regions_indices,
labels=labels,
lines=[
{"color": "green", "width": 10},
{"color": "purple"}
]
)

figure.show()

##############################################################################
# Plot with higher-resolution mesh
# --------------------------------
Expand Down
159 changes: 156 additions & 3 deletions nilearn/plotting/displays/_figures.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,16 @@
import numpy as np
from scipy.spatial import distance_matrix

from nilearn.surface.surface import load_surf_data

try:
import plotly.graph_objects as go
except ImportError:
PLOTLY_INSTALLED = False

Check warning on line 9 in nilearn/plotting/displays/_figures.py

View check run for this annotation

Codecov / codecov/patch

nilearn/plotting/displays/_figures.py#L8-L9

Added lines #L8 - L9 were not covered by tests
else:
PLOTLY_INSTALLED = True


class SurfaceFigure:
"""Abstract class for surface figures.

Expand Down Expand Up @@ -36,6 +49,10 @@
else:
self.output_file = output_file

def add_contours(self):
"""Draw boundaries around roi."""
raise NotImplementedError


class PlotlySurfaceFigure(SurfaceFigure):
"""Implementation of a surface figure obtained with `plotly` engine.
Expand All @@ -61,9 +78,7 @@
"""

def __init__(self, figure=None, output_file=None):
try:
import plotly.graph_objects as go
except ImportError:
if not PLOTLY_INSTALLED:
raise ImportError(
"Plotly is required to use `PlotlySurfaceFigure`."
)
Expand Down Expand Up @@ -103,3 +118,141 @@
self._check_output_file(output_file=output_file)
if self.figure is not None:
self.figure.write_image(self.output_file)

def add_contours(self, roi_map, levels=None, labels=None, lines=None):
Remi-Gau marked this conversation as resolved.
Show resolved Hide resolved
"""
Draw boundaries around roi.
Comment on lines +123 to +124
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Just for consistency

Suggested change
"""
Draw boundaries around roi.
"""Draw boundaries around roi.


Parameters
----------
roi_map : :obj:`str` or :class:`numpy.ndarray` or :obj:`list` of \
:class:`numpy.ndarray`
ROI map to be displayed on the surface
mesh, can be a file (valid formats are .gii, .mgz, .nii,
.nii.gz, or FreeSurfer specific files such as .annot or .label),
or a Numpy array with a value for each vertex of the surf_mesh.
The value at each vertex one inside the ROI and zero inside ROI,
or an :obj:`int` giving the label number for atlases.
Comment on lines +134 to +135
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

??

Suggested change
The value at each vertex one inside the ROI and zero inside ROI,
or an :obj:`int` giving the label number for atlases.
The value at each vertex is one inside the ROI and zero outside
the ROI, or an :obj:`int` giving the label number for atlases.


levels : :obj:`list` of :obj:`int`, or :obj:`None`, default=None
A :obj:`list` of indices of the regions that are to be outlined.
Every index needs to correspond to one index in roi_map.
If :obj:`None`, all regions in roi_map are used.

labels : :obj:`list` of :obj:`str` or :obj:`None`, default=None
A :obj:`list` of labels for the individual regions of interest.
Provide :obj:`None` as list entry to skip showing the label of
that region. If :obj:`None`, no labels are used.

lines : :obj:`list` of :obj:`dict` giving the properties of the \
contours, or :obj:`None`, default=None
For valid keys, see :attr:`plotly.graph_objects.Scatter3d.line`.
If length 1, the properties defined in that element will be used
to draw all requested contours.
"""
if levels is None:
levels = np.unique(roi_map)
if labels is None:
labels = [f"Region {i}" for i, _ in enumerate(levels)]
if lines is None:
lines = [None] * len(levels)
elif len(lines) == 1 and len(levels) > 1:
lines *= len(levels)
if not (len(levels) == len(labels)):
raise ValueError(
"levels and labels need to be either the same length or None."
)
if not (len(levels) == len(lines)):
raise ValueError(
"levels and lines need to be either the same length or None."
)
roi = load_surf_data(roi_map)

traces = []
for level, label, line in zip(levels, labels, lines):
parc_idx = np.where(roi == level)[0]
sorted_vertices = self._get_vertices_on_edge(parc_idx)
traces.append(
go.Scatter3d(
x=sorted_vertices[:, 0],
y=sorted_vertices[:, 1],
z=sorted_vertices[:, 2],
mode="lines",
line=line,
name=label,
)
)
self.figure.add_traces(data=traces)

def _get_vertices_on_edge(self, parc_idx):
Remi-Gau marked this conversation as resolved.
Show resolved Hide resolved
"""
Identify which vertices lie on the outer edge of a parcellation.
Comment on lines +188 to +189
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
"""
Identify which vertices lie on the outer edge of a parcellation.
"""Identify which vertices lie on the outer edge of a parcellation.


Parameters
----------
parc_idx : :class:`numpy.ndarray`
Indices of the vertices of the region to be plotted.

Returns
-------
data : :class:`numpy.ndarray`
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
data : :class:`numpy.ndarray`
sorted_vertices : :class:`numpy.ndarray`

(n_vertices, s) x,y,z coordinates of vertices that trace region
of interest.

"""
faces = np.vstack(
[self.figure._data[0].get(d) for d in ["i", "j", "k"]]
).T

# count how many vertices belong to the given parcellation in each face
verts_per_face = np.isin(faces, parc_idx).sum(axis=1)

# test if parcellation forms regions
if np.all(verts_per_face < 2):
raise ValueError("Vertices in parcellation do not form region.")

vertices_on_edge = np.intersect1d(
np.unique(faces[verts_per_face == 2]), parc_idx
)

# now that we know where to draw the lines, we need to know in which
# order. If we pick a vertex to start and move to the closest one, and
# then to the closest remaining one and so forth, we should get the
# whole ROI
coords = np.vstack(
[self.figure._data[0].get(d) for d in ["x", "y", "z"]]
).T
vertices = coords[vertices_on_edge]

# Start with the first vertex
current_vertex = 0
visited_vertices = {current_vertex}

sorted_vertices = [vertices[0]]

# Loop over the remaining vertices in order of distance from the
# current vertex
while len(visited_vertices) < len(vertices):
remaining_vertices = np.array(
[
vertex
for vertex in range(len(vertices))
if vertex not in visited_vertices
]
)
remaining_distances = distance_matrix(
vertices[current_vertex].reshape(1, -1),
vertices[remaining_vertices],
)
closest_index = np.argmin(remaining_distances)
closest_vertex = remaining_vertices[closest_index]
visited_vertices.add(closest_vertex)
sorted_vertices.append(vertices[closest_vertex])
# Move to the closest vertex and repeat the process
current_vertex = closest_vertex

# at the end we append the first one again to close the outline
sorted_vertices.append(vertices[0])
sorted_vertices = np.asarray(sorted_vertices)

return sorted_vertices