-
Notifications
You must be signed in to change notification settings - Fork 575
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[ENH] Improve plotting contours for PlotlySurfaceFigure
objects by adding add_contours
method
#3949
base: main
Are you sure you want to change the base?
Changes from all commits
a573742
ca73877
1c9d024
fc89686
4739ca0
8e793fb
e395ba2
78bdfda
9753884
ef4edf2
b41cd16
b681315
e3fc2bb
9d12271
c256ed9
84f9885
cb7f960
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change | ||||||||||||||||
---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|
|
@@ -134,11 +134,15 @@ | |||||||||||||||||
############################################################################## | ||||||||||||||||||
# Display outlines of the regions of interest on top of a statistical map | ||||||||||||||||||
# ----------------------------------------------------------------------- | ||||||||||||||||||
# | ||||||||||||||||||
# Regions can be outlined using both engines. | ||||||||||||||||||
|
||||||||||||||||||
figure = plotting.plot_surf_stat_map(fsaverage.infl_right, | ||||||||||||||||||
texture, hemi='right', | ||||||||||||||||||
title='Surface right hemisphere', | ||||||||||||||||||
colorbar=True, threshold=1., | ||||||||||||||||||
colorbar=True, | ||||||||||||||||||
threshold=1., | ||||||||||||||||||
bg_on_data=True, | ||||||||||||||||||
bg_map=fsaverage.sulc_right) | ||||||||||||||||||
|
||||||||||||||||||
plotting.plot_surf_contours(fsaverage.infl_right, parcellation, labels=labels, | ||||||||||||||||||
|
@@ -147,6 +151,35 @@ | |||||||||||||||||
colors=['g', 'k']) | ||||||||||||||||||
plotting.show() | ||||||||||||||||||
|
||||||||||||||||||
############################################################################## | ||||||||||||||||||
# The plotly engine allows for enhanced customization of the contours. In | ||||||||||||||||||
# particular, the lines' width can be modified. | ||||||||||||||||||
# | ||||||||||||||||||
# Note that the contours are added with a method of the object that is | ||||||||||||||||||
# returned by :func:`~nilearn.plotting.plot_surf_stat_map`: | ||||||||||||||||||
# :meth:`~nilearn.plotting.displays.PlotlySurfaceFigure.add_contours`. | ||||||||||||||||||
Comment on lines
+158
to
+160
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
Suggested change
|
||||||||||||||||||
|
||||||||||||||||||
figure = plotting.plot_surf_stat_map(fsaverage.infl_right, | ||||||||||||||||||
texture, | ||||||||||||||||||
hemi='right', | ||||||||||||||||||
title='Surface right hemisphere', | ||||||||||||||||||
colorbar=True, | ||||||||||||||||||
threshold=1., | ||||||||||||||||||
bg_map=fsaverage.sulc_right, | ||||||||||||||||||
bg_on_data=True, | ||||||||||||||||||
engine="plotly") | ||||||||||||||||||
|
||||||||||||||||||
figure.add_contours(roi_map=parcellation, | ||||||||||||||||||
levels=regions_indices, | ||||||||||||||||||
labels=labels, | ||||||||||||||||||
lines=[ | ||||||||||||||||||
{"color": "green", "width": 10}, | ||||||||||||||||||
{"color": "purple"} | ||||||||||||||||||
] | ||||||||||||||||||
) | ||||||||||||||||||
|
||||||||||||||||||
figure.show() | ||||||||||||||||||
|
||||||||||||||||||
############################################################################## | ||||||||||||||||||
# Plot with higher-resolution mesh | ||||||||||||||||||
# -------------------------------- | ||||||||||||||||||
|
Original file line number | Diff line number | Diff line change | ||||||||
---|---|---|---|---|---|---|---|---|---|---|
@@ -1,3 +1,16 @@ | ||||||||||
import numpy as np | ||||||||||
from scipy.spatial import distance_matrix | ||||||||||
|
||||||||||
from nilearn.surface.surface import load_surf_data | ||||||||||
|
||||||||||
try: | ||||||||||
import plotly.graph_objects as go | ||||||||||
except ImportError: | ||||||||||
PLOTLY_INSTALLED = False | ||||||||||
else: | ||||||||||
PLOTLY_INSTALLED = True | ||||||||||
|
||||||||||
|
||||||||||
class SurfaceFigure: | ||||||||||
"""Abstract class for surface figures. | ||||||||||
|
||||||||||
|
@@ -36,6 +49,10 @@ | |||||||||
else: | ||||||||||
self.output_file = output_file | ||||||||||
|
||||||||||
def add_contours(self): | ||||||||||
"""Draw boundaries around roi.""" | ||||||||||
raise NotImplementedError | ||||||||||
|
||||||||||
|
||||||||||
class PlotlySurfaceFigure(SurfaceFigure): | ||||||||||
"""Implementation of a surface figure obtained with `plotly` engine. | ||||||||||
|
@@ -61,9 +78,7 @@ | |||||||||
""" | ||||||||||
|
||||||||||
def __init__(self, figure=None, output_file=None): | ||||||||||
try: | ||||||||||
import plotly.graph_objects as go | ||||||||||
except ImportError: | ||||||||||
if not PLOTLY_INSTALLED: | ||||||||||
raise ImportError( | ||||||||||
"Plotly is required to use `PlotlySurfaceFigure`." | ||||||||||
) | ||||||||||
|
@@ -103,3 +118,141 @@ | |||||||||
self._check_output_file(output_file=output_file) | ||||||||||
if self.figure is not None: | ||||||||||
self.figure.write_image(self.output_file) | ||||||||||
|
||||||||||
def add_contours(self, roi_map, levels=None, labels=None, lines=None): | ||||||||||
Remi-Gau marked this conversation as resolved.
Show resolved
Hide resolved
|
||||||||||
""" | ||||||||||
Draw boundaries around roi. | ||||||||||
Comment on lines
+123
to
+124
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Just for consistency
Suggested change
|
||||||||||
|
||||||||||
Parameters | ||||||||||
---------- | ||||||||||
roi_map : :obj:`str` or :class:`numpy.ndarray` or :obj:`list` of \ | ||||||||||
:class:`numpy.ndarray` | ||||||||||
ROI map to be displayed on the surface | ||||||||||
mesh, can be a file (valid formats are .gii, .mgz, .nii, | ||||||||||
.nii.gz, or FreeSurfer specific files such as .annot or .label), | ||||||||||
or a Numpy array with a value for each vertex of the surf_mesh. | ||||||||||
The value at each vertex one inside the ROI and zero inside ROI, | ||||||||||
or an :obj:`int` giving the label number for atlases. | ||||||||||
Comment on lines
+134
to
+135
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. ??
Suggested change
|
||||||||||
|
||||||||||
levels : :obj:`list` of :obj:`int`, or :obj:`None`, default=None | ||||||||||
A :obj:`list` of indices of the regions that are to be outlined. | ||||||||||
Every index needs to correspond to one index in roi_map. | ||||||||||
If :obj:`None`, all regions in roi_map are used. | ||||||||||
|
||||||||||
labels : :obj:`list` of :obj:`str` or :obj:`None`, default=None | ||||||||||
A :obj:`list` of labels for the individual regions of interest. | ||||||||||
Provide :obj:`None` as list entry to skip showing the label of | ||||||||||
that region. If :obj:`None`, no labels are used. | ||||||||||
|
||||||||||
lines : :obj:`list` of :obj:`dict` giving the properties of the \ | ||||||||||
contours, or :obj:`None`, default=None | ||||||||||
For valid keys, see :attr:`plotly.graph_objects.Scatter3d.line`. | ||||||||||
If length 1, the properties defined in that element will be used | ||||||||||
to draw all requested contours. | ||||||||||
""" | ||||||||||
if levels is None: | ||||||||||
levels = np.unique(roi_map) | ||||||||||
if labels is None: | ||||||||||
labels = [f"Region {i}" for i, _ in enumerate(levels)] | ||||||||||
if lines is None: | ||||||||||
lines = [None] * len(levels) | ||||||||||
elif len(lines) == 1 and len(levels) > 1: | ||||||||||
lines *= len(levels) | ||||||||||
if not (len(levels) == len(labels)): | ||||||||||
raise ValueError( | ||||||||||
"levels and labels need to be either the same length or None." | ||||||||||
) | ||||||||||
if not (len(levels) == len(lines)): | ||||||||||
raise ValueError( | ||||||||||
"levels and lines need to be either the same length or None." | ||||||||||
) | ||||||||||
roi = load_surf_data(roi_map) | ||||||||||
|
||||||||||
traces = [] | ||||||||||
for level, label, line in zip(levels, labels, lines): | ||||||||||
parc_idx = np.where(roi == level)[0] | ||||||||||
sorted_vertices = self._get_vertices_on_edge(parc_idx) | ||||||||||
traces.append( | ||||||||||
go.Scatter3d( | ||||||||||
x=sorted_vertices[:, 0], | ||||||||||
y=sorted_vertices[:, 1], | ||||||||||
z=sorted_vertices[:, 2], | ||||||||||
mode="lines", | ||||||||||
line=line, | ||||||||||
name=label, | ||||||||||
) | ||||||||||
) | ||||||||||
self.figure.add_traces(data=traces) | ||||||||||
|
||||||||||
def _get_vertices_on_edge(self, parc_idx): | ||||||||||
Remi-Gau marked this conversation as resolved.
Show resolved
Hide resolved
|
||||||||||
""" | ||||||||||
Identify which vertices lie on the outer edge of a parcellation. | ||||||||||
Comment on lines
+188
to
+189
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
Suggested change
|
||||||||||
|
||||||||||
Parameters | ||||||||||
---------- | ||||||||||
parc_idx : :class:`numpy.ndarray` | ||||||||||
Indices of the vertices of the region to be plotted. | ||||||||||
|
||||||||||
Returns | ||||||||||
------- | ||||||||||
data : :class:`numpy.ndarray` | ||||||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
Suggested change
|
||||||||||
(n_vertices, s) x,y,z coordinates of vertices that trace region | ||||||||||
of interest. | ||||||||||
|
||||||||||
""" | ||||||||||
faces = np.vstack( | ||||||||||
[self.figure._data[0].get(d) for d in ["i", "j", "k"]] | ||||||||||
).T | ||||||||||
|
||||||||||
# count how many vertices belong to the given parcellation in each face | ||||||||||
verts_per_face = np.isin(faces, parc_idx).sum(axis=1) | ||||||||||
|
||||||||||
# test if parcellation forms regions | ||||||||||
if np.all(verts_per_face < 2): | ||||||||||
raise ValueError("Vertices in parcellation do not form region.") | ||||||||||
|
||||||||||
vertices_on_edge = np.intersect1d( | ||||||||||
np.unique(faces[verts_per_face == 2]), parc_idx | ||||||||||
) | ||||||||||
|
||||||||||
# now that we know where to draw the lines, we need to know in which | ||||||||||
# order. If we pick a vertex to start and move to the closest one, and | ||||||||||
# then to the closest remaining one and so forth, we should get the | ||||||||||
# whole ROI | ||||||||||
coords = np.vstack( | ||||||||||
[self.figure._data[0].get(d) for d in ["x", "y", "z"]] | ||||||||||
).T | ||||||||||
vertices = coords[vertices_on_edge] | ||||||||||
|
||||||||||
# Start with the first vertex | ||||||||||
current_vertex = 0 | ||||||||||
visited_vertices = {current_vertex} | ||||||||||
|
||||||||||
sorted_vertices = [vertices[0]] | ||||||||||
|
||||||||||
# Loop over the remaining vertices in order of distance from the | ||||||||||
# current vertex | ||||||||||
while len(visited_vertices) < len(vertices): | ||||||||||
remaining_vertices = np.array( | ||||||||||
[ | ||||||||||
vertex | ||||||||||
for vertex in range(len(vertices)) | ||||||||||
if vertex not in visited_vertices | ||||||||||
] | ||||||||||
) | ||||||||||
remaining_distances = distance_matrix( | ||||||||||
vertices[current_vertex].reshape(1, -1), | ||||||||||
vertices[remaining_vertices], | ||||||||||
) | ||||||||||
closest_index = np.argmin(remaining_distances) | ||||||||||
closest_vertex = remaining_vertices[closest_index] | ||||||||||
visited_vertices.add(closest_vertex) | ||||||||||
sorted_vertices.append(vertices[closest_vertex]) | ||||||||||
# Move to the closest vertex and repeat the process | ||||||||||
current_vertex = closest_vertex | ||||||||||
|
||||||||||
# at the end we append the first one again to close the outline | ||||||||||
sorted_vertices.append(vertices[0]) | ||||||||||
sorted_vertices = np.asarray(sorted_vertices) | ||||||||||
|
||||||||||
return sorted_vertices |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Maybe add a line clarifying that for the default matplotlib engine used by
plot_surf_stat_map
,plotting.plot_surf_contours
is used.There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Makes me also think we should raise an error in
plot_surf_contours
if plotly object is passed. Now we get the errorAttributeError: 'PlotlySurfaceFigure' object has no attribute 'axes'
but it would be better to raise one specifying thatplot_surf_contours
only works with matplotlib. WDYT?There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
yes a more explicit error message would be a good thing for users