diff --git a/cellpose/__main__.py b/cellpose/__main__.py index b9a33218..f874ff97 100644 --- a/cellpose/__main__.py +++ b/cellpose/__main__.py @@ -250,20 +250,22 @@ def main(): args.mxnet = False if not bacterial: model = models.Cellpose(gpu=gpu, device=device, model_type=args.pretrained_model, - torch=(not args.mxnet),omni=args.omni) + torch=(not args.mxnet),omni=args.omni, net_avg=(not args.fast_mode and not args.no_net_avg)) else: cpmodel_path = models.model_path(args.pretrained_model, 0, True) model = models.CellposeModel(gpu=gpu, device=device, pretrained_model=cpmodel_path, torch=True, - nclasses=args.nclasses,omni=args.omni) + nclasses=args.nclasses,omni=args.omni, + net_avg=False) else: if args.all_channels: channels = None model = models.CellposeModel(gpu=gpu, device=device, pretrained_model=cpmodel_path, torch=True, - nclasses=args.nclasses,omni=args.omni) + nclasses=args.nclasses,omni=args.omni, + net_avg=False) # omni changes not implemented for mxnet. Full parity for cpu/gpu in pytorch. @@ -309,6 +311,7 @@ def main(): tqdm_out = utils.TqdmToLogger(logger,level=logging.INFO) + for image_name in tqdm(image_names, file=tqdm_out): image = io.imread(image_name) out = model.eval(image, channels=channels, diameter=diameter, @@ -326,7 +329,8 @@ def main(): z_axis=args.z_axis, omni=args.omni, anisotropy=args.anisotropy, - verbose=args.verbose) + verbose=args.verbose, + model_loaded=True) masks, flows = out[:2] if len(out) > 3: diams = out[-1] diff --git a/cellpose/models.py b/cellpose/models.py index a6385ddc..07ae34a3 100644 --- a/cellpose/models.py +++ b/cellpose/models.py @@ -108,7 +108,8 @@ def __init__(self, gpu=False, model_type='cyto', net_avg=True, device=None, torc self.cp = CellposeModel(device=self.device, gpu=self.gpu, pretrained_model=self.pretrained_model, - diam_mean=self.diam_mean, torch=self.torch, omni=self.omni) + diam_mean=self.diam_mean, torch=self.torch, omni=self.omni, + net_avg=net_avg) self.cp.model_type = model_type # size model not used for bacterial model @@ -125,7 +126,7 @@ def eval(self, x, batch_size=8, channels=None, channel_axis=None, z_axis=None, net_avg=True, augment=False, tile=True, tile_overlap=0.1, resample=True, interp=True, cluster=False, flow_threshold=0.4, mask_threshold=0.0, cellprob_threshold=None, dist_threshold=None, diam_threshold=12., min_size=15, stitch_threshold=0.0, - rescale=None, progress=None, omni=False, verbose=False, transparency=False): + rescale=None, progress=None, omni=False, verbose=False, transparency=False, model_loaded=False): """ run cellpose and get masks Parameters @@ -224,6 +225,9 @@ def eval(self, x, batch_size=8, channels=None, channel_axis=None, z_axis=None, transparency: bool (optional, default False) modulate flow opacity by magnitude instead of brightness (can use flows on any color background) + model_loaded: bool (optional, default False) + internal variable for determining if model has been loaded, used in __main__.py + Returns ------- masks: list of 2D arrays, or single 3D array (if do_3D=True) @@ -304,7 +308,8 @@ def eval(self, x, batch_size=8, channels=None, channel_axis=None, z_axis=None, stitch_threshold=stitch_threshold, omni=omni, verbose=verbose, - transparency=transparency) + transparency=transparency, + model_loaded=model_loaded) models_logger.info('>>>> TOTAL TIME %0.2f sec'%(time.time()-tic0)) return masks, flows, styles, diams @@ -378,9 +383,6 @@ def __init__(self, gpu=False, pretrained_model=False, if nuclear: self.diam_mean = 17. - elif bacterial: - #self.diam_mean = 0. - net_avg = False #'bact' model also has no 1,2,3 # set omni flag to true if the name contains it self.omni = 'omni' in os.path.splitext(Path(pretrained_model_string).name)[0] @@ -413,6 +415,8 @@ def __init__(self, gpu=False, pretrained_model=False, self.pretrained_model = pretrained_model if self.pretrained_model and len(self.pretrained_model)==1: self.net.load_model(self.pretrained_model[0], cpu=(not self.gpu)) + if not self.torch: + self.net.collect_params().grad_req = 'null' ostr = ['off', 'on'] omnistr = ['','_omni'] #toggle by containing omni phrase self.net_type = 'cellpose_residual_{}_style_{}_concatenation_{}{}'.format(ostr[residual_on], @@ -428,7 +432,7 @@ def eval(self, x, batch_size=8, channels=None, channel_axis=None, flow_threshold=0.4, mask_threshold=0.0, diam_threshold=12., cellprob_threshold=None, dist_threshold=None, compute_masks=True, min_size=15, stitch_threshold=0.0, progress=None, omni=False, - calc_trace=False, verbose=False, transparency=False, loop_run=False): + calc_trace=False, verbose=False, transparency=False, loop_run=False, model_loaded=False): """ segment list of images x, or 4D array - Z x nchan x Y x X @@ -533,6 +537,9 @@ def eval(self, x, batch_size=8, channels=None, channel_axis=None, loop_run: bool (optional, default False) internal variable for determining if model has been loaded, stops model loading in loop over images + model_loaded: bool (optional, default False) + internal variable for determining if model has been loaded, used in __main__.py + Returns ------- masks: list of 2D arrays, or single 3D array (if do_3D=True) @@ -597,14 +604,15 @@ def eval(self, x, batch_size=8, channels=None, channel_axis=None, calc_trace=calc_trace, verbose=verbose, transparency=transparency, - loop_run=(i>0)) + loop_run=(i>0), + model_loaded=model_loaded) masks.append(maski) flows.append(flowi) styles.append(stylei) return masks, styles, flows else: - if isinstance(self.pretrained_model, list) and not net_avg and not loop_run: + if not model_loaded and (isinstance(self.pretrained_model, list) and not net_avg and not loop_run): self.net.load_model(self.pretrained_model[0], cpu=(not self.gpu)) if not self.torch: self.net.collect_params().grad_req = 'null'