Skip to content

Commit

Permalink
adding option to run user model from gui without full path
Browse files Browse the repository at this point in the history
  • Loading branch information
carsen-stringer committed Apr 6, 2022
1 parent 072b1d5 commit 5aebd43
Show file tree
Hide file tree
Showing 3 changed files with 37 additions and 36 deletions.
68 changes: 34 additions & 34 deletions cellpose/__main__.py
Expand Up @@ -160,18 +160,29 @@ def main():

device, gpu = models.assign_device(use_torch=True, gpu=args.use_gpu)

#define available model names, right now we have three broad categories
builtin_model = np.any([args.pretrained_model==s for s in models.MODEL_NAMES])
cytoplasmic = 'cyto' in args.pretrained_model
nuclear = 'nuclei' in args.pretrained_model
if args.pretrained_model is None or args.pretrained_model == 'None' or args.pretrained_model == 'False' or args.pretrained_model == '0':
pretrained_model = False
else:
pretrained_model = args.pretrained_model

model_type = None
if pretrained_model and not os.path.exists(pretrained_model):
model_type = pretrained_model if pretrained_model is not None else 'cyto'
model_strings = models.get_user_models()
all_models = models.MODEL_NAMES.copy()
all_models.extend(model_strings)
if ~np.any([model_type == s for s in all_models]):
model_type = 'cyto'
logger.warning('pretrained model has incorrect path')

if model_type=='nuclei':
szmean = 17.
else:
szmean = 30.
builtin_size = model_type == 'cyto' or model_type == 'cyto2' or model_type == 'nuclei'

if not args.train and not args.train_size:
tic = time.time()
if not builtin_model:
cpmodel_path = args.pretrained_model
if not os.path.exists(cpmodel_path):
logger.warning('model path does not exist, using cyto model')
args.pretrained_model = 'cyto'

image_names = io.get_image_files(args.dir,
args.mask_filter,
Expand All @@ -185,24 +196,26 @@ def main():
(nimg, cstr0[channels[0]], cstr1[channels[1]]))

# handle built-in model exceptions; bacterial ones get no size model
if builtin_model:
model = models.Cellpose(gpu=gpu, device=device, model_type=args.pretrained_model,
if builtin_size:
model = models.Cellpose(gpu=gpu, device=device, model_type=model_type,
net_avg=(not args.fast_mode or args.net_avg))

else:
if args.all_channels:
channels = None
pretrained_model = None if model_type is not None else pretrained_model
model = models.CellposeModel(gpu=gpu, device=device,
pretrained_model=cpmodel_path,
pretrained_model=pretrained_model,
model_type=model_type,
net_avg=False)

# handle diameters
if args.diameter==0:
if builtin_model:
if builtin_size:
diameter = None
logger.info('>>>> estimating diameter for each image')
else:
logger.info('>>>> not using cyto or nuclei model, cannot auto-estimate diameter')
logger.info('>>>> not using cyto, cyto2, or nuclei model, cannot auto-estimate diameter')
diameter = model.diam_labels
logger.info('>>>> using diameter %0.3f for all images'%diameter)
else:
Expand Down Expand Up @@ -244,14 +257,6 @@ def main():
save_txt=args.save_txt,in_folders=args.in_folders)
logger.info('>>>> completed in %0.3f sec'%(time.time()-tic))
else:
if builtin_model:
cpmodel_path = models.model_path(args.pretrained_model, 0)
if cytoplasmic:
szmean = 30.
elif nuclear:
szmean = 17.
else:
cpmodel_path = os.fspath(args.pretrained_model)

test_dir = None if len(args.test_dir)==0 else args.test_dir
output = io.load_train_test_data(args.dir, test_dir, imf, args.mask_filter, args.unet, args.look_one_level_down)
Expand All @@ -270,28 +275,22 @@ def main():


# model path
if not os.path.exists(cpmodel_path):
szmean = args.diam_mean
if not os.path.exists(pretrained_model) and model_type is None:
if not args.train:
error_message = 'ERROR: model path missing or incorrect - cannot train size model'
logger.critical(error_message)
raise ValueError(error_message)
cpmodel_path = False
pretrained_model = False
logger.info('>>>> training from scratch')

szmean = args.diam_mean
else:
args.diam_mean = szmean
logger.info('>>>> pretrained model %s is being used'%cpmodel_path)
args.residual_on = 1
args.style_on = 1
args.concatenation = 0

if args.train:
logger.info('>>>> during training rescaling images to fixed diameter of %0.1f pixels'%args.diam_mean)

# initialize model
if args.unet:
model = core.UnetModel(device=device,
pretrained_model=cpmodel_path,
pretrained_model=pretrained_model,
diam_mean=szmean,
residual_on=args.residual_on,
style_on=args.style_on,
Expand All @@ -300,7 +299,8 @@ def main():
nchan=nchan)
else:
model = models.CellposeModel(device=device,
pretrained_model=cpmodel_path,
pretrained_model=pretrained_model if model_type is None else None,
model_type=model_type,
diam_mean=szmean,
residual_on=args.residual_on,
style_on=args.style_on,
Expand Down
3 changes: 2 additions & 1 deletion cellpose/models.py
Expand Up @@ -346,13 +346,14 @@ def __init__(self, gpu=False, pretrained_model=False,
residual_on, style_on, concatenation = True, True, False

else:
builtin = False
if pretrained_model:
pretrained_model_string = pretrained_model[0]
params = parse_model_string(pretrained_model_string)
if params is not None:
_, residual_on, style_on, concatenation = params
builtin = False
models_logger.info(f'>>>> loading model {pretrained_model_string}')



# initialize network
Expand Down
2 changes: 1 addition & 1 deletion tests/test_train.py
Expand Up @@ -28,7 +28,7 @@ def test_cli_train(data_dir):
train_dir = str(data_dir.joinpath('2D').joinpath('train'))
model_dir = str(data_dir.joinpath('2D').joinpath('train').joinpath('models'))
shutil.rmtree(model_dir, ignore_errors=True)
cmd = 'python -m cellpose --train --train_size --n_epochs 3 --dir %s --mask_filter _cyto_masks --pretrained_model None --chan 2 --chan2 1 --diameter 40'%train_dir
cmd = 'python -m cellpose --train --train_size --n_epochs 3 --dir %s --mask_filter _cyto_masks --pretrained_model None --chan 2 --chan2 1 --diam_mean 40'%train_dir
try:
cmd_stdout = check_output(cmd, stderr=STDOUT, shell=True).decode()
except Exception as e:
Expand Down

0 comments on commit 5aebd43

Please sign in to comment.