Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[ENH] Improve plotting contours for PlotlySurfaceFigure objects by adding add_contours method #3949

Open
wants to merge 17 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from 8 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
35 changes: 34 additions & 1 deletion examples/01_plotting/plot_3d_map_to_surface_projection.py
Original file line number Diff line number Diff line change
Expand Up @@ -134,11 +134,15 @@
##############################################################################
# Display outlines of the regions of interest on top of a statistical map
# -----------------------------------------------------------------------
#
# Regions can be outlined using both engines.
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Maybe add a line clarifying that for the default matplotlib engine used by plot_surf_stat_map, plotting.plot_surf_contours is used.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Makes me also think we should raise an error in plot_surf_contours if plotly object is passed. Now we get the error AttributeError: 'PlotlySurfaceFigure' object has no attribute 'axes' but it would be better to raise one specifying that plot_surf_contours only works with matplotlib. WDYT?

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

yes a more explicit error message would be a good thing for users


figure = plotting.plot_surf_stat_map(fsaverage.infl_right,
texture, hemi='right',
title='Surface right hemisphere',
colorbar=True, threshold=1.,
colorbar=True,
threshold=1.,
bg_on_data=True,
bg_map=fsaverage.sulc_right)

plotting.plot_surf_contours(fsaverage.infl_right, parcellation, labels=labels,
Expand All @@ -147,6 +151,35 @@
colors=['g', 'k'])
plotting.show()

##############################################################################
# The plotly engine allows for enhanced customization of the contours. In
# particular, the lines' width can be modified.
#
# Note that the contours are added with a method of the object that is
# returned by :func:`~nilearn.plotting.plot_surf_stat_map`:
# :meth:`~nilearn.plotting.displays.PlotlySurfaceFigure.add_contours`.
Comment on lines +158 to +160
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
# Note that the contours are added with a method of the object that is
# returned by :func:`~nilearn.plotting.plot_surf_stat_map`:
# :meth:`~nilearn.plotting.displays.PlotlySurfaceFigure.add_contours`.
# Note that the contours are added with
# :meth:`~nilearn.plotting.displays.PlotlySurfaceFigure.add_contours`
# method of a :class:`~nilearn.plotting.displays.PlotlySurfaceFigure` object
# that is returned by :func:`~nilearn.plotting.plot_surf_stat_map`
# when engine is set to "plotly".


figure = plotting.plot_surf_stat_map(fsaverage.infl_right,
texture,
hemi='right',
title='Surface right hemisphere',
colorbar=True,
threshold=1.,
bg_map=fsaverage.sulc_right,
bg_on_data=True,
engine="plotly")

figure.add_contours(roi_map=parcellation,
levels=regions_indices,
labels=labels,
lines=[
{"color": "green", "width": 10},
{"color": "purple"}
]
)

figure.show()

##############################################################################
# Plot with higher-resolution mesh
# --------------------------------
Expand Down
148 changes: 148 additions & 0 deletions nilearn/plotting/displays/_figures.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,9 @@
import numpy as np
from scipy.spatial import distance_matrix

Check warning on line 2 in nilearn/plotting/displays/_figures.py

View check run for this annotation

Codecov / codecov/patch

nilearn/plotting/displays/_figures.py#L1-L2

Added lines #L1 - L2 were not covered by tests

from nilearn.surface.surface import load_surf_data

Check warning on line 4 in nilearn/plotting/displays/_figures.py

View check run for this annotation

Codecov / codecov/patch

nilearn/plotting/displays/_figures.py#L4

Added line #L4 was not covered by tests


class SurfaceFigure:
"""Abstract class for surface figures.

Expand Down Expand Up @@ -36,6 +42,10 @@
else:
self.output_file = output_file

def add_contours(self):

Check warning on line 45 in nilearn/plotting/displays/_figures.py

View check run for this annotation

Codecov / codecov/patch

nilearn/plotting/displays/_figures.py#L45

Added line #L45 was not covered by tests
"""Draw boundaries around roi."""
raise NotImplementedError

Check warning on line 47 in nilearn/plotting/displays/_figures.py

View check run for this annotation

Codecov / codecov/patch

nilearn/plotting/displays/_figures.py#L47

Added line #L47 was not covered by tests


class PlotlySurfaceFigure(SurfaceFigure):
"""Implementation of a surface figure obtained with `plotly` engine.
Expand Down Expand Up @@ -103,3 +113,141 @@
self._check_output_file(output_file=output_file)
if self.figure is not None:
self.figure.write_image(self.output_file)

def add_contours(self, roi_map, levels=None, labels=None, lines=None):

Check warning on line 117 in nilearn/plotting/displays/_figures.py

View check run for this annotation

Codecov / codecov/patch

nilearn/plotting/displays/_figures.py#L117

