Skip to content

Commit

Permalink
speeding up command line
Browse files Browse the repository at this point in the history
  • Loading branch information
carsen-stringer committed Jan 25, 2022
1 parent 22c9562 commit 8896f97
Show file tree
Hide file tree
Showing 2 changed files with 25 additions and 13 deletions.
12 changes: 8 additions & 4 deletions cellpose/__main__.py
Expand Up @@ -250,20 +250,22 @@ def main():
args.mxnet = False
if not bacterial:
model = models.Cellpose(gpu=gpu, device=device, model_type=args.pretrained_model,
torch=(not args.mxnet),omni=args.omni)
torch=(not args.mxnet),omni=args.omni, net_avg=(not args.fast_mode and not args.no_net_avg))
else:
cpmodel_path = models.model_path(args.pretrained_model, 0, True)
model = models.CellposeModel(gpu=gpu, device=device,
pretrained_model=cpmodel_path,
torch=True,
nclasses=args.nclasses,omni=args.omni)
nclasses=args.nclasses,omni=args.omni,
net_avg=False)
else:
if args.all_channels:
channels = None
model = models.CellposeModel(gpu=gpu, device=device,
pretrained_model=cpmodel_path,
torch=True,
nclasses=args.nclasses,omni=args.omni)
nclasses=args.nclasses,omni=args.omni,
net_avg=False)


# omni changes not implemented for mxnet. Full parity for cpu/gpu in pytorch.
Expand Down Expand Up @@ -309,6 +311,7 @@ def main():


tqdm_out = utils.TqdmToLogger(logger,level=logging.INFO)

for image_name in tqdm(image_names, file=tqdm_out):
image = io.imread(image_name)
out = model.eval(image, channels=channels, diameter=diameter,
Expand All @@ -326,7 +329,8 @@ def main():
z_axis=args.z_axis,
omni=args.omni,
anisotropy=args.anisotropy,
verbose=args.verbose)
verbose=args.verbose,
model_loaded=True)
masks, flows = out[:2]
if len(out) > 3:
diams = out[-1]
Expand Down
26 changes: 17 additions & 9 deletions cellpose/models.py
Expand Up @@ -108,7 +108,8 @@ def __init__(self, gpu=False, model_type='cyto', net_avg=True, device=None, torc

self.cp = CellposeModel(device=self.device, gpu=self.gpu,
pretrained_model=self.pretrained_model,
diam_mean=self.diam_mean, torch=self.torch, omni=self.omni)
diam_mean=self.diam_mean, torch=self.torch, omni=self.omni,
net_avg=net_avg)
self.cp.model_type = model_type

# size model not used for bacterial model
Expand All @@ -125,7 +126,7 @@ def eval(self, x, batch_size=8, channels=None, channel_axis=None, z_axis=None,
net_avg=True, augment=False, tile=True, tile_overlap=0.1, resample=True, interp=True, cluster=False,
flow_threshold=0.4, mask_threshold=0.0, cellprob_threshold=None, dist_threshold=None,
diam_threshold=12., min_size=15, stitch_threshold=0.0,
rescale=None, progress=None, omni=False, verbose=False, transparency=False):
rescale=None, progress=None, omni=False, verbose=False, transparency=False, model_loaded=False):
""" run cellpose and get masks
Parameters
Expand Down Expand Up @@ -224,6 +225,9 @@ def eval(self, x, batch_size=8, channels=None, channel_axis=None, z_axis=None,
transparency: bool (optional, default False)
modulate flow opacity by magnitude instead of brightness (can use flows on any color background)
model_loaded: bool (optional, default False)
internal variable for determining if model has been loaded, used in __main__.py
Returns
-------
masks: list of 2D arrays, or single 3D array (if do_3D=True)
Expand Down Expand Up @@ -304,7 +308,8 @@ def eval(self, x, batch_size=8, channels=None, channel_axis=None, z_axis=None,
stitch_threshold=stitch_threshold,
omni=omni,
verbose=verbose,
transparency=transparency)
transparency=transparency,
model_loaded=model_loaded)
models_logger.info('>>>> TOTAL TIME %0.2f sec'%(time.time()-tic0))

return masks, flows, styles, diams
Expand Down Expand Up @@ -378,9 +383,6 @@ def __init__(self, gpu=False, pretrained_model=False,

if nuclear:
self.diam_mean = 17.
elif bacterial:
#self.diam_mean = 0.
net_avg = False #'bact' model also has no 1,2,3

# set omni flag to true if the name contains it
self.omni = 'omni' in os.path.splitext(Path(pretrained_model_string).name)[0]
Expand Down Expand Up @@ -413,6 +415,8 @@ def __init__(self, gpu=False, pretrained_model=False,
self.pretrained_model = pretrained_model
if self.pretrained_model and len(self.pretrained_model)==1:
self.net.load_model(self.pretrained_model[0], cpu=(not self.gpu))
if not self.torch:
self.net.collect_params().grad_req = 'null'
ostr = ['off', 'on']
omnistr = ['','_omni'] #toggle by containing omni phrase
self.net_type = 'cellpose_residual_{}_style_{}_concatenation_{}{}'.format(ostr[residual_on],
Expand All @@ -428,7 +432,7 @@ def eval(self, x, batch_size=8, channels=None, channel_axis=None,
flow_threshold=0.4, mask_threshold=0.0, diam_threshold=12.,
cellprob_threshold=None, dist_threshold=None,
compute_masks=True, min_size=15, stitch_threshold=0.0, progress=None, omni=False,
calc_trace=False, verbose=False, transparency=False, loop_run=False):
calc_trace=False, verbose=False, transparency=False, loop_run=False, model_loaded=False):
"""
segment list of images x, or 4D array - Z x nchan x Y x X
Expand Down Expand Up @@ -533,6 +537,9 @@ def eval(self, x, batch_size=8, channels=None, channel_axis=None,
loop_run: bool (optional, default False)
internal variable for determining if model has been loaded, stops model loading in loop over images
model_loaded: bool (optional, default False)
internal variable for determining if model has been loaded, used in __main__.py
Returns
-------
masks: list of 2D arrays, or single 3D array (if do_3D=True)
Expand Down Expand Up @@ -597,14 +604,15 @@ def eval(self, x, batch_size=8, channels=None, channel_axis=None,
calc_trace=calc_trace,
verbose=verbose,
transparency=transparency,
loop_run=(i>0))
loop_run=(i>0),
model_loaded=model_loaded)
masks.append(maski)
flows.append(flowi)
styles.append(stylei)
return masks, styles, flows

else:
if isinstance(self.pretrained_model, list) and not net_avg and not loop_run:
if not model_loaded and (isinstance(self.pretrained_model, list) and not net_avg and not loop_run):
self.net.load_model(self.pretrained_model[0], cpu=(not self.gpu))
if not self.torch:
self.net.collect_params().grad_req = 'null'
Expand Down

0 comments on commit 8896f97

Please sign in to comment.