Skip to content

Commit

Permalink
Merge pull request #712 from mrariden/export_multiproc
Browse files Browse the repository at this point in the history
add multiprocessing to outlines_list
  • Loading branch information
carsen-stringer committed May 24, 2023
2 parents d92bc6a + c8697b9 commit 8b53a57
Show file tree
Hide file tree
Showing 2 changed files with 64 additions and 3 deletions.
32 changes: 31 additions & 1 deletion cellpose/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
import numpy as np
import colorsys
import io
from multiprocessing import Pool, cpu_count

from . import metrics

Expand Down Expand Up @@ -219,7 +220,15 @@ def masks_to_outlines(masks):
outlines[vr, vc] = 1
return outlines

def outlines_list(masks):
def outlines_list(masks, multiprocessing=True):
""" get outlines of masks as a list to loop over for plotting
This function is a wrapper for outlines_list_single and outlines_list_multi """
if multiprocessing:
return outlines_list_multi(masks)
else:
return outlines_list_single(masks)

def outlines_list_single(masks):
""" get outlines of masks as a list to loop over for plotting """
outpix=[]
for n in np.unique(masks)[1:]:
Expand All @@ -235,6 +244,27 @@ def outlines_list(masks):
outpix.append(np.zeros((0,2)))
return outpix

def outlines_list_multi(masks, num_processes=None):
""" get outlines of masks as a list to loop over for plotting """
if num_processes is None:
num_processes = cpu_count()

unique_masks = np.unique(masks)[1:]
with Pool(processes=num_processes) as pool:
outpix = pool.map(get_outline_multi, [(masks, n) for n in unique_masks])
return outpix

def get_outline_multi(args):
masks, n = args
mn = masks == n
if mn.sum() > 0:
contours = cv2.findContours(mn.astype(np.uint8), mode=cv2.RETR_EXTERNAL, method=cv2.CHAIN_APPROX_NONE)
contours = contours[-2]
cmax = np.argmax([c.shape[0] for c in contours])
pix = contours[cmax].astype(int).squeeze()
return pix if len(pix) > 4 else np.zeros((0, 2))
return np.zeros((0, 2))

def get_perimeter(points):
""" perimeter of points - npoints x ndim """
if points.shape[0]>4:
Expand Down
35 changes: 33 additions & 2 deletions tests/test_output.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from cellpose import io, models, metrics, plot
from cellpose import io, models, metrics, plot, utils
from pathlib import Path
from subprocess import check_output, STDOUT
import os, shutil
Expand Down Expand Up @@ -100,7 +100,38 @@ def test_cli_3D(data_dir, image_names):
raise ValueError(e)
compare_masks(data_dir, image_names, '3D', model_type)
clear_output(data_dir, image_names)



def test_outlines_list(data_dir, image_names):
""" test both single and multithreaded by comparing them"""
clear_output(data_dir, image_names)
model_type = 'cyto'
channels = [2, 1]
image_name = 'rgb_2D.png'

file_name = str(data_dir.joinpath('2D').joinpath(image_name))
img = io.imread(file_name)

model = models.Cellpose(model_type=model_type)
masks, _, _, _ = model.eval(img, diameter=30, channels=channels, net_avg=False)
outlines_single = utils.outlines_list(masks, multiprocessing=False)
outlines_multi = utils.outlines_list(masks, multiprocessing=True)

assert len(outlines_single) == len(outlines_multi)

# Check that the outlines are the same, but not necessarily in the same order
outlines_matched = [False] * len(outlines_single)
for i, outline_single in enumerate(outlines_single):
for j, outline_multi in enumerate(outlines_multi):
if not outlines_matched[j] and np.array_equal(outline_single, outline_multi):
outlines_matched[j] = True
break
else:
assert False, "Outline not found in outlines_multi: {}".format(outline_single)

assert all(outlines_matched), "Not all outlines in outlines_multi were matched"


def compare_masks(data_dir, image_names, runtype, model_type):
"""
Helper function to check if outputs given by a test are exactly the same
Expand Down

0 comments on commit 8b53a57

Please sign in to comment.