Added line #L117 was not covered by tests
Remi-Gau marked this conversation as resolved.
Show resolved Hide resolved
"""
Draw boundaries around roi.
Comment on lines +123 to +124
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Just for consistency

Suggested change
"""
Draw boundaries around roi.
"""Draw boundaries around roi.


Parameters
----------
roi_map : str or :class:`numpy.ndarray` or list of
:class:`numpy.ndarray` ROI map to be displayed on the surface
mesh, can be a file (valid formats are .gii, .mgz, .nii,
.nii.gz, or Freesurfer specific files such as .annot or .label),
or a Numpy array with a value for each vertex of the surf_mesh.
The value at each vertex one inside the ROI and zero inside ROI,
or an integer giving the label number for atlases.
ymzayek marked this conversation as resolved.
Show resolved Hide resolved

levels : list of integers, or None, optional
psadil marked this conversation as resolved.
Show resolved Hide resolved
A list of indices of the regions that are to be outlined.
Every index needs to correspond to one index in roi_map.
If None, all regions in roi_map are used.

labels : list of strings or None, or None, optional
psadil marked this conversation as resolved.
Show resolved Hide resolved
A list of labels for the individual regions of interest. Provide
None as list entry to skip showing the label of that region. If
None, no labels are used.

lines : list of dict giving the properties of the contours, or None,
optional. For valid keys, see
psadil marked this conversation as resolved.
Show resolved Hide resolved
:attr:`plotly.graph_objects.Scatter3d.line`. If length 1, the
psadil marked this conversation as resolved.
Show resolved Hide resolved
properties defined in that element will be used to draw all
requested contours.
"""
import plotly.graph_objects as go

Check warning on line 147 in nilearn/plotting/displays/_figures.py

View check run for this annotation

Codecov / codecov/patch

nilearn/plotting/displays/_figures.py#L147

Added line #L147 was not covered by tests
ymzayek marked this conversation as resolved.
Show resolved Hide resolved

if levels is None:
levels = np.unique(roi_map)

Check warning on line 150 in nilearn/plotting/displays/_figures.py

View check run for this annotation

Codecov / codecov/patch

nilearn/plotting/displays/_figures.py#L150

Added line #L150 was not covered by tests
if labels is None:
labels = [f"Region {i}" for i, _ in enumerate(levels)]
if lines is None:
lines = [None] * len(levels)

Check warning on line 154 in nilearn/plotting/displays/_figures.py

View check run for this annotation

Codecov / codecov/patch

nilearn/plotting/displays/_figures.py#L154

Added line #L154 was not covered by tests
elif len(lines) == 1 and len(levels) > 1:
lines *= len(levels)

Check warning on line 156 in nilearn/plotting/displays/_figures.py

View check run for this annotation

Codecov / codecov/patch

nilearn/plotting/displays/_figures.py#L156

Added line #L156 was not covered by tests
if not (len(levels) == len(labels)):
raise ValueError(

Check warning on line 158 in nilearn/plotting/displays/_figures.py

View check run for this annotation

Codecov / codecov/patch

nilearn/plotting/displays/_figures.py#L158

Added line #L158 was not covered by tests
"levels and labels need to be either the same length or None."
)
if not (len(levels) == len(lines)):
raise ValueError(

Check warning on line 162 in nilearn/plotting/displays/_figures.py

View check run for this annotation

Codecov / codecov/patch

nilearn/plotting/displays/_figures.py#L162

Added line #L162 was not covered by tests
"levels and lines need to be either the same length or None."
)
roi = load_surf_data(roi_map)

Check warning on line 165 in nilearn/plotting/displays/_figures.py

View check run for this annotation

Codecov / codecov/patch

nilearn/plotting/displays/_figures.py#L165

Added line #L165 was not covered by tests

traces = []

Check warning on line 167 in nilearn/plotting/displays/_figures.py

View check run for this annotation

Codecov / codecov/patch

nilearn/plotting/displays/_figures.py#L167

Added line #L167 was not covered by tests
for level, label, line in zip(levels, labels, lines):
parc_idx = np.where(roi == level)[0]
sorted_vertices = self._get_vertices_on_edge(parc_idx)
traces.append(

Check warning on line 171 in nilearn/plotting/displays/_figures.py

View check run for this annotation

Codecov / codecov/patch

nilearn/plotting/displays/_figures.py#L169-L171

Added lines #L169 - L171 were not covered by tests
go.Scatter3d(
x=sorted_vertices[:, 0],
y=sorted_vertices[:, 1],
z=sorted_vertices[:, 2],
mode="lines",
line=line,
name=label,
)
)
self.figure.add_traces(data=traces)

Check warning on line 181 in nilearn/plotting/displays/_figures.py

View check run for this annotation

Codecov / codecov/patch

nilearn/plotting/displays/_figures.py#L181

Added line #L181 was not covered by tests

def _get_vertices_on_edge(self, parc_idx):

Check warning on line 183 in nilearn/plotting/displays/_figures.py

View check run for this annotation

Codecov / codecov/patch

nilearn/plotting/displays/_figures.py#L183

Added line #L183 was not covered by tests
Remi-Gau marked this conversation as resolved.
Show resolved Hide resolved
"""
Identify which vertices lie on the outer edge of a parcellation.
Comment on lines +188 to +189
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
"""
Identify which vertices lie on the outer edge of a parcellation.
"""Identify which vertices lie on the outer edge of a parcellation.


Parameters
----------
parc_idx : numpy.ndarray, indices of the vertices of the region to be
plotted.

Returns
-------
data : :class:`numpy.ndarray` (n_vertices, s) x,y,z coordinates of
vertices that trace region of interest.

"""
faces = np.vstack(
[self.figure._data[0].get(d) for d in ["i", "j", "k"]]
).T

# count how many vertices belong to the given parcellation in each face
verts_per_face = np.isin(faces, parc_idx).sum(axis=1)

Check warning on line 203 in nilearn/plotting/displays/_figures.py

View check run for this annotation

Codecov / codecov/patch

nilearn/plotting/displays/_figures.py#L203

Added line #L203 was not covered by tests

# test if parcellation forms regions
if np.all(verts_per_face < 2):
raise ValueError("Vertices in parcellation do not form region.")

Check warning on line 207 in nilearn/plotting/displays/_figures.py

View check run for this annotation

Codecov / codecov/patch

nilearn/plotting/displays/_figures.py#L207

Added line #L207 was not covered by tests

vertices_on_edge = np.intersect1d(

Check warning on line 209 in nilearn/plotting/displays/_figures.py

View check run for this annotation

Codecov / codecov/patch

nilearn/plotting/displays/_figures.py#L209

Added line #L209 was not covered by tests
np.unique(faces[verts_per_face == 2]), parc_idx
)

# now that we know where to draw the lines, we need to know in which
# order. If we pick a vertex to start and move to the closest one, and
# then to the closest remaining one and so forth, we should get the
# whole ROI
coords = np.vstack(
[self.figure._data[0].get(d) for d in ["x", "y", "z"]]
).T
vertices = coords[vertices_on_edge]

Check warning on line 220 in nilearn/plotting/displays/_figures.py

View check run for this annotation

Codecov / codecov/patch

nilearn/plotting/displays/_figures.py#L220

Added line #L220 was not covered by tests

# Start with the first vertex
current_vertex = 0
visited_vertices = {current_vertex}

Check warning on line 224 in nilearn/plotting/displays/_figures.py

View check run for this annotation

Codecov / codecov/patch

nilearn/plotting/displays/_figures.py#L223-L224

Added lines #L223 - L224 were not covered by tests

sorted_vertices = [vertices[0]]

Check warning on line 226 in nilearn/plotting/displays/_figures.py

View check run for this annotation

Codecov / codecov/patch

nilearn/plotting/displays/_figures.py#L226

Added line #L226 was not covered by tests

# Loop over the remaining vertices in order of distance from the
# current vertex
while len(visited_vertices) < len(vertices):
remaining_vertices = np.array(
[
vertex
for vertex in range(len(vertices))
if vertex not in visited_vertices
]
)
remaining_distances = distance_matrix(

Check warning on line 238 in nilearn/plotting/displays/_figures.py

View check run for this annotation

Codecov / codecov/patch

nilearn/plotting/displays/_figures.py#L238

Added line #L238 was not covered by tests
vertices[current_vertex].reshape(1, -1),
vertices[remaining_vertices],
)
closest_index = np.argmin(remaining_distances)
closest_vertex = remaining_vertices[closest_index]
visited_vertices.add(closest_vertex)
sorted_vertices.append(vertices[closest_vertex])

Check warning on line 245 in nilearn/plotting/displays/_figures.py

View check run for this annotation

Codecov / codecov/patch

nilearn/plotting/displays/_figures.py#L242-L245

Added lines #L242 - L245 were not covered by tests
# Move to the closest vertex and repeat the process
current_vertex = closest_vertex

Check warning on line 247 in nilearn/plotting/displays/_figures.py

View check run for this annotation

Codecov / codecov/patch

nilearn/plotting/displays/_figures.py#L247

Added line #L247 was not covered by tests

# at the end we append the first one again to close the outline
sorted_vertices.append(vertices[0])
sorted_vertices = np.asarray(sorted_vertices)

Check warning on line 251 in nilearn/plotting/displays/_figures.py

View check run for this annotation

Codecov / codecov/patch

nilearn/plotting/displays/_figures.py#L250-L251

Added lines #L250 - L251 were not covered by tests

return sorted_vertices

Check warning on line 253 in nilearn/plotting/displays/_figures.py

View check run for this annotation

Codecov / codecov/patch

nilearn/plotting/displays/_figures.py#L253

Added line #L253 was not covered by tests