From 5c877c8e6925d6f6ff8e5dfd81a8427928bf5f72 Mon Sep 17 00:00:00 2001 From: Carsen Stringer Date: Sat, 30 Mar 2024 11:16:19 -0400 Subject: [PATCH 1/9] adding support for nchan=3 --- cellpose/dynamics.py | 69 ++++++++++++++++++++-------------------- cellpose/gui/io.py | 8 +++-- cellpose/models.py | 22 +++++++++---- cellpose/resnet_torch.py | 24 ++------------ cellpose/train.py | 16 ++++++---- 5 files changed, 67 insertions(+), 72 deletions(-) diff --git a/cellpose/dynamics.py b/cellpose/dynamics.py index acdf1492..72496cf1 100644 --- a/cellpose/dynamics.py +++ b/cellpose/dynamics.py @@ -74,9 +74,8 @@ def _extend_centers_gpu(neighbors, meds, isneighbor, shape, n_iter=200, device = torch.device("cuda") T = torch.zeros(shape, dtype=torch.double, device=device) - for i in range(n_iter): - T[meds[:, 0], meds[:, 1]] += 1 + T[tuple(meds.T)] += 1 Tneigh = T[tuple(neighbors)] Tneigh *= isneighbor T[tuple(neighbors[:, 0])] = Tneigh.mean(axis=0) @@ -90,11 +89,11 @@ def _extend_centers_gpu(neighbors, meds, isneighbor, shape, n_iter=200, del grads mu_torch = np.stack((dy.cpu().squeeze(0), dx.cpu().squeeze(0)), axis=-2) else: - grads = T[:, pt[1:, :, 0], pt[1:, :, 1], pt[1:, :, 2]] - del pt - dz = grads[:, 0] - grads[:, 1] - dy = grads[:, 2] - grads[:, 3] - dx = grads[:, 4] - grads[:, 5] + grads = T[tuple(neighbors[:,1:])] + del neighbors + dz = grads[0] - grads[1] + dy = grads[2] - grads[3] + dx = grads[4] - grads[5] del grads mu_torch = np.stack( (dz.cpu().squeeze(0), dy.cpu().squeeze(0), dx.cpu().squeeze(0)), axis=-2) @@ -161,7 +160,7 @@ def masks_to_flows_gpu(masks, device=None, niter=None): neighborsY = torch.stack((y, y - 1, y + 1, y, y, y - 1, y - 1, y + 1, y + 1), dim=0) neighborsX = torch.stack((x, x, x, x - 1, x + 1, x - 1, x + 1, x - 1, x + 1), dim=0) neighbors = torch.stack((neighborsY, neighborsX), dim=0) - neighbor_masks = masks_padded[neighbors[0], neighbors[1]] + neighbor_masks = masks_padded[tuple(neighbors)] isneighbor = neighbor_masks == neighbor_masks[0] ### get center-of-mass within cell @@ -210,17 +209,17 @@ def masks_to_flows_gpu_3d(masks, device=None): Lz0, Ly0, Lx0 = masks.shape Lz, Ly, Lx = Lz0 + 2, Ly0 + 2, Lx0 + 2 - masks_padded = np.zeros((Lz, Ly, Lx), np.int64) - masks_padded[1:-1, 1:-1, 1:-1] = masks - + masks_padded = torch.from_numpy(masks.astype("int64")).to(device) + masks_padded = F.pad(masks_padded, (1, 1, 1, 1, 1, 1)) + # get mask pixel neighbors - z, y, x = np.nonzero(masks_padded) - neighborsZ = np.stack((z, z + 1, z - 1, z, z, z, z)) - neighborsY = np.stack((y, y, y, y + 1, y - 1, y, y), axis=0) - neighborsX = np.stack((x, x, x, x, x, x + 1, x - 1), axis=0) - - neighbors = np.stack((neighborsZ, neighborsY, neighborsX), axis=-1) + z, y, x = torch.nonzero(masks_padded).T + neighborsZ = torch.stack((z, z + 1, z - 1, z, z, z, z)) + neighborsY = torch.stack((y, y, y, y + 1, y - 1, y, y), axis=0) + neighborsX = torch.stack((x, x, x, x, x, x + 1, x - 1), axis=0) + neighbors = torch.stack((neighborsZ, neighborsY, neighborsX), axis=0) + # get mask centers slices = find_objects(masks) @@ -245,8 +244,7 @@ def masks_to_flows_gpu_3d(masks, device=None): centers[i, 2] = xmed + sx.start # get neighbor validator (not all neighbors are in same mask) - neighbor_masks = masks_padded[neighbors[:, :, 0], neighbors[:, :, 1], - neighbors[:, :, 2]] + neighbor_masks = masks_padded[tuple(neighbors)] isneighbor = neighbor_masks == neighbor_masks[0] ext = np.array( [[sz.stop - sz.start + 1, sy.stop - sy.start + 1, sx.stop - sx.start + 1] @@ -262,7 +260,7 @@ def masks_to_flows_gpu_3d(masks, device=None): # put into original image mu0 = np.zeros((3, Lz0, Ly0, Lx0)) - mu0[:, z - 1, y - 1, x - 1] = mu + mu0[:, z.cpu().numpy() - 1, y.cpu().numpy() - 1, x.cpu().numpy() - 1] = mu mu_c = np.zeros_like(mu0) return mu0, mu_c @@ -362,7 +360,8 @@ def masks_to_flows(masks, device=None, niter=None): raise ValueError("masks_to_flows only takes 2D or 3D arrays") -def labels_to_flows(labels, files=None, device=None, redo_flows=False, niter=None): +def labels_to_flows(labels, files=None, device=None, + redo_flows=False, niter=None, return_flows=True): """Converts labels (list of masks or flows) to flows for training model. Args: @@ -384,6 +383,7 @@ def labels_to_flows(labels, files=None, device=None, redo_flows=False, niter=Non if labels[0].ndim < 3: labels = [labels[n][np.newaxis, :, :] for n in range(nimg)] + flows = [] # flows need to be recomputed if labels[0].shape[0] == 1 or labels[0].ndim < 3 or redo_flows: dynamics_logger.info("computing flows for labels") @@ -392,23 +392,22 @@ def labels_to_flows(labels, files=None, device=None, redo_flows=False, niter=Non # make sure labels are unique! labels = [fastremap.renumber(label, in_place=True)[0] for label in labels] iterator = trange if nimg > 1 else range - veci = [ - masks_to_flows(labels[n][0].astype(int), device=device, niter=niter) - for n in iterator(nimg) - ] - - # concatenate labels, distance transform, vector flows, heat (boundary and mask are computed in augmentations) - flows = [ - np.concatenate((labels[n], labels[n] > 0.5, veci[n]), - axis=0).astype(np.float32) for n in range(nimg) - ] - if files is not None: - for flow, file in zip(flows, files): - file_name = os.path.splitext(file)[0] + for n in iterator(nimg): + labels[n][0] = fastremap.renumber(labels[n][0], in_place=True)[0] + vecn = masks_to_flows(labels[n][0].astype(int), device=device, niter=niter) + + # concatenate labels, distance transform, vector flows, heat (boundary and mask are computed in augmentations) + flow = np.concatenate((labels[n], labels[n] > 0.5, vecn), + axis=0).astype(np.float32) + if files is not None: + file_name = os.path.splitext(files[n])[0] tifffile.imwrite(file_name + "_flows.tif", flow) + if return_flows: + flows.append(flow) else: dynamics_logger.info("flows precomputed") - flows = [labels[n].astype(np.float32) for n in range(nimg)] + if return_flows: + flows = [labels[n].astype(np.float32) for n in range(nimg)] return flows diff --git a/cellpose/gui/io.py b/cellpose/gui/io.py index d8bef2ed..f65534ca 100644 --- a/cellpose/gui/io.py +++ b/cellpose/gui/io.py @@ -166,9 +166,11 @@ def _initialize_images(parent, image, load_3D=False): c = np.array(image.shape).argmin() image = image.transpose(((c + 1) % 3, (c + 2) % 3, c)) elif load_3D: - # assume smallest dimension is Z and put first - z = np.array(image.shape).argmin() - image = image.transpose((z, (z + 1) % 3, (z + 2) % 3)) + # assume smallest dimension is Z and put first if <3x max dim + shape = np.array(image.shape) + z = shape.argmin() + if shape[z] < shape.max()/3: + image = image.transpose((z, (z + 1) % 3, (z + 2) % 3)) image = image[..., np.newaxis] elif image.ndim == 2: if not load_3D: diff --git a/cellpose/models.py b/cellpose/models.py index 8fad3e7f..9dd8cebf 100644 --- a/cellpose/models.py +++ b/cellpose/models.py @@ -51,12 +51,15 @@ def model_path(model_type, model_index=0): def size_model_path(model_type): - torch_str = "torch" - if model_type == "cyto" or model_type == "nuclei" or model_type == "cyto2": - basename = "size_%s%s_0.npy" % (model_type, torch_str) + if os.path.exists(model_type): + return model_type + "_size.npy" else: - basename = "size_%s.npy" % model_type - return cache_model_path(basename) + torch_str = "torch" + if model_type == "cyto" or model_type == "nuclei" or model_type == "cyto2": + basename = "size_%s%s_0.npy" % (model_type, torch_str) + else: + basename = "size_%s.npy" % model_type + return cache_model_path(basename) def cache_model_path(basename): @@ -99,7 +102,7 @@ class Cellpose(): """ - def __init__(self, gpu=False, model_type="cyto3", device=None): + def __init__(self, gpu=False, model_type="cyto3", nchan=2, device=None): super(Cellpose, self).__init__() # assign device (GPU or CPU) @@ -114,8 +117,13 @@ def __init__(self, gpu=False, model_type="cyto3", device=None): if nuclear: self.diam_mean = 17. + if model_type in ["cyto", "nuclei", "cyto2", "cyto3"] and nchan!=2: + nchan = 2 + models_logger.warning(f"cannot set nchan to other value for {model_type} model") + self.nchan = nchan + self.cp = CellposeModel(device=self.device, gpu=self.gpu, model_type=model_type, - diam_mean=self.diam_mean) + diam_mean=self.diam_mean, nchan=self.nchan) self.cp.model_type = model_type # size model not used for bacterial model diff --git a/cellpose/resnet_torch.py b/cellpose/resnet_torch.py index 0d710d0a..3a952f26 100644 --- a/cellpose/resnet_torch.py +++ b/cellpose/resnet_torch.py @@ -79,7 +79,6 @@ def forward(self, x): class batchconvstyle(nn.Module): - def __init__(self, in_channels, out_channels, style_channels, sz, conv_3D=False): super().__init__() self.concatenation = False @@ -104,8 +103,9 @@ def forward(self, style, x, mkldnn=False, y=None): class resup(nn.Module): def __init__(self, in_channels, out_channels, style_channels, sz, - concatenation=False, conv_3D=False): + conv_3D=False): super().__init__() + self.concatenation = False self.conv = nn.Sequential() self.conv.add_module("conv_0", batchconv(in_channels, out_channels, sz, conv_3D=conv_3D)) @@ -130,24 +130,6 @@ def forward(self, x, y, style, mkldnn=False): return x -class convup(nn.Module): - - def __init__(self, in_channels, out_channels, style_channels, sz, - concatenation=False, conv_3D=False): - super().__init__() - self.conv = nn.Sequential() - self.conv.add_module("conv_0", batchconv(in_channels, out_channels, sz, - conv_3D)) - self.conv.add_module( - "conv_1", - batchconvstyle(out_channels, out_channels, style_channels, sz, - concatenation=concatenation, conv_3D=conv_3D)) - - def forward(self, x, y, style, mkldnn=False): - x = self.conv[1](style, self.conv[0](x), y=y) - return x - - class make_style(nn.Module): def __init__(self, conv_3D=False): @@ -164,7 +146,7 @@ def forward(self, x0): class upsample(nn.Module): - def __init__(self, nbase, sz, residual_on=True, conv_3D=False): + def __init__(self, nbase, sz, conv_3D=False): super().__init__() self.upsampling = nn.Upsample(scale_factor=2, mode="nearest") self.up = nn.Sequential() diff --git a/cellpose/train.py b/cellpose/train.py index c4977d8f..34c27c12 100644 --- a/cellpose/train.py +++ b/cellpose/train.py @@ -298,7 +298,8 @@ def train_seg(net, train_data=None, train_labels=None, train_files=None, n_epochs=2000, weight_decay=1e-5, momentum=0.9, SGD=False, channels=None, channel_axis=None, normalize=True, compute_flows=False, save_path=None, save_every=100, nimg_per_epoch=None, nimg_test_per_epoch=None, - rescale=True, min_train_masks=5, model_name=None): + rescale=True, scale_range=None, bsize=224, + min_train_masks=5, model_name=None): """ Train the network with images for segmentation. @@ -338,7 +339,8 @@ def train_seg(net, train_data=None, train_labels=None, train_files=None, """ device = net.device - scale_range = 0.5 if rescale else 1.0 + scale_range0 = 0.5 if rescale else 1.0 + scale_range = scale_range if scale_range is not None else scale_range0 if isinstance(normalize, dict): normalize_params = {**models.normalize_default, **normalize} @@ -424,7 +426,7 @@ def train_seg(net, train_data=None, train_labels=None, train_files=None, rsc = diams / net.diam_mean.item() # augmentations imgi, lbl = transforms.random_rotate_and_resize(imgs, Y=lbls, rescale=rsc, - scale_range=scale_range)[:2] + scale_range=scale_range, xy=(bsize, bsize))[:2] X = torch.from_numpy(imgi).to(device) y = net(X)[0] @@ -460,7 +462,8 @@ def train_seg(net, train_data=None, train_labels=None, train_files=None, diams = np.array([diam_test[i] for i in inds]) rsc = diams / net.diam_mean.item() imgi, lbl = transforms.random_rotate_and_resize( - imgs, Y=lbls, rescale=rsc, scale_range=scale_range)[:2] + imgs, Y=lbls, rescale=rsc, scale_range=scale_range, + xy=(bsize, bsize))[:2] X = torch.from_numpy(imgi).to(device) y = net(X)[0] loss = _loss_fn_seg(lbl, y, device) @@ -486,6 +489,7 @@ def train_size(net, pretrained_model, train_data=None, train_labels=None, test_labels_files=None, test_probs=None, load_files=True, min_train_masks=5, channels=None, channel_axis=None, normalize=True, nimg_per_epoch=None, nimg_test_per_epoch=None, batch_size=128, + scale_range=1.0, bsize=512, l2_regularization=1.0, n_epochs=10): """Train the size model. @@ -564,7 +568,7 @@ def train_size(net, pretrained_model, train_data=None, train_labels=None, normalize_params=normalize_params) diami = diam_train[inds].copy() imgi, lbl, scale = transforms.random_rotate_and_resize( - imgs, scale_range=1, xy=(512, 512)) + imgs, scale_range=scale_range, xy=(bsize, bsize)) imgi = torch.from_numpy(imgi).to(device) with torch.no_grad(): feat = net(imgi)[1] @@ -608,7 +612,7 @@ def train_size(net, pretrained_model, train_data=None, train_labels=None, normalize_params=normalize_params) diami = diam_test[inds].copy() imgi, lbl, scale = transforms.random_rotate_and_resize( - imgs, Y=lbls, scale_range=1, xy=(512, 512)) + imgs, Y=lbls, scale_range=scale_range, xy=(bsize, bsize)) imgi = torch.from_numpy(imgi).to(device) diamt = np.array([utils.diameters(lbl0[0])[0] for lbl0 in lbl]) diamt = np.maximum(5., diamt) From e09c0e4ce833f1a582d88d281a563078574066f6 Mon Sep 17 00:00:00 2001 From: Carsen Stringer Date: Sat, 30 Mar 2024 14:21:33 -0400 Subject: [PATCH 2/9] updating for rgb training --- cellpose/__main__.py | 33 +++++++--- cellpose/cli.py | 10 ++- cellpose/train.py | 145 ++++++++++++++++++++++++++----------------- 3 files changed, 123 insertions(+), 65 deletions(-) diff --git a/cellpose/__main__.py b/cellpose/__main__.py index 072e931f..d0b4abf9 100644 --- a/cellpose/__main__.py +++ b/cellpose/__main__.py @@ -221,13 +221,28 @@ def main(): else: test_dir = None if len(args.test_dir) == 0 else args.test_dir - output = io.load_train_test_data(args.dir, test_dir, imf, args.mask_filter, - args.look_one_level_down) - images, labels, image_names, test_images, test_labels, image_names_test = output + images, labels, image_names, train_probs = None, None, None, None + test_images, test_labels, image_names_test, test_probs = None, None, None, None + if len(args.file_list) > 0: + if os.path.exists(args.file_list): + dat = np.load(args.file_list, allow_pickle=True).item() + image_names = dat["train_files"] + image_names_test = dat["test_files"] + if "train_probs" in dat: + train_probs = dat["train_probs"] + test_probs = dat["test_probs"] + load_files = False + else: + logger.critical(f"ERROR: {args.file_list} does not exist") + else: + output = io.load_train_test_data(args.dir, test_dir, imf, args.mask_filter, + args.look_one_level_down) + images, labels, image_names, test_images, test_labels, image_names_test = output + load_files = True # training with all channels if args.all_channels: - img = images[0] + img = images[0] if images is not None else io.imread(image_names[0]) if img.ndim == 3: nchan = min(img.shape) elif img.ndim == 2: @@ -261,12 +276,16 @@ def main(): cpmodel_path = train.train_seg( model.net, images, labels, train_files=image_names, test_data=test_images, test_labels=test_labels, - test_files=image_names_test, learning_rate=args.learning_rate, + test_files=image_names_test, + train_probs=train_probs, test_probs=test_probs, + load_files=load_files, learning_rate=args.learning_rate, weight_decay=args.weight_decay, channels=channels, - channel_axis=args.channel_axis, + channel_axis=args.channel_axis, rgb=(nchan==3), save_path=os.path.realpath(args.dir), save_every=args.save_every, SGD=args.SGD, n_epochs=args.n_epochs, batch_size=args.batch_size, - min_train_masks=args.min_train_masks, + min_train_masks=args.min_train_masks, + nimg_per_epoch=args.nimg_per_epoch, normalize=(not args.no_norm), + nimg_test_per_epoch=args.nimg_test_per_epoch, model_name=args.model_name_out) model.pretrained_model = cpmodel_path logger.info(">>>> model trained and saved to %s" % cpmodel_path) diff --git a/cellpose/cli.py b/cellpose/cli.py index 83532076..78a1196f 100644 --- a/cellpose/cli.py +++ b/cellpose/cli.py @@ -63,7 +63,7 @@ def get_arg_parser(): input_img_args.add_argument( "--all_channels", action="store_true", help= "use all channels in image if using own model and images with special channels") - + # model settings model_args = parser.add_argument_group("Model Arguments") model_args.add_argument("--pretrained_model", required=False, default="cyto", @@ -171,6 +171,10 @@ def get_arg_parser(): help="train size network at end of training") training_args.add_argument("--test_dir", default=[], type=str, help="folder containing test data (optional)") + training_args.add_argument( + "--file_list", default=[], type=str, help= + "path to list of files for training and testing and probabilities for each image (optional)" + ) training_args.add_argument( "--mask_filter", default="_masks", type=str, help= "end string for masks to run on. use '_seg.npy' for manual annotations from the GUI. Default: %(default)s" @@ -187,6 +191,10 @@ def get_arg_parser(): help="number of epochs. Default: %(default)s") training_args.add_argument("--batch_size", default=8, type=int, help="batch size. Default: %(default)s") + training_args.add_argument("--nimg_per_epoch", default=None, type=int, + help="number of train images per epoch. Default is to use all train images.") + training_args.add_argument("--nimg_test_per_epoch", default=None, type=int, + help="number of test images per epoch. Default is to use all test images.") training_args.add_argument( "--min_train_masks", default=5, type=int, help= "minimum number of masks a training image must have to be used. Default: %(default)s" diff --git a/cellpose/train.py b/cellpose/train.py index 34c27c12..9950628a 100644 --- a/cellpose/train.py +++ b/cellpose/train.py @@ -38,7 +38,8 @@ def _loss_fn_seg(lbl, y, device): def _get_batch(inds, data=None, labels=None, files=None, labels_files=None, - channels=None, channel_axis=None, normalize_params={"normalize": False}): + channels=None, channel_axis=None, rgb=False, + normalize_params={"normalize": False}): """ Get a batch of images and labels. @@ -56,26 +57,39 @@ def _get_batch(inds, data=None, labels=None, files=None, labels_files=None, tuple: A tuple containing two lists: the batch of images and the batch of labels. """ if data is None: + lbls = None imgs = [io.imread(files[i]) for i in inds] - if channels is not None: - imgs = [ - transforms.convert_image(img, channels=channels, - channel_axis=channel_axis, nchan=None) for img in imgs - ] - imgs = [img.transpose(2, 0, 1) for img in imgs] - if normalize_params["normalize"]: - imgs = [ - transforms.normalize_img(img, normalize=normalize_params, axis=0) - for img in imgs - ] - lbls = [io.imread(labels_files[i])[1:] for i in inds] + imgs = _reshape_norm(imgs, channels=channels, channel_axis=channel_axis, rgb=rgb, + normalize_params=normalize_params) + # if channels is not None: + # imgs = [ + # transforms.convert_image(img, channels=channels, + # channel_axis=channel_axis, nchan=None) for img in imgs + # ] + # imgs = [img.transpose(2, 0, 1) for img in imgs] + # if normalize_params["normalize"]: + # imgs = [ + # transforms.normalize_img(img, normalize=normalize_params, axis=0) + # for img in imgs + # ] + if labels_files is not None: + lbls = [io.imread(labels_files[i])[1:] for i in inds] else: imgs = [data[i] for i in inds] lbls = [labels[i][1:] for i in inds] return imgs, lbls - -def _reshape_norm(data, channels=None, channel_axis=None, +def pad_to_rgb(img): + if img.ndim==2 or (img.shape[0]==2 and np.ptp(img[1])<1e-3): + if img.ndim==2: + img = img[np.newaxis,:,:] + img = np.tile(img[:1], (3,1,1)) + elif img.shape[0] < 3: + nc, Ly, Lx = img.shape + img = np.concatenate((img, np.zeros((3-nc, Ly, Lx), dtype=img.dtype)), axis=0) + return img + +def _reshape_norm(data, channels=None, channel_axis=None, rgb=False, normalize_params={"normalize": False}): """ Reshapes and normalizes the input data. @@ -91,7 +105,7 @@ def _reshape_norm(data, channels=None, channel_axis=None, """ if channels is not None or channel_axis is not None: data = [ - transforms.convert_image(td, channels=channels, channel_axis=channel_axis, nchan=None) + transforms.convert_image(td, channels=channels, channel_axis=channel_axis) for td in data ] data = [td.transpose(2, 0, 1) for td in data] @@ -100,6 +114,8 @@ def _reshape_norm(data, channels=None, channel_axis=None, transforms.normalize_img(td, normalize=normalize_params, axis=0) for td in data ] + if rgb: + data = [pad_to_rgb(td) for td in data] return data @@ -111,7 +127,7 @@ def _reshape_norm_save(files, channels=None, channel_axis=None, td = io.imread(f) if channels is not None: td = transforms.convert_image(td, channels=channels, - channel_axis=channel_axis, nchan=None) + channel_axis=channel_axis) td = td.transpose(2, 0, 1) if normalize_params["normalize"]: td = transforms.normalize_img(td, normalize=normalize_params, axis=0) @@ -129,11 +145,11 @@ def _reshape_norm_save(files, channels=None, channel_axis=None, def _process_train_test(train_data=None, train_labels=None, train_files=None, train_labels_files=None, train_probs=None, test_data=None, - test_labels=None, test_files=None, test_labels_files=None, + test_labels=None, test_files=None, test_labels_files=None, test_probs=None, load_files=True, min_train_masks=5, - compute_flows=False, channels=None, channel_axis=None, - normalize_params={"normalize": False - }, device=torch.device("cuda")): + compute_flows=False, channels=None, channel_axis=None, + rgb=False, normalize_params={"normalize": False}, + device=torch.device("cuda")): """ Process train and test data. @@ -141,18 +157,19 @@ def _process_train_test(train_data=None, train_labels=None, train_files=None, train_data (list or None): List of training data arrays. train_labels (list or None): List of training label arrays. train_files (list or None): List of training file paths. - train_labels_files (list or None): List of training label file paths. + #train_labels_files (list or None): List of training label file paths. train_probs (ndarray or None): Array of training probabilities. test_data (list or None): List of test data arrays. test_labels (list or None): List of test label arrays. test_files (list or None): List of test file paths. - test_labels_files (list or None): List of test label file paths. + #test_labels_files (list or None): List of test label file paths. test_probs (ndarray or None): Array of test probabilities. load_files (bool): Whether to load data from files. min_train_masks (int): Minimum number of masks required for training images. compute_flows (bool): Whether to compute flows. channels (list or None): List of channel indices to use. channel_axis (int or None): Axis of channel dimension. + rgb (bool): Convert training/testing images to RGB. normalize_params (dict): Dictionary of normalization parameters. device (torch.device): Device to use for computation. @@ -166,6 +183,16 @@ def _process_train_test(train_data=None, train_labels=None, train_files=None, else: # otherwise use files nimg = len(train_files) + if train_labels_files is None: + train_labels_files = [ + os.path.splitext(str(tf))[0] + "_flows.tif" for tf in train_files + ] + train_labels_files = [tf for tf in train_labels_files if os.path.exists(tf)] + if test_data is not None or test_files is not None and test_labels_files is None: + test_labels_files = [ + os.path.splitext(str(tf))[0] + "_flows.tif" for tf in test_files + ] + test_labels_files = [tf for tf in test_labels_files if os.path.exists(tf)] if not load_files: train_logger.info(">>> using files instead of loading dataset") else: @@ -209,16 +236,11 @@ def _process_train_test(train_data=None, train_labels=None, train_files=None, for k in trange(nimg): tl = dynamics.labels_to_flows(io.imread(train_labels_files), files=train_files, device=device) - train_labels_files = [ - os.path.splitext(str(tf))[0] + "_flows.tif" for tf in train_files - ] if test_files is not None: for k in trange(nimg_test): tl = dynamics.labels_to_flows(io.imread(test_labels_files), files=test_files, device=device) - test_labels_files = [ - os.path.splitext(str(tf))[0] + "_flows.tif" for tf in test_files - ] + ### compute diameters nmasks = np.zeros(nimg) @@ -272,6 +294,7 @@ def _process_train_test(train_data=None, train_labels=None, train_files=None, test_probs /= test_probs.sum() ### reshape and normalize train / test data + normed = False if channels is not None or normalize_params["normalize"]: if channels: train_logger.info(f">>> using channels {channels}") @@ -279,24 +302,26 @@ def _process_train_test(train_data=None, train_labels=None, train_files=None, train_logger.info(f">>> normalizing {normalize_params}") if train_data is not None: train_data = _reshape_norm(train_data, channels=channels, - channel_axis=channel_axis, + channel_axis=channel_axis, rgb=rgb, normalize_params=normalize_params) + normed = True if test_data is not None: test_data = _reshape_norm(test_data, channels=channels, - channel_axis=channel_axis, + channel_axis=channel_axis, rgb=rgb, normalize_params=normalize_params) return (train_data, train_labels, train_files, train_labels_files, train_probs, diam_train, test_data, test_labels, test_files, test_labels_files, - test_probs, diam_test) + test_probs, diam_test, normed) def train_seg(net, train_data=None, train_labels=None, train_files=None, train_labels_files=None, train_probs=None, test_data=None, - test_labels=None, test_files=None, test_labels_files=None, + test_labels=None, test_files=None, test_labels_files=None, test_probs=None, load_files=True, batch_size=8, learning_rate=0.005, n_epochs=2000, weight_decay=1e-5, momentum=0.9, SGD=False, channels=None, - channel_axis=None, normalize=True, compute_flows=False, save_path=None, + channel_axis=None, rgb=False, normalize=True, + compute_flows=False, save_path=None, save_every=100, nimg_per_epoch=None, nimg_test_per_epoch=None, rescale=True, scale_range=None, bsize=224, min_train_masks=5, model_name=None): @@ -308,12 +333,10 @@ def train_seg(net, train_data=None, train_labels=None, train_files=None, train_data (List[np.ndarray], optional): List of arrays (2D or 3D) - images for training. Defaults to None. train_labels (List[np.ndarray], optional): List of arrays (2D or 3D) - labels for train_data, where 0=no masks; 1,2,...=mask labels. Defaults to None. train_files (List[str], optional): List of strings - file names for images in train_data (to save flows for future runs). Defaults to None. - train_labels_files (List[str], optional): List of strings - file names for labels in train_labels. Defaults to None. train_probs (List[float], optional): List of floats - probabilities for each image to be selected during training. Defaults to None. test_data (List[np.ndarray], optional): List of arrays (2D or 3D) - images for testing. Defaults to None. test_labels (List[np.ndarray], optional): List of arrays (2D or 3D) - labels for test_data, where 0=no masks; 1,2,...=mask labels. Defaults to None. test_files (List[str], optional): List of strings - file names for images in test_data (to save flows for future runs). Defaults to None. - test_labels_files (List[str], optional): List of strings - file names for labels in test_labels. Defaults to None. test_probs (List[float], optional): List of floats - probabilities for each image to be selected during testing. Defaults to None. load_files (bool, optional): Boolean - whether to load images and labels from files. Defaults to True. batch_size (int, optional): Integer - number of patches to run simultaneously on the GPU. Defaults to 8. @@ -338,7 +361,7 @@ def train_seg(net, train_data=None, train_labels=None, train_files=None, Path: path to saved model weights """ device = net.device - + scale_range0 = 0.5 if rescale else 1.0 scale_range = scale_range if scale_range is not None else scale_range0 @@ -352,14 +375,20 @@ def train_seg(net, train_data=None, train_labels=None, train_files=None, out = _process_train_test( train_data=train_data, train_labels=train_labels, train_files=train_files, - train_labels_files=train_labels_files, train_probs=train_probs, + train_probs=train_probs, test_data=test_data, test_labels=test_labels, test_files=test_files, - test_labels_files=test_labels_files, test_probs=test_probs, + test_probs=test_probs, load_files=load_files, min_train_masks=min_train_masks, compute_flows=compute_flows, channels=channels, channel_axis=channel_axis, - normalize_params=normalize_params, device=net.device) + rgb=rgb, normalize_params=normalize_params, device=net.device) (train_data, train_labels, train_files, train_labels_files, train_probs, diam_train, - test_data, test_labels, test_files, test_labels_files, test_probs, diam_test) = out + test_data, test_labels, test_files, test_labels_files, test_probs, diam_test, normed) = out + # already normalized, do not normalize during training + if normed: + kwargs = {} + else: + kwargs = {"normalize_params": normalize_params, "channels": channels, + "channel_axis": channel_axis, "rgb": rgb} net.diam_labels.data = torch.Tensor([diam_train.mean()]).to(device) @@ -420,10 +449,9 @@ def train_seg(net, train_data=None, train_labels=None, train_files=None, inds = rperm[k:kend] imgs, lbls = _get_batch(inds, data=train_data, labels=train_labels, files=train_files, labels_files=train_labels_files, - channels=channels, channel_axis=channel_axis, - normalize_params=normalize_params) + **kwargs) diams = np.array([diam_train[i] for i in inds]) - rsc = diams / net.diam_mean.item() + rsc = diams / net.diam_mean.item() if rescale else np.ones(len(diams), "float32") # augmentations imgi, lbl = transforms.random_rotate_and_resize(imgs, Y=lbls, rescale=rsc, scale_range=scale_range, xy=(bsize, bsize))[:2] @@ -439,7 +467,7 @@ def train_seg(net, train_data=None, train_labels=None, train_files=None, train_loss *= len(imgi) lavg += train_loss nsum += len(imgi) - + if iepoch == 5 or iepoch % 10 == 0: lavgt = 0. if test_data is not None or test_files is not None: @@ -456,11 +484,9 @@ def train_seg(net, train_data=None, train_labels=None, train_files=None, imgs, lbls = _get_batch(inds, data=test_data, labels=test_labels, files=test_files, labels_files=test_labels_files, - channels=channels, - channel_axis=channel_axis, - normalize_params=normalize_params) + **kwargs) diams = np.array([diam_test[i] for i in inds]) - rsc = diams / net.diam_mean.item() + rsc = diams / net.diam_mean.item() if rescale else np.ones(len(diams), "float32") imgi, lbl = transforms.random_rotate_and_resize( imgs, Y=lbls, rescale=rsc, scale_range=scale_range, xy=(bsize, bsize))[:2] @@ -487,9 +513,9 @@ def train_size(net, pretrained_model, train_data=None, train_labels=None, train_files=None, train_labels_files=None, train_probs=None, test_data=None, test_labels=None, test_files=None, test_labels_files=None, test_probs=None, load_files=True, - min_train_masks=5, channels=None, channel_axis=None, normalize=True, - nimg_per_epoch=None, nimg_test_per_epoch=None, batch_size=128, - scale_range=1.0, bsize=512, + min_train_masks=5, channels=None, channel_axis=None, rgb=False, + normalize=True, nimg_per_epoch=None, nimg_test_per_epoch=None, + batch_size=128, scale_range=1.0, bsize=512, l2_regularization=1.0, n_epochs=10): """Train the size model. @@ -537,7 +563,14 @@ def train_size(net, pretrained_model, train_data=None, train_labels=None, channels=channels, channel_axis=channel_axis, normalize_params=normalize_params, device=net.device) (train_data, train_labels, train_files, train_labels_files, train_probs, diam_train, - test_data, test_labels, test_files, test_labels_files, test_probs, diam_test) = out + test_data, test_labels, test_files, test_labels_files, test_probs, diam_test, normed) = out + + # already normalized, do not normalize during training + if normed: + kwargs = {} + else: + kwargs = {"normalize_params": normalize_params, "channels": channels, + "channel_axis": channel_axis, "rgb": rgb} nimg = len(train_data) if train_data is not None else len(train_files) nimg_test = len(test_data) if test_data is not None else None @@ -564,8 +597,7 @@ def train_size(net, pretrained_model, train_data=None, train_labels=None, inds = rperm[inds_batch] imgs, lbls = _get_batch(inds, data=train_data, labels=train_labels, files=train_files, labels_files=train_labels_files, - channels=channels, channel_axis=channel_axis, - normalize_params=normalize_params) + **kwargs) diami = diam_train[inds].copy() imgi, lbl, scale = transforms.random_rotate_and_resize( imgs, scale_range=scale_range, xy=(bsize, bsize)) @@ -608,8 +640,7 @@ def train_size(net, pretrained_model, train_data=None, train_labels=None, inds = rperm[inds_batch] imgs, lbls = _get_batch(inds, data=test_data, labels=test_labels, files=test_files, labels_files=test_labels_files, - channels=channels, channel_axis=channel_axis, - normalize_params=normalize_params) + **kwargs) diami = diam_test[inds].copy() imgi, lbl, scale = transforms.random_rotate_and_resize( imgs, Y=lbls, scale_range=scale_range, xy=(bsize, bsize)) From a339b79d6ae3ce1a43d7b0f3085c0bf94f4401f1 Mon Sep 17 00:00:00 2001 From: Carsen Stringer Date: Sat, 30 Mar 2024 16:05:28 -0400 Subject: [PATCH 3/9] converting to grayscale --- cellpose/train.py | 22 ++++++++++++++-------- 1 file changed, 14 insertions(+), 8 deletions(-) diff --git a/cellpose/train.py b/cellpose/train.py index 9950628a..bef184a0 100644 --- a/cellpose/train.py +++ b/cellpose/train.py @@ -79,14 +79,20 @@ def _get_batch(inds, data=None, labels=None, files=None, labels_files=None, lbls = [labels[i][1:] for i in inds] return imgs, lbls -def pad_to_rgb(img): - if img.ndim==2 or (img.shape[0]==2 and np.ptp(img[1])<1e-3): - if img.ndim==2: - img = img[np.newaxis,:,:] - img = np.tile(img[:1], (3,1,1)) +def convert_to_rgb(img): + if img.ndim==2: + img = img[np.newaxis,:,:] + img = np.tile(img, (3,1,1)) elif img.shape[0] < 3: - nc, Ly, Lx = img.shape - img = np.concatenate((img, np.zeros((3-nc, Ly, Lx), dtype=img.dtype)), axis=0) + img = img.mean(axis=0, keepdims=True) + img = transforms.normalize99(img) + img = np.tile(img, (3,1,1)) + # if img.ndim==2: + # img = img[np.newaxis,:,:] + # img = np.tile(img[:1], (3,1,1)) + # elif img.shape[0] < 3: + # nc, Ly, Lx = img.shape + # img = np.concatenate((img, np.zeros((3-nc, Ly, Lx), dtype=img.dtype)), axis=0) return img def _reshape_norm(data, channels=None, channel_axis=None, rgb=False, @@ -115,7 +121,7 @@ def _reshape_norm(data, channels=None, channel_axis=None, rgb=False, for td in data ] if rgb: - data = [pad_to_rgb(td) for td in data] + data = [convert_to_rgb(td) for td in data] return data From 44bb27720912e8465a2eeafc40cc89fefcee9c05 Mon Sep 17 00:00:00 2001 From: Carsen Stringer Date: Sat, 30 Mar 2024 20:15:50 -0400 Subject: [PATCH 4/9] random RGB channels --- cellpose/train.py | 30 ++++++++++++------------------ 1 file changed, 12 insertions(+), 18 deletions(-) diff --git a/cellpose/train.py b/cellpose/train.py index bef184a0..4bb9ff20 100644 --- a/cellpose/train.py +++ b/cellpose/train.py @@ -61,17 +61,6 @@ def _get_batch(inds, data=None, labels=None, files=None, labels_files=None, imgs = [io.imread(files[i]) for i in inds] imgs = _reshape_norm(imgs, channels=channels, channel_axis=channel_axis, rgb=rgb, normalize_params=normalize_params) - # if channels is not None: - # imgs = [ - # transforms.convert_image(img, channels=channels, - # channel_axis=channel_axis, nchan=None) for img in imgs - # ] - # imgs = [img.transpose(2, 0, 1) for img in imgs] - # if normalize_params["normalize"]: - # imgs = [ - # transforms.normalize_img(img, normalize=normalize_params, axis=0) - # for img in imgs - # ] if labels_files is not None: lbls = [io.imread(labels_files[i])[1:] for i in inds] else: @@ -79,6 +68,17 @@ def _get_batch(inds, data=None, labels=None, files=None, labels_files=None, lbls = [labels[i][1:] for i in inds] return imgs, lbls +def pad_to_rgb(img): + if img.ndim==2 or np.ptp(img[1]) < 1e-3: + if img.ndim==2: + img = img[np.newaxis,:,:] + img = np.tile(img[:1], (3,1,1)) + elif img.shape[0] < 3: + nc, Ly, Lx = img.shape + ic = np.random.randint(3) + img = np.insert(img, ic, np.zeros((3-nc, Ly, Lx), dtype=img.dtype), axis=0) + return img + def convert_to_rgb(img): if img.ndim==2: img = img[np.newaxis,:,:] @@ -87,12 +87,6 @@ def convert_to_rgb(img): img = img.mean(axis=0, keepdims=True) img = transforms.normalize99(img) img = np.tile(img, (3,1,1)) - # if img.ndim==2: - # img = img[np.newaxis,:,:] - # img = np.tile(img[:1], (3,1,1)) - # elif img.shape[0] < 3: - # nc, Ly, Lx = img.shape - # img = np.concatenate((img, np.zeros((3-nc, Ly, Lx), dtype=img.dtype)), axis=0) return img def _reshape_norm(data, channels=None, channel_axis=None, rgb=False, @@ -121,7 +115,7 @@ def _reshape_norm(data, channels=None, channel_axis=None, rgb=False, for td in data ] if rgb: - data = [convert_to_rgb(td) for td in data] + data = [pad_to_rgb(td) for td in data] return data From f8a721a0f3513bed68472e137c29507502d0ce51 Mon Sep 17 00:00:00 2001 From: Carsen Stringer Date: Sun, 31 Mar 2024 08:39:02 -0400 Subject: [PATCH 5/9] adding random flipping for all rgb combos --- cellpose/train.py | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/cellpose/train.py b/cellpose/train.py index 4bb9ff20..715f7478 100644 --- a/cellpose/train.py +++ b/cellpose/train.py @@ -75,6 +75,10 @@ def pad_to_rgb(img): img = np.tile(img[:1], (3,1,1)) elif img.shape[0] < 3: nc, Ly, Lx = img.shape + # randomly flip channels + if np.random.rand() > 0.5: + img = img[::-1] + # randomly insert blank channel ic = np.random.randint(3) img = np.insert(img, ic, np.zeros((3-nc, Ly, Lx), dtype=img.dtype), axis=0) return img From 0b7b9ffb90e4e0a028a32b49255769403be8e7c9 Mon Sep 17 00:00:00 2001 From: Carsen Stringer Date: Sun, 31 Mar 2024 12:48:46 -0400 Subject: [PATCH 6/9] adding segformer option --- cellpose/__main__.py | 28 ++++++++++++---- cellpose/cli.py | 4 ++- cellpose/models.py | 52 ++++++++++++++++++----------- cellpose/segformer.py | 78 +++++++++++++++++++++++++++++++++++++++++++ 4 files changed, 134 insertions(+), 28 deletions(-) create mode 100644 cellpose/segformer.py diff --git a/cellpose/__main__.py b/cellpose/__main__.py index d0b4abf9..4eb60d07 100644 --- a/cellpose/__main__.py +++ b/cellpose/__main__.py @@ -90,6 +90,7 @@ def main(): else: pretrained_model = args.pretrained_model + restore_type = args.restore_type if restore_type is not None: try: @@ -98,6 +99,15 @@ def main(): raise ValueError("restore_type invalid") if args.train or args.train_size: raise ValueError("restore_type cannot be used with training on CLI yet") + + if args.transformer and (restore_type is None): + default_model = "transformer_cp3" + backbone = "transformer" + elif args.transformer and restore_type is not None: + raise ValueError("no transformer based restoration") + else: + default_model = "cyto3" + backbone = "default" model_type = None if pretrained_model and not os.path.exists(pretrained_model): @@ -106,13 +116,15 @@ def main(): all_models = models.MODEL_NAMES.copy() all_models.extend(model_strings) if ~np.any([model_type == s for s in all_models]): - model_type = "cyto" - logger.warning("pretrained model has incorrect path") + model_type = default_model + logger.warning(f"pretrained model has incorrect path, using {default_model}") if model_type == "nuclei": szmean = 17. else: szmean = 30. - builtin_size = model_type == "cyto" or model_type == "cyto2" or model_type == "nuclei" or model_type == "cyto3" + builtin_size = (model_type == "cyto" or model_type == "cyto2" or + model_type == "nuclei" or model_type == "cyto3" or + model_type=="transformer_cp3") if len(args.image_path) > 0 and (args.train or args.train_size): raise ValueError("ERROR: cannot train model with single image input") @@ -138,7 +150,8 @@ def main(): # handle built-in model exceptions if builtin_size and restore_type is None: - model = models.Cellpose(gpu=gpu, device=device, model_type=model_type) + model = models.Cellpose(gpu=gpu, device=device, + model_type=model_type, backbone=backbone) else: builtin_size = False if args.all_channels: @@ -147,7 +160,8 @@ def main(): if restore_type is None: model = models.CellposeModel(gpu=gpu, device=device, pretrained_model=pretrained_model, - model_type=model_type) + model_type=model_type, + backbone=backbone) else: model = denoise.CellposeDenoiseModel(gpu=gpu, device=device, pretrained_model=pretrained_model, @@ -267,9 +281,9 @@ def main(): # initialize model model = models.CellposeModel( - device=device, + device=device, model_type=model_type, diam_mean=szmean, nchan=nchan, pretrained_model=pretrained_model if model_type is None else None, - model_type=model_type, diam_mean=szmean, nchan=nchan) + backbone=backbone) # train segmentation model if args.train: diff --git a/cellpose/cli.py b/cellpose/cli.py index 78a1196f..a6074ebd 100644 --- a/cellpose/cli.py +++ b/cellpose/cli.py @@ -77,7 +77,9 @@ def get_arg_parser(): model_args.add_argument( "--add_model", required=False, default=None, type=str, help="model path to copy model to hidden .cellpose folder for using in GUI/CLI") - + model_args.add_argument("--transformer", action="store_true", + help="use transformer backbone (pretrained_model from Cellpose3 is transformer_cp3)") + # algorithm settings algorithm_args = parser.add_argument_group("Algorithm Arguments") algorithm_args.add_argument( diff --git a/cellpose/models.py b/cellpose/models.py index 9dd8cebf..b73a9398 100644 --- a/cellpose/models.py +++ b/cellpose/models.py @@ -24,7 +24,8 @@ MODEL_NAMES = [ "cyto3", "nuclei", "cyto2_cp3", "tissuenet_cp3", "livecell_cp3", "yeast_PhC_cp3", - "yeast_BF_cp3", "bact_phase_cp3", "bact_fluor_cp3", "deepbacs_cp3", "cyto2", "cyto" + "yeast_BF_cp3", "bact_phase_cp3", "bact_fluor_cp3", "deepbacs_cp3", "cyto2", "cyto", + "transformer_cp3" ] MODEL_LIST_PATH = os.fspath(MODEL_DIR.joinpath("gui_models.txt")) @@ -42,13 +43,15 @@ } def model_path(model_type, model_index=0): - torch_str = "torch" - if model_type == "cyto" or model_type == "cyto2" or model_type == "nuclei": - basename = "%s%s_%d" % (model_type, torch_str, model_index) + if not os.path.exists(model_type): + torch_str = "torch" + if model_type == "cyto" or model_type == "cyto2" or model_type == "nuclei": + basename = "%s%s_%d" % (model_type, torch_str, model_index) + else: + basename = model_type + return cache_model_path(basename) else: - basename = model_type - return cache_model_path(basename) - + return model_type def size_model_path(model_type): if os.path.exists(model_type): @@ -102,13 +105,15 @@ class Cellpose(): """ - def __init__(self, gpu=False, model_type="cyto3", nchan=2, device=None): + def __init__(self, gpu=False, model_type="cyto3", nchan=2, + device=None, backbone="default"): super(Cellpose, self).__init__() # assign device (GPU or CPU) sdevice, gpu = assign_device(use_torch=True, gpu=gpu) self.device = device if device is not None else sdevice self.gpu = gpu + self.backbone = backbone model_type = "cyto3" if model_type is None else model_type @@ -123,7 +128,8 @@ def __init__(self, gpu=False, model_type="cyto3", nchan=2, device=None): self.nchan = nchan self.cp = CellposeModel(device=self.device, gpu=self.gpu, model_type=model_type, - diam_mean=self.diam_mean, nchan=self.nchan) + diam_mean=self.diam_mean, nchan=self.nchan, + backbone=self.backbone) self.cp.model_type = model_type # size model not used for bacterial model @@ -231,7 +237,7 @@ class CellposeModel(): """ def __init__(self, gpu=False, pretrained_model=False, model_type=None, - diam_mean=30., device=None, nchan=2): + diam_mean=30., device=None, nchan=2, backbone="default"): """ Initialize the CellposeModel. @@ -245,19 +251,21 @@ def __init__(self, gpu=False, pretrained_model=False, model_type=None, """ self.diam_mean = diam_mean builtin = True - + default_model = "cyto3" if backbone=="default" else "transformer_cp3" if model_type is not None or (pretrained_model and not os.path.exists(pretrained_model)): - pretrained_model_string = model_type if model_type is not None else "cyto" + pretrained_model_string = model_type if model_type is not None else default_model model_strings = get_user_models() all_models = MODEL_NAMES.copy() all_models.extend(model_strings) if ~np.any([pretrained_model_string == s for s in MODEL_NAMES]): builtin = False - elif ~np.any([pretrained_model_string == s for s in all_models]): - pretrained_model_string = "cyto3" + if (not os.path.exists(model_type) and + ~np.any([pretrained_model_string == s for s in all_models])): + pretrained_model_string = default_model + models_logger.warning("model_type does not exist / has incorrect path") - if (pretrained_model and not os.path.exists(pretrained_model[0])): + if (pretrained_model and not os.path.exists(pretrained_model)): models_logger.warning("pretrained model has incorrect path") models_logger.info(f">> {pretrained_model_string} << model set to be used") @@ -266,7 +274,6 @@ def __init__(self, gpu=False, pretrained_model=False, model_type=None, else: self.diam_mean = 30. pretrained_model = model_path(pretrained_model_string) - else: builtin = False if pretrained_model: @@ -290,10 +297,15 @@ def __init__(self, gpu=False, pretrained_model=False, model_type=None, nbase = [32, 64, 128, 256] self.nbase = [nchan, *nbase] - self.net = CPnet(self.nbase, self.nclasses, sz=3, mkldnn=self.mkldnn, - max_pool=True, diam_mean=diam_mean).to(self.device) - self.pretrained_model = pretrained_model + if backbone=="default": + self.net = CPnet(self.nbase, self.nclasses, sz=3, mkldnn=self.mkldnn, + max_pool=True, diam_mean=diam_mean).to(self.device) + else: + from .segformer import Transformer + self.net = Transformer(encoder_weights="imagenet" if not self.pretrained_model else None, + diam_mean=diam_mean).to(self.device) + if self.pretrained_model: self.net.load_model(self.pretrained_model, device=self.device) if not builtin: @@ -307,7 +319,7 @@ def __init__(self, gpu=False, pretrained_model=False, model_type=None, f">>>> model diam_labels = {self.diam_labels: .3f} (mean diameter of training ROIs)" ) - self.net_type = "cellpose" + self.net_type = f"cellpose_{backbone}" def eval(self, x, batch_size=8, resample=True, channels=None, channel_axis=None, z_axis=None, normalize=True, invert=False, rescale=None, diameter=None, diff --git a/cellpose/segformer.py b/cellpose/segformer.py new file mode 100644 index 00000000..a15a2884 --- /dev/null +++ b/cellpose/segformer.py @@ -0,0 +1,78 @@ +""" +Copyright © 2023 Howard Hughes Medical Institute, Authored by Carsen Stringer and Marius Pachitariu. +""" +import torch +from torch import nn + +try: + import segmentation_models_pytorch as smp + + class Transformer(nn.Module): + """ Transformer encoder from segformer paper with MAnet decoder + (configuration from MEDIAR) + """ + def __init__(self, encoder="mit_b5", + encoder_weights=None, decoder="MAnet", + diam_mean=30.): + super().__init__() + net_fcn = smp.MAnet if decoder == "MAnet" else smp.FPN + self.encoder = encoder + self.decoder = decoder + self.net = net_fcn( + encoder_name=encoder, + encoder_weights=encoder_weights, + # (use "imagenet" pre-trained weights for encoder initialization if training) + in_channels=3, + classes=3, + activation=None) + self.nout = 3 + self.mkldnn = False + self.diam_mean = nn.Parameter(data=torch.ones(1) * diam_mean, + requires_grad=False) + self.diam_labels = nn.Parameter(data=torch.ones(1) * diam_mean, + requires_grad=False) + + def forward(self, X): + # have to convert to 3-chan (RGB) + if X.shape[1] < 3: + X = torch.cat( + (X, torch.zeros( + (X.shape[0], 3-X.shape[1], X.shape[2], X.shape[3]), device=X.device)), dim=1) + y = self.net(X) + return y, torch.zeros((X.shape[0], 256), device=X.device) + + @property + def device(self): + return next(self.parameters()).device + + def save_model(self, filename): + """ + Save the model to a file. + + Args: + filename (str): The path to the file where the model will be saved. + """ + torch.save(self.state_dict(), filename) + + def load_model(self, filename, device=None): + """ + Load the model from a file. + + Args: + filename (str): The path to the file where the model is saved. + device (torch.device, optional): The device to load the model on. Defaults to None. + """ + if (device is not None) and (device.type != "cpu"): + state_dict = torch.load(filename, map_location=device) + else: + self.__init__(encoder=self.encoder, decoder=self.decoder, + diam_mean=self.diam_mean) + state_dict = torch.load(filename, map_location=torch.device("cpu")) + + self.load_state_dict( + dict([(name, param) for name, param in state_dict.items()]), + strict=False) + +except Exception as e: + print(e) + print("need to install segmentation_models_pytorch to run transformer") From 5b2d834ab9d5421d89fe4a7be077d7c5f7157ca4 Mon Sep 17 00:00:00 2001 From: Carsen Stringer Date: Mon, 1 Apr 2024 15:18:42 -0400 Subject: [PATCH 7/9] updating cli for size model and setting batch size to 64 --- cellpose/__main__.py | 14 ++++++++++---- cellpose/models.py | 2 +- cellpose/train.py | 14 ++++++++------ 3 files changed, 19 insertions(+), 11 deletions(-) diff --git a/cellpose/__main__.py b/cellpose/__main__.py index 4eb60d07..c48c47c6 100644 --- a/cellpose/__main__.py +++ b/cellpose/__main__.py @@ -123,8 +123,7 @@ def main(): else: szmean = 30. builtin_size = (model_type == "cyto" or model_type == "cyto2" or - model_type == "nuclei" or model_type == "cyto3" or - model_type=="transformer_cp3") + model_type == "nuclei" or model_type == "cyto3") if len(args.image_path) > 0 and (args.train or args.train_size): raise ValueError("ERROR: cannot train model with single image input") @@ -312,8 +311,15 @@ def main(): ] if test_labels is not None else test_labels # data has already been normalized and reshaped sz_model.params = train.train_size(model.net, model.pretrained_model, - images, masks, test_images, - test_masks, channels=channels, + images, labels, train_files=image_names, + test_data=test_images, test_labels=test_labels, + test_files=image_names_test, + train_probs=train_probs, test_probs=test_probs, + load_files=load_files, channels=channels, + min_train_masks=args.min_train_masks, + channel_axis=args.channel_axis, rgb=(nchan==3), + nimg_per_epoch=args.nimg_per_epoch, normalize=(not args.no_norm), + nimg_test_per_epoch=args.nimg_test_per_epoch, batch_size=args.batch_size) if test_images is not None: predicted_diams, diams_style = sz_model.eval( diff --git a/cellpose/models.py b/cellpose/models.py index b73a9398..c77a740f 100644 --- a/cellpose/models.py +++ b/cellpose/models.py @@ -260,7 +260,7 @@ def __init__(self, gpu=False, pretrained_model=False, model_type=None, all_models.extend(model_strings) if ~np.any([pretrained_model_string == s for s in MODEL_NAMES]): builtin = False - if (not os.path.exists(model_type) and + if (not os.path.exists(pretrained_model_string) and ~np.any([pretrained_model_string == s for s in all_models])): pretrained_model_string = default_model models_logger.warning("model_type does not exist / has incorrect path") diff --git a/cellpose/train.py b/cellpose/train.py index 715f7478..aa70d656 100644 --- a/cellpose/train.py +++ b/cellpose/train.py @@ -337,10 +337,12 @@ def train_seg(net, train_data=None, train_labels=None, train_files=None, train_data (List[np.ndarray], optional): List of arrays (2D or 3D) - images for training. Defaults to None. train_labels (List[np.ndarray], optional): List of arrays (2D or 3D) - labels for train_data, where 0=no masks; 1,2,...=mask labels. Defaults to None. train_files (List[str], optional): List of strings - file names for images in train_data (to save flows for future runs). Defaults to None. + train_labels_files (list or None): List of training label file paths. Defaults to None. train_probs (List[float], optional): List of floats - probabilities for each image to be selected during training. Defaults to None. test_data (List[np.ndarray], optional): List of arrays (2D or 3D) - images for testing. Defaults to None. test_labels (List[np.ndarray], optional): List of arrays (2D or 3D) - labels for test_data, where 0=no masks; 1,2,...=mask labels. Defaults to None. test_files (List[str], optional): List of strings - file names for images in test_data (to save flows for future runs). Defaults to None. + test_labels_files (list or None): List of test label file paths. Defaults to None. test_probs (List[float], optional): List of floats - probabilities for each image to be selected during testing. Defaults to None. load_files (bool, optional): Boolean - whether to load images and labels from files. Defaults to True. batch_size (int, optional): Integer - number of patches to run simultaneously on the GPU. Defaults to 8. @@ -519,7 +521,7 @@ def train_size(net, pretrained_model, train_data=None, train_labels=None, test_labels_files=None, test_probs=None, load_files=True, min_train_masks=5, channels=None, channel_axis=None, rgb=False, normalize=True, nimg_per_epoch=None, nimg_test_per_epoch=None, - batch_size=128, scale_range=1.0, bsize=512, + batch_size=64, scale_range=1.0, bsize=512, l2_regularization=1.0, n_epochs=10): """Train the size model. @@ -543,7 +545,7 @@ def train_size(net, pretrained_model, train_data=None, train_labels=None, normalize (bool or dict, optional): Whether to normalize the data. Defaults to True. nimg_per_epoch (int, optional): The number of images per epoch. Defaults to None. nimg_test_per_epoch (int, optional): The number of test images per epoch. Defaults to None. - batch_size (int, optional): The batch size. Defaults to 128. + batch_size (int, optional): The batch size. Defaults to 64. l2_regularization (float, optional): The L2 regularization factor. Defaults to 1.0. n_epochs (int, optional): The number of epochs. Defaults to 10. @@ -600,7 +602,7 @@ def train_size(net, pretrained_model, train_data=None, train_labels=None, inds_batch = np.arange(ibatch, min(nimg_per_epoch, ibatch + batch_size)) inds = rperm[inds_batch] imgs, lbls = _get_batch(inds, data=train_data, labels=train_labels, - files=train_files, labels_files=train_labels_files, + files=train_files, **kwargs) diami = diam_train[inds].copy() imgi, lbl, scale = transforms.random_rotate_and_resize( @@ -632,7 +634,7 @@ def train_size(net, pretrained_model, train_data=None, train_labels=None, np.random.seed(0) styles_test = np.zeros((nimg_test_per_epoch, 256), np.float32) diams_test = np.zeros((nimg_test_per_epoch,), np.float32) - diam_test = np.zeros((nimg_test_per_epoch,), np.float32) + diams_test0 = np.zeros((nimg_test_per_epoch,), np.float32) if nimg_test != nimg_test_per_epoch: rperm = np.random.choice(np.arange(0, nimg_test), size=(nimg_test_per_epoch,), p=test_probs) @@ -655,12 +657,12 @@ def train_size(net, pretrained_model, train_data=None, train_labels=None, feat = net(imgi)[1] styles_test[inds_batch] = feat.cpu().numpy() diams_test[inds_batch] = np.log(diami) - np.log(diam_mean) + np.log(scale) - diam_test[inds_batch] = diamt + diams_test0[inds_batch] = diamt diam_test_pred = np.exp(A @ (styles_test - smean).T + np.log(diam_mean) + ymean) diam_test_pred = np.maximum(5., diam_test_pred) train_logger.info("test correlation: %0.4f" % - np.corrcoef(diam_test, diam_test_pred)[0, 1]) + np.corrcoef(diams_test0, diam_test_pred)[0, 1]) pretrained_size = str(pretrained_model) + "_size.npy" params = {"A": A, "smean": smean, "diam_mean": diam_mean, "ymean": ymean} From 8c8e9d8a93bb204b0f9787037d1109b87600a8c7 Mon Sep 17 00:00:00 2001 From: Carsen Stringer Date: Sun, 7 Apr 2024 09:59:59 -0400 Subject: [PATCH 8/9] adding train_size for file sampling --- cellpose/__main__.py | 25 +-- cellpose/core.py | 2 +- cellpose/gui/menus.py | 6 - cellpose/io.py | 41 +--- cellpose/key/cellpose-data-writer.json | 13 -- cellpose/models.py | 103 +++++---- cellpose/train.py | 4 +- cellpose/transforms.py | 1 - paper/3.0/fig_utils.py | 12 -- paper/neurips/analysis.py | 160 ++++++++++++++ paper/neurips/fig_utils.py | 37 ++++ paper/neurips/figures.py | 288 +++++++++++++++++++++++++ setup.py | 1 - tests/test_import.py | 4 +- 14 files changed, 564 insertions(+), 133 deletions(-) delete mode 100644 cellpose/key/cellpose-data-writer.json create mode 100644 paper/neurips/analysis.py create mode 100644 paper/neurips/fig_utils.py create mode 100644 paper/neurips/figures.py diff --git a/cellpose/__main__.py b/cellpose/__main__.py index c48c47c6..5e301513 100644 --- a/cellpose/__main__.py +++ b/cellpose/__main__.py @@ -236,14 +236,15 @@ def main(): test_dir = None if len(args.test_dir) == 0 else args.test_dir images, labels, image_names, train_probs = None, None, None, None test_images, test_labels, image_names_test, test_probs = None, None, None, None + compute_flows = False if len(args.file_list) > 0: if os.path.exists(args.file_list): dat = np.load(args.file_list, allow_pickle=True).item() image_names = dat["train_files"] - image_names_test = dat["test_files"] - if "train_probs" in dat: - train_probs = dat["train_probs"] - test_probs = dat["test_probs"] + image_names_test = dat.get("test_files", None) + train_probs = dat.get("train_probs", None) + test_probs = dat.get("test_probs", None) + compute_flows = dat.get("compute_flows", False) load_files = False else: logger.critical(f"ERROR: {args.file_list} does not exist") @@ -263,7 +264,7 @@ def main(): channels = None else: nchan = 2 - + # model path szmean = args.diam_mean if not os.path.exists(pretrained_model) and model_type is None: @@ -291,14 +292,15 @@ def main(): test_data=test_images, test_labels=test_labels, test_files=image_names_test, train_probs=train_probs, test_probs=test_probs, - load_files=load_files, learning_rate=args.learning_rate, - weight_decay=args.weight_decay, channels=channels, + compute_flows=compute_flows, load_files=load_files, + normalize=(not args.no_norm), channels=channels, channel_axis=args.channel_axis, rgb=(nchan==3), - save_path=os.path.realpath(args.dir), save_every=args.save_every, + learning_rate=args.learning_rate, weight_decay=args.weight_decay, SGD=args.SGD, n_epochs=args.n_epochs, batch_size=args.batch_size, min_train_masks=args.min_train_masks, - nimg_per_epoch=args.nimg_per_epoch, normalize=(not args.no_norm), + nimg_per_epoch=args.nimg_per_epoch, nimg_test_per_epoch=args.nimg_test_per_epoch, + save_path=os.path.realpath(args.dir), save_every=args.save_every, model_name=args.model_name_out) model.pretrained_model = cpmodel_path logger.info(">>>> model trained and saved to %s" % cpmodel_path) @@ -306,9 +308,6 @@ def main(): # train size model if args.train_size: sz_model = models.SizeModel(cp_model=model, device=device) - masks = [lbl[0] for lbl in labels] - test_masks = [lbl[0] for lbl in test_labels - ] if test_labels is not None else test_labels # data has already been normalized and reshaped sz_model.params = train.train_size(model.net, model.pretrained_model, images, labels, train_files=image_names, @@ -322,6 +321,8 @@ def main(): nimg_test_per_epoch=args.nimg_test_per_epoch, batch_size=args.batch_size) if test_images is not None: + test_masks = [lbl[0] for lbl in test_labels + ] if test_labels is not None else test_labels predicted_diams, diams_style = sz_model.eval( test_images, channels=channels) ccs = np.corrcoef( diff --git a/cellpose/core.py b/cellpose/core.py index 822b8eb3..746c1d38 100644 --- a/cellpose/core.py +++ b/cellpose/core.py @@ -188,6 +188,7 @@ def run_net(net, imgs, batch_size=8, augment=False, tile=True, tile_overlap=0.1, (faster if augment is False) Args: + net (class): cellpose network (model.net) imgs (np.ndarray): The input image or stack of images of size [Ly x Lx x nchan] or [Lz x Ly x Lx x nchan]. batch_size (int, optional): Number of tiles to run in a batch. Defaults to 8. rsz (float, optional): Resize coefficient(s) for image. Defaults to 1.0. @@ -240,7 +241,6 @@ def run_net(net, imgs, batch_size=8, augment=False, tile=True, tile_overlap=0.1, return y, style - def _run_tiled(net, imgi, batch_size=8, augment=False, bsize=224, tile_overlap=0.1): """ Run network on tiles of size [bsize x bsize] diff --git a/cellpose/gui/menus.py b/cellpose/gui/menus.py index 202dbacf..36ae7f14 100644 --- a/cellpose/gui/menus.py +++ b/cellpose/gui/menus.py @@ -70,12 +70,6 @@ def mainmenu(parent): file_menu.addAction(parent.saveFlows) parent.saveFlows.setEnabled(False) - parent.saveServer = QAction("Send manually labelled data to server", parent) - parent.saveServer.triggered.connect(lambda: save_server(parent)) - file_menu.addAction(parent.saveServer) - parent.saveServer.setEnabled(False) - - def editmenu(parent): main_menu = parent.menuBar() edit_menu = main_menu.addMenu("&Edit") diff --git a/cellpose/io.py b/cellpose/io.py index 2abcd6b4..776d1cfd 100644 --- a/cellpose/io.py +++ b/cellpose/io.py @@ -744,43 +744,4 @@ def save_masks(images, masks, flows, file_names, png=True, tif=False, channels=[ imsave(os.path.join(flowdir, basename + "_flows" + suffix + ".tif"), (flows[0] * (2**16 - 1)).astype(np.uint16)) #save full flow data - imsave(os.path.join(flowdir, basename + "_dP" + suffix + ".tif"), flows[1]) - - -def save_server(parent=None, filename=None): - """ Uploads a *_seg.npy file to the bucket. - - Args: - parent (PyQt.MainWindow, optional): GUI window to grab file info from. Defaults to None. - filename (str, optional): if no GUI, send this file to server. Defaults to None. - """ - if parent is not None: - q = QMessageBox.question( - parent, "Send to server", - "Are you sure? Only send complete and fully manually segmented data.\n (do not send partially automated segmentations)", - QMessageBox.Yes | QMessageBox.No) - if q != QMessageBox.Yes: - return - else: - filename = parent.filename - - if filename is not None: - os.environ["GOOGLE_APPLICATION_CREDENTIALS"] = os.path.join( - os.path.dirname(os.path.realpath(__file__)), - "key/cellpose-data-writer.json") - bucket_name = "cellpose_data" - base = os.path.splitext(filename)[0] - source_file_name = base + "_seg.npy" - io_logger.info(f"sending {source_file_name} to server") - time = datetime.datetime.now().strftime("%Y_%m_%d_%H_%M_%S.%f") - filestring = time + ".npy" - io_logger.info(f"name on server: {filestring}") - destination_blob_name = filestring - storage_client = storage.Client() - bucket = storage_client.bucket(bucket_name) - blob = bucket.blob(destination_blob_name) - - blob.upload_from_filename(source_file_name) - - io_logger.info("File {} uploaded to {}.".format(source_file_name, - destination_blob_name)) + imsave(os.path.join(flowdir, basename + "_dP" + suffix + ".tif"), flows[1]) \ No newline at end of file diff --git a/cellpose/key/cellpose-data-writer.json b/cellpose/key/cellpose-data-writer.json deleted file mode 100644 index 3feaf95b..00000000 --- a/cellpose/key/cellpose-data-writer.json +++ /dev/null @@ -1,13 +0,0 @@ -{ - "type": "service_account", - "project_id": "pachitariu-lab", - "private_key_id": "a112e329e705a5c51bfbba5277f52116e4bfc1af", - "private_key": "-----BEGIN PRIVATE KEY-----\nMIIEvgIBADANBgkqhkiG9w0BAQEFAASCBKgwggSkAgEAAoIBAQCsmcCIGjXqyWsl\n++lxWxzSL5QWQTPkIyP+j99gFJQ1c/Z4PYC7YrtIdfu/PsLwqBAvOHJCsrq0m+y0\nB2vqUJJ2mcvKQaWygAWKwoikFUI1AtjkH9GnSHxXmTH9WqXBw2pQv1xBfOvoMptJ\n8ArXyAMyBzSYN0xyOqEQEsrSzqEUvZw467YZmyf24ZhjFXze0Rrn1znkQprMB8GK\nPz7Vr6iiRPd8mtjzjS2WC0Fat+RIhh0o5lb93woGoFbWAhvuK8xadGOYP80Mcr/b\nsgFM1D2k/HHdjJiHRe7jwRMIRhA5478QwoBNN00W/zEgDV+8SSl9rs3yUAUJG9mf\nEHkdlzexAgMBAAECggEAUFbMyE0y9ZNFfYuxUGMxmiAtVOKKrdExiuca+VT625qb\nicJO7mn5dLP+NzmWcYA48FHc1XDt+O1vEyk1MP7J/cx+kClYYCq46aq9AWsnwxcN\nL7oj0zKpNfkHzL7p0rQMA4PfBFiKUi1kHNlPorrlyd6Su5tZyP3DRIEKyW8GiWko\n19DEcBUhe2uEW+claFSiy6fYfXFXtzYln8mWKAWjOxw6LcQBc6KRMdYh79d09/bj\nthnnNeMLK6FSiKTXuT/a84qzxNkj549H0ILwolVKn1vPe35WV7ZJJpugYvnbek5s\nZjwN2fMUa/GD9agbfS4LXbnC4hYwUnLhAn2zSEjlBQKBgQDWxwNoo0CuEB8s+VCN\n9J80vqTQfXSnSqoAS0BaDaivGoNUHtRGjXs+Qqrzlikk9032s9eRtX+q7uhR7wIo\nIilO1GhVCr+OYPWTBoJ/ARgDrhqgXPjyQ97wSDz9PHloFnGHf+HFcqzhlhNGaoS9\nQaZO3nhx3Du7YV8kaefm0iODnwKBgQDNumbAi0L9jTDguOtD/3w32MuodQfL35Ok\nRvdkQphZL2jFqiztlwVU11SOibnVc7ZAjFt4tPSS0veFQmC8JfnpdYSDk1NodwMG\nhC3UTyjrvIltEmPzSQczRMK5PCRZnvP6SBKnUSEsAk+bZO1nyFNPNNmn5J/LQ+24\n2MGAgxcCrwKBgB0pFBtm3udDJRh0GS3M4rjEkZgFEIuOJZq4nNodNKPhk6ceMHAL\n0YnYf2FnJ9rvANTYAhK0c8r/eOd27fIJAVbEnA2/0dZA79awcZNQ0LPfNZpERUCP\nWnuBM1ammU06jtt4z2yBb1uJhsBuwer4ON5Ick3zOuDsDYDiKCw8p7m9AoGBAIpp\nuAohaA/pN5JqN7eHI877KIKNQpKTOOVU7cth1thiQk6DITk021xqh7Riy0nmUR96\nj2xV6xsBn5DjyOutbUf6Tg6sR3jIYZu3wJHQNIruTVO6BM9BOfvvbkdsRFSb0jB4\n3zv9JKFUaLT3IZcqu4pV137THgOHD2DHTOEm0Yt3AoGBALMOKoajLFTBT+9bEpXZ\nB6ES4W1KNPKw+1Y0n2gGJxL+zAXFm1MJozJbyZSwZHb8TyOMNqP59+tzh080oUoc\nlMWJzS4xxuGx+JAULtDT4ko3+3Q19H3dyJNAW9SoY1lX47JMrEB1qYLNx7o78nzB\niCWxdweJjlpKijcUP9keCmIW\n-----END PRIVATE KEY-----\n", - "client_email": "cellpose-data-writer@pachitariu-lab.iam.gserviceaccount.com", - "client_id": "105326682635824364397", - "auth_uri": "https://accounts.google.com/o/oauth2/auth", - "token_uri": "https://oauth2.googleapis.com/token", - "auth_provider_x509_cert_url": "https://www.googleapis.com/oauth2/v1/certs", - "client_x509_cert_url": "https://www.googleapis.com/robot/v1/metadata/x509/cellpose-data-writer%40pachitariu-lab.iam.gserviceaccount.com" -} - diff --git a/cellpose/models.py b/cellpose/models.py index c77a740f..3cff3f49 100644 --- a/cellpose/models.py +++ b/cellpose/models.py @@ -25,7 +25,8 @@ MODEL_NAMES = [ "cyto3", "nuclei", "cyto2_cp3", "tissuenet_cp3", "livecell_cp3", "yeast_PhC_cp3", "yeast_BF_cp3", "bact_phase_cp3", "bact_fluor_cp3", "deepbacs_cp3", "cyto2", "cyto", - "transformer_cp3" + "transformer_cp3", "neurips_cellpose_default", "neurips_cellpose_transformer", + "neurips_grayscale_cyto2" ] MODEL_LIST_PATH = os.fspath(MODEL_DIR.joinpath("gui_models.txt")) @@ -43,16 +44,13 @@ } def model_path(model_type, model_index=0): - if not os.path.exists(model_type): - torch_str = "torch" - if model_type == "cyto" or model_type == "cyto2" or model_type == "nuclei": - basename = "%s%s_%d" % (model_type, torch_str, model_index) - else: - basename = model_type - return cache_model_path(basename) + torch_str = "torch" + if model_type == "cyto" or model_type == "cyto2" or model_type == "nuclei": + basename = "%s%s_%d" % (model_type, torch_str, model_index) else: - return model_type - + basename = model_type + return cache_model_path(basename) + def size_model_path(model_type): if os.path.exists(model_type): return model_type + "_size.npy" @@ -139,7 +137,7 @@ def __init__(self, gpu=False, model_type="cyto3", nchan=2, self.sz.model_type = model_type def eval(self, x, batch_size=8, channels=[0,0], channel_axis=None, invert=False, - normalize=True, diameter=30., do_3D=False, **kwargs): + normalize=True, diameter=30., do_3D=False, find_masks=True, **kwargs): """Run cellpose size model and mask model and get masks. Args: @@ -197,14 +195,12 @@ def eval(self, x, batch_size=8, channels=[0,0], channel_axis=None, invert=False, else: diams = diameter - tic = time.time() models_logger.info("~~~ FINDING MASKS ~~~") masks, flows, styles = self.cp.eval(x, channels=channels, channel_axis=channel_axis, batch_size=batch_size, normalize=normalize, invert=invert, diameter=diams, do_3D=do_3D, **kwargs) - models_logger.info(">>>> TOTAL TIME %0.2f sec" % (time.time() - tic0)) return masks, flows, styles, diams @@ -250,37 +246,47 @@ def __init__(self, gpu=False, pretrained_model=False, model_type=None, nchan (int, optional): Number of channels to use as input to network, default is 2 (cyto + nuclei) or (nuclei + zeros). """ self.diam_mean = diam_mean - builtin = True + + ### set model path default_model = "cyto3" if backbone=="default" else "transformer_cp3" - if model_type is not None or (pretrained_model and - not os.path.exists(pretrained_model)): - pretrained_model_string = model_type if model_type is not None else default_model - model_strings = get_user_models() - all_models = MODEL_NAMES.copy() - all_models.extend(model_strings) - if ~np.any([pretrained_model_string == s for s in MODEL_NAMES]): - builtin = False - if (not os.path.exists(pretrained_model_string) and - ~np.any([pretrained_model_string == s for s in all_models])): - pretrained_model_string = default_model - models_logger.warning("model_type does not exist / has incorrect path") - - if (pretrained_model and not os.path.exists(pretrained_model)): - models_logger.warning("pretrained model has incorrect path") - models_logger.info(f">> {pretrained_model_string} << model set to be used") - - if pretrained_model_string == "nuclei": + builtin = False + use_default = False + model_strings = get_user_models() + all_models = MODEL_NAMES.copy() + all_models.extend(model_strings) + + # check if pretrained_model is builtin or custom user model saved in .cellpose/models + # if yes, then set to model_type + if (pretrained_model and not Path(pretrained_model).exists() and + np.any([pretrained_model == s for s in all_models])): + model_type = pretrained_model + + # check if model_type is builtin or custom user model saved in .cellpose/models + if model_type is not None and np.any([model_type == s for s in all_models]): + if np.any([model_type == s for s in MODEL_NAMES]): + builtin = True + models_logger.info(f">> {model_type} << model set to be used") + if model_type == "nuclei": self.diam_mean = 17. + pretrained_model = model_path(model_type) + # if model_type is not None and does not exist, use default model + elif model_type is not None: + if Path(model_type).exists(): + pretrained_model = model_type else: - self.diam_mean = 30. - pretrained_model = model_path(pretrained_model_string) + models_logger.warning("model_type does not exist, using default model") + use_default = True + # if model_type is None... else: - builtin = False - if pretrained_model: - pretrained_model_string = pretrained_model - models_logger.info(f">>>> loading model {pretrained_model_string}") - - # assign network device + # if pretrained_model does not exist, use default model + if pretrained_model and not Path(pretrained_model).exists(): + models_logger.warning("pretrained_model path does not exist, using default model") + use_default = True + + builtin = True if use_default else builtin + self.pretrained_model = model_path(default_model) if use_default else pretrained_model + + ### assign model device self.mkldnn = None if device is None: sdevice, gpu = assign_device(use_torch=True, gpu=gpu) @@ -291,12 +297,11 @@ def __init__(self, gpu=False, pretrained_model=False, model_type=None, if not self.gpu: self.mkldnn = check_mkl(True) - # create network + ### create neural network self.nchan = nchan self.nclasses = 3 nbase = [32, 64, 128, 256] self.nbase = [nchan, *nbase] - self.pretrained_model = pretrained_model if backbone=="default": self.net = CPnet(self.nbase, self.nclasses, sz=3, mkldnn=self.mkldnn, @@ -306,7 +311,9 @@ def __init__(self, gpu=False, pretrained_model=False, model_type=None, self.net = Transformer(encoder_weights="imagenet" if not self.pretrained_model else None, diam_mean=diam_mean).to(self.device) + ### load model weights if self.pretrained_model: + models_logger.info(f">>>> loading model {pretrained_model}") self.net.load_model(self.pretrained_model, device=self.device) if not builtin: self.diam_mean = self.net.diam_mean.data.cpu().numpy()[0] @@ -318,6 +325,9 @@ def __init__(self, gpu=False, pretrained_model=False, model_type=None, models_logger.info( f">>>> model diam_labels = {self.diam_labels: .3f} (mean diameter of training ROIs)" ) + else: + models_logger.info(f">>>> no model weights loaded") + self.diam_labels = self.diam_mean self.net_type = f"cellpose_{backbone}" @@ -382,12 +392,14 @@ def eval(self, x, batch_size=8, resample=True, channels=None, channel_axis=None, """ if isinstance(x, list) or x.squeeze().ndim == 5: + self.timing = [] masks, styles, flows = [], [], [] tqdm_out = utils.TqdmToLogger(models_logger, level=logging.INFO) nimg = len(x) iterator = trange(nimg, file=tqdm_out, mininterval=30) if nimg > 1 else range(nimg) for i in iterator: + tic = time.time() maski, flowi, stylei = self.eval( x[i], batch_size=batch_size, channels=channels[i] if channels is not None and @@ -409,6 +421,7 @@ def eval(self, x, batch_size=8, resample=True, channels=None, channel_axis=None, masks.append(maski) flows.append(flowi) styles.append(stylei) + self.timing.append(time.time() - tic) return masks, flows, styles else: @@ -499,7 +512,7 @@ def _run_cp(self, x, compute_masks=True, normalize=True, invert=False, niter=Non if rescale != 1.0: img = transforms.resize_image(img, rsz=rescale) yf, style = run_net(self.net, img, bsize=bsize, augment=augment, - tile=tile, tile_overlap=tile_overlap) + tile=tile, tile_overlap=tile_overlap) if resample: yf = transforms.resize_image(yf, shape[1], shape[2]) @@ -661,14 +674,15 @@ def eval(self, x, channels=None, channel_axis=None, normalize=True, invert=False - diam (np.ndarray): Final estimated diameters from images x or styles style after running both steps. - diam_style (np.ndarray): Estimated diameters from style alone. """ - if isinstance(x, list): + self.timing = [] diams, diams_style = [], [] nimg = len(x) tqdm_out = utils.TqdmToLogger(models_logger, level=logging.INFO) iterator = trange(nimg, file=tqdm_out, mininterval=30) if nimg > 1 else range(nimg) for i in iterator: + tic = time.time() diam, diam_style = self.eval( x[i], channels=channels[i] if (channels is not None and len(channels) == len(x) and @@ -679,6 +693,7 @@ def eval(self, x, channels=None, channel_axis=None, normalize=True, invert=False batch_size=batch_size, progress=progress) diams.append(diam) diams_style.append(diam_style) + self.timing.append(time.time() - tic) return diams, diams_style diff --git a/cellpose/train.py b/cellpose/train.py index aa70d656..8c86ebb8 100644 --- a/cellpose/train.py +++ b/cellpose/train.py @@ -192,7 +192,7 @@ def _process_train_test(train_data=None, train_labels=None, train_files=None, os.path.splitext(str(tf))[0] + "_flows.tif" for tf in train_files ] train_labels_files = [tf for tf in train_labels_files if os.path.exists(tf)] - if test_data is not None or test_files is not None and test_labels_files is None: + if (test_data is not None or test_files is not None) and test_labels_files is None: test_labels_files = [ os.path.splitext(str(tf))[0] + "_flows.tif" for tf in test_files ] @@ -415,7 +415,7 @@ def train_seg(net, train_data=None, train_labels=None, train_files=None, LR = LR[:-50] for i in range(10): LR = np.append(LR, LR[-1] / 2 * np.ones(5)) - n_epochs = len(LR) + LR = LR train_logger.info(f">>> n_epochs={n_epochs}, n_train={nimg}, n_test={nimg_test}") if not SGD: diff --git a/cellpose/transforms.py b/cellpose/transforms.py index 516b5a06..402d0eed 100644 --- a/cellpose/transforms.py +++ b/cellpose/transforms.py @@ -38,7 +38,6 @@ def _taper_mask(ly=224, lx=224, sig=7.5): bsize // 2 - lx // 2:bsize // 2 + lx // 2 + lx % 2] return mask - def unaugment_tiles(y): """Reverse test-time augmentations for averaging (includes flipping of flowsY and flowsX). diff --git a/paper/3.0/fig_utils.py b/paper/3.0/fig_utils.py index 474a8c84..3dc0882d 100644 --- a/paper/3.0/fig_utils.py +++ b/paper/3.0/fig_utils.py @@ -10,17 +10,6 @@ from matplotlib.colors import ListedColormap from cellpose import utils -cmap_emb = ListedColormap(plt.get_cmap("gist_ncar")(np.linspace(0.05, 0.95), 100)) - -kp_colors = np.array([ - [0.55, 0.55, 0.55], - [0., 0., 1], - [0.8, 0, 0], - [1., 0.4, 0.2], - [0, 0.6, 0.4], - [0.2, 1, 0.5], -]) - default_font = 12 rcParams["font.family"] = "Arial" rcParams["savefig.dpi"] = 300 @@ -34,7 +23,6 @@ fs_title = 16 weight_title = "normal" - def plot_label(ltr, il, ax, trans, fs_title=20): ax.text( 0.0, diff --git a/paper/neurips/analysis.py b/paper/neurips/analysis.py new file mode 100644 index 00000000..7d55c446 --- /dev/null +++ b/paper/neurips/analysis.py @@ -0,0 +1,160 @@ +import os +import numpy as np +from cellpose import io, transforms, utils, models, dynamics, metrics, resnet_torch, denoise +from natsort import natsorted +from pathlib import Path +from glob import glob + +from cellpose.io import logger_setup + +def prediction_test_hidden(root): + """ root is path to Hidden folder """ + root = Path(root) + + logger_setup() + # path to images + fall = natsorted(glob((root / "images" / "*").as_posix())) + img_files = [f for f in fall if "_masks" not in f and "_flows" not in f] + + # load images + imgs = [io.imread(f) for f in img_files] + nimg = len(imgs) + + # for 3 channel model, normalize images and convert to 3 channels if needed + imgs_norm = [] + for img in imgs: + if img.ndim==2: + img = np.tile(img[:,:,np.newaxis], (1,1,3)) + img = transforms.normalize_img(img, axis=-1) + imgs_norm.append(img.transpose(2,0,1)) + + dat = {} + for mtype in ["default", "transformer"]: + if mtype=="default": + model = models.Cellpose(gpu=True, nchan=3, model_type="neurips_cellpose_default") + channels = None + normalize = False + diams = None # Cellpose will estimate diameter + elif mtype=="transformer": + model = models.CellposeModel(gpu=True, nchan=3, model_type="neurips_cellpose_transformer", backbone="transformer") + channels = None + normalize = False + diams = dat["diams_pred"] # (use diams from Cellpose default model for transformer) + + out = model.eval(imgs_norm, diameter=diams, + channels=channels, normalize=normalize, + tile_overlap=0.6, augment=True) + # predicted masks + dat[mtype] = out[0] + + if mtype=="default": + diams = out[-1] + dat["diams_pred"] = diams + dat[f"{mtype}_size_timing"] = model.sz.timing + dat[f"{mtype}_mask_timing"] = model.cp.timing + else: + dat[f"{mtype}_mask_timing"] = model.timing + + np.savez_compressed("neurips_test_results.npz", dat) + +def prediction_tuning(root, root2=None): + """ root is path to Tuning folder, root2 is path to mediar results """ + root = Path(root) + logger_setup() + + # path to images and masks + fall = natsorted(glob((root / "images" / "*").as_posix())) + # (exclude last image) + img_files = [f for f in fall if "_masks" not in f and "_flows" not in f][:-1] + mask_files = natsorted(glob((root / "labels" / "*").as_posix()))[:-1] + + # load images and masks + imgs = [io.imread(f) for f in img_files] + masks = [io.imread(f) for f in mask_files] + nimg = len(imgs) + + # for 3 channel model, normalize images and convert to 3 channels if needed + imgs_norm = [] + for img in imgs: + if img.ndim==2: + img = np.tile(img[:,:,np.newaxis], (1,1,3)) + img = transforms.normalize_img(img, axis=-1) + imgs_norm.append(img.transpose(2,0,1)) + + dat = {} + + ### RUN MODELS + model_types = ["grayscale", "default", "transformer", "maetal", "mediar"] + for mtype in model_types[:-1]: + print(mtype) + if mtype=="grayscale" or mtype=="maetal": + if mtype=="grayscale": + model = models.CellposeModel(gpu=True, model_type="neurips_grayscale_cyto2") + else: + ### need to download cellpose model from Ma et al + # https://github.com/JunMa11/NeurIPS-CellSeg/tree/main/cellpose-omnipose-KIT-GE + pretrained_model = "/home/carsen/Downloads/model.501776_epoch_499" + if not os.path.exists(pretrained_model): + print("need to download cellpose model from Ma et al; https://github.com/JunMa11/NeurIPS-CellSeg/tree/main/cellpose-omnipose-KIT-GE") + print("skipping Ma et al model") + del model_types[-2] + break + model = models.CellposeModel(gpu=True, pretrained_model=pretrained_model) + channels = [0, 0] + normalize = True + diams = None # CellposeModel will use mean diameter from training set + elif mtype=="default": + model = models.Cellpose(gpu=True, nchan=3, model_type="neurips_cellpose_default") + channels = None + normalize = False + diams = None # Cellpose will estimate diameter + elif mtype=="transformer": + model = models.CellposeModel(gpu=True, nchan=3, model_type="neurips_cellpose_transformer", backbone="transformer") + channels = None + normalize = False + diams = dat["diams_pred"] # (use diams from Cellpose default model for transformer) + + out = model.eval(imgs if mtype=="grayscale" else imgs_norm, diameter=diams, + channels=channels, normalize=normalize, + tile_overlap=0.6, augment=True) + if mtype=="default": + diams = out[-1] + dat["diams_pred"] = diams + + dat[mtype] = out[0] + + ### load Mediar results + if root2 is not None: + root2 = Path(root2) + masks_pred_mediar = [] + for imgf in img_files: + maskf = root2 / (os.path.splitext(os.path.split(imgf)[-1])[0] + "_label.tiff") + m = io.imread(maskf) + m = np.unique(m, return_inverse=True)[1].reshape(m.shape) + masks_pred_mediar.append(m) + + dat["mediar"] = masks_pred_mediar + else: + print("no path to mediar files specified") + print("skipping mediar") + del model_types[-1] + + ### EVALUATION + thresholds = np.arange(0.5, 1.05, 0.05) + dat["thresholds"] = thresholds + masks_true = [lbl.astype("uint32") for lbl in masks] + for mtype in model_types: + print(mtype) + masks_pred = dat[mtype] + ap, tp, fp, fn = metrics.average_precision(masks_true, masks_pred, threshold=thresholds) + f1 = 2 * tp / (2 * tp + fp + fn) + print(f"{mtype}, F1 score @ 0.5 = {np.median(f1[:,0]):.3f}") + + dat[mtype+"_f1"] = f1 + dat[mtype+"_tp"] = tp + dat[mtype+"_fp"] = fp + dat[mtype+"_fn"] = fn + + np.savez_compressed("neurips_eval_results.npz", dat) + + return imgs_norm, masks, dat \ No newline at end of file diff --git a/paper/neurips/fig_utils.py b/paper/neurips/fig_utils.py new file mode 100644 index 00000000..48070f3d --- /dev/null +++ b/paper/neurips/fig_utils.py @@ -0,0 +1,37 @@ +""" +Copyright © 2024 Howard Hughes Medical Institute, Authored by Carsen Stringer and Marius Pachitariu. +""" +import string +import matplotlib +import matplotlib.pyplot as plt +import matplotlib.transforms as mtransforms +import numpy as np +from matplotlib import rcParams +from matplotlib.colors import ListedColormap +from cellpose import utils + +default_font = 12 +rcParams["font.family"] = "Arial" +rcParams["savefig.dpi"] = 300 +rcParams["axes.spines.top"] = False +rcParams["axes.spines.right"] = False +rcParams["axes.titlelocation"] = "left" +rcParams["axes.titleweight"] = "normal" +rcParams["font.size"] = default_font + +ltr = string.ascii_lowercase +fs_title = 16 +weight_title = "normal" + +def plot_label(ltr, il, ax, trans, fs_title=20): + ax.text( + 0.0, + 1.0, + ltr[il], + transform=ax.transAxes + trans, + va="bottom", + fontsize=fs_title, + fontweight="bold", + ) + il += 1 + return il \ No newline at end of file diff --git a/paper/neurips/figures.py b/paper/neurips/figures.py new file mode 100644 index 00000000..3d4bca8a --- /dev/null +++ b/paper/neurips/figures.py @@ -0,0 +1,288 @@ + +from fig_utils import * +import matplotlib.patheffects as pe + +def fig1(imgs_norm, masks_true, dat, timings, save_fig=False): + fig = plt.figure(figsize=(14, 5.5)) + thresholds = dat["thresholds"] + grid = plt.GridSpec(2, 5, figure=fig, left=0.05, right=0.98, top=0.94, bottom=0.09, + wspace=0.4, hspace=0.6) + il = 0 + transl = mtransforms.ScaledTranslation(-20 / 72, 7 / 72, fig.dpi_scale_trans) + transl1 = mtransforms.ScaledTranslation(-38 / 72, 7 / 72, fig.dpi_scale_trans) + + iex = 54 + img0 = np.clip(imgs_norm[iex].copy().transpose(1,2,0), 0, 1) + xlim = [300, 660] + ylim = [350, 650] + + cols = {"grayscale": [0.5, 0.5, 1], + "maetal": "b", + "default": "g", + "mediar": "r", + "transformer": [0,1,0]} + titles = {"grayscale": "Cellpose (impaired)", + "maetal": "Cellpose (Ma et al)", + "default": "Cellpose (default)", + "mediar": "Mediar", + "transformer": "Cellpose (transformer)"} + + ax = plt.subplot(grid[0,0]) + pos = ax.get_position().bounds + ax.set_position([pos[0] - 0.03, pos[1], pos[2] + 0.035, pos[3]]) + ax.imshow(img0)#, aspect="auto") + ax.set_xlim(xlim) + ax.set_ylim(ylim) + ax.axis("off") + il = plot_label(ltr, il, ax, transl, fs_title) + ax.set_title("Example validation image") + + maskk = masks_true[iex].copy() + outlines_gt = utils.outlines_list(maskk, multiprocessing=False) + + pltmasks = [(0, 1, "maetal"), + (0, 2, "default"), + (1, 0, "mediar"), + (1, 1, "transformer"), + ] + + for k,pltmask in enumerate(pltmasks): + ax = plt.subplot(grid[pltmask[0], pltmask[1]]) + pos = ax.get_position().bounds + ax.set_position([pos[0] - 0.03, pos[1], pos[2] + 0.035, pos[3]]) + il = plot_label(ltr, il, ax, transl, fs_title) + ax.imshow(img0) + maskk = dat[pltmask[2]][iex].copy() + outlines = utils.outlines_list(maskk, multiprocessing=False) + for o in outlines_gt: + ax.plot(o[:, 0], o[:, 1], color=[0.7,0.4,1], lw=1, ls="-") + for o in outlines: + ax.plot(o[:, 0], o[:, 1], color=[1, 1, 0.3], lw=1.5, ls="--") + ax.set_xlim(xlim) + ax.set_ylim(ylim) + ax.axis("off") + if k==0: + ax.set_title("Cellpose (Ma et al, 2024)", color=cols[pltmask[2]]) + else: + ax.set_title(titles[pltmask[2]], color=cols[pltmask[2]]) + if k==0: + ax.text(-0.1, -0.1, "ground-truth", color=[0.7, 0.4, 1], transform=ax.transAxes, + ha="left", fontweight="normal", fontsize="large", + path_effects=[pe.withStroke(linewidth=1, foreground="k")]) + ax.text(-0.1, -0.22, "model", color=[1, 1, 0.3], transform=ax.transAxes, + ha="left", fontweight="normal", fontsize="large", + path_effects=[pe.withStroke(linewidth=1, foreground="k")]) + f1 = dat[pltmask[2]+"_f1"][iex, 0] + ax.text(1, -0.1, f"F1@0.5 = {f1:.2f}", transform=ax.transAxes, ha="right") + + ax = plt.subplot(grid[1,2]) + il = plot_label(ltr, il, ax, transl1, fs_title) + mtypes = ["default", "transformer", "mediar"] + dx = 0.4 + for k, mtype in enumerate(mtypes): + tsec = timings[:,k] + vp = ax.violinplot(tsec, positions=k*np.ones(1), bw_method=0.1, + showextrema=False, showmedians=False)#, quantiles=[[0.25, 0.5, 0.75]]) + ax.plot(dx*np.arange(-1, 2, 2) + k, + np.median(tsec) * np.ones(2), + color=cols[mtype]) + vp["bodies"][0].set_facecolor(cols[mtype]) + ax.text(k+0.2 if k>0 else k-0.1, -1, titles[mtype].replace(" (", "\n("), + color=cols[mtype], rotation=0, + va="top", ha="center") + ax.set_xticklabels([]) + ax.text(-0.1, 1.05, "Test set runtimes", + fontsize="large", transform=ax.transAxes) + ax.set_ylabel("runtime per image (sec.)") + + ax = plt.subplot(grid[:2, 3]) + il = plot_label(ltr, il, ax, transl1, fs_title) + f1s = np.array([[0.8612, 0.8346, 0.7976, 0.7013, 0.4116], + [0.8484, 0.8190, 0.7761, 0.6744, 0.3907], + [0.8263, 0.7903, 0.7371, 0.6063, 0.2911], + ]) + mtypes = ["default", "transformer", "mediar"] + for k, mtype in enumerate(mtypes): + ax.plot(np.arange(0.5, 1, 0.1), f1s[k], color=cols[mtype], lw=3) + ax.text(0.1, 0.5-k*0.13 if k<2 else 0.5-k*0.13+0.05, titles[mtype].replace(" (", "\n("), + color=cols[mtype], fontsize="large", transform=ax.transAxes) + ax.set_ylim([0, 0.9]) + ax.set_xlim([0.49, 0.91]) + ax.set_xticks([0.5, 0.7, 0.9]) + ax.set_xlabel("IoU threshold") + ax.set_ylabel("F1 score") + ax.set_title("Test set results") + + mtypes = ["default", "transformer", "mediar", "maetal", "grayscale"] + dx = 0.3 + stype = "f1" + ax = plt.subplot(grid[0,4]) + for k, mtype in enumerate(mtypes): #enumerate(model_types): + score = dat[f"{mtype}_{stype}"][:,0] + vp = ax.violinplot(score, positions=k*np.ones(1), bw_method=0.1, + showextrema=False, showmedians=False)#, quantiles=[[0.25, 0.5, 0.75]]) + ax.plot(dx*np.arange(-1, 2, 2) + k, + np.median(score) * np.ones(2), + color=cols[mtype]) + vp["bodies"][0].set_facecolor(cols[mtype]) + ax.text(k+0.2, -0.06, titles[mtype].replace(" (", "\n("), + color=cols[mtype], rotation=90, + va="top", ha="center") + ax.text(-0.1, 1.05, "Validation set scores", + fontsize="large", transform=ax.transAxes) + il = plot_label(ltr, il, ax, transl1, fs_title) + ax.set_ylabel("F1 score @ 0.5 IoU") + ax.set_xticks(np.arange(len(mtypes))) + ax.set_yticks(np.arange(0, 1.1, 0.2)) + ax.set_ylim([-0.01, 1.01]) + ax.set_xticklabels([]) + + ax = plt.subplot(grid[1,4]) + for k, mtype in enumerate(mtypes): + ax.errorbar(thresholds, np.median(dat[f"{mtype}_f1"], axis=0), + dat[f"{mtype}_f1"].std(axis=0) / ((dat[f"{mtype}_f1"].shape[0]-1)**0.5), + color=cols[mtype], lw=2, #if mtype=="grayscale" else 1, + ls="--" if mtype=="transformer" else "-", zorder=30 if mtype=="maetal" else 0) + ax.set_ylabel("F1 score") + ax.set_xlabel("IoU threshold") + ax.set_yticks(np.arange(0, 1.1, 0.2)) + ax.set_ylim([-0.01, 1.01]) + ax.set_xlim([0.49, 1.01]) + ax.set_xticks([0.5, 0.75, 1.0]) + + if save_fig: + fig.savefig("figs/fig1_neurips.pdf", dpi=100) + +def fig2(imgs_norm, masks_true, dat, type_names, types, emb, emb_test, save_fig=False): + ids = [0, 3, 56, 55, 81, 75] + + outlines_gt = [utils.outlines_list(masks_true[iex], multiprocessing=False) for iex in ids] + outlines_cp = [utils.outlines_list(dat["default"][iex], multiprocessing=False) for iex in ids] + outlines_m = [utils.outlines_list(dat["mediar"][iex], multiprocessing=False) for iex in ids] + + fig = plt.figure(figsize=(14,10)) + grid = plt.GridSpec(4, 6, figure=fig, left=0.025, right=0.98, top=0.97, bottom=0.04, + wspace=0.1, hspace=0.2) + il = 0 + transl = mtransforms.ScaledTranslation(-20 / 72, 14 / 72, fig.dpi_scale_trans) + transl1 = mtransforms.ScaledTranslation(-18 / 72, 7 / 72, fig.dpi_scale_trans) + + ylims = [[0, 500], [1550, 1950], [450, 700], [250, 500], [300, 700], [400, 800]] + xlims = [[0, 600], [500, 900], [300, 550], [100, 350], [0, 400], [200, 600]] + + for j in range(len(ids)): + iex = ids[j] + img0 = np.clip(imgs_norm[iex].transpose(1,2,0).copy(), 0, 1) + + ax = plt.subplot(grid[0,j]) + maskk = dat["default"][iex].copy() + ax.imshow(img0) + for o in outlines_gt[j]: + ax.plot(o[:, 0], o[:, 1], color=[0.7,0.4,1], lw=2, ls="-", rasterized=True) + for o in outlines_cp[j]: + ax.plot(o[:, 0], o[:, 1], color=[1, 1, 0.3], lw=1.5, ls="--", rasterized=True) + ax.set_ylim(ylims[j]) + ax.set_xlim(xlims[j]) + ax.axis("off") + f1 = dat["default_f1"][iex,0] + ax.text(1, -0.1, f"F1@0.5 = {f1:.2f}", transform=ax.transAxes, ha="right") + if j==0: + ax.text(-0.1, 0.5, "Cellpose (default)", rotation=90, va="center", transform=ax.transAxes) + ax.set_title("Example validation images", y=1.07) + il = plot_label(ltr, il, ax, transl, fs_title) + ax.text(-0.1, -0.18, "ground-truth", color=[0.7, 0.4, 1], transform=ax.transAxes, + ha="left", fontweight="normal", fontsize="large", + path_effects=[pe.withStroke(linewidth=1, foreground="k")]) + ax.text(-0.1, -0.3, "model", color=[1, 1, 0.3], transform=ax.transAxes, + ha="left", fontweight="normal", fontsize="large", + path_effects=[pe.withStroke(linewidth=1, foreground="k")]) + + ax = plt.subplot(grid[1,j]) + maskk = dat["mediar"][iex].copy() + ax.imshow(img0) + for o in outlines_gt[j]: + ax.plot(o[:, 0], o[:, 1], color=[0.7,0.4,1], lw=2, ls="-", rasterized=True) + for o in outlines_m[j]: + ax.plot(o[:, 0], o[:, 1], color=[1, 1, 0.3], lw=1.5, ls="--", rasterized=True) + ax.set_ylim(ylims[j]) + ax.set_xlim(xlims[j]) + f1 = dat["mediar_f1"][iex,0] + ax.text(1, -0.1, f"F1@0.5 = {f1:.2f}", transform=ax.transAxes, ha="right") + ax.axis("off") + if j==0: + ax.text(-0.1, 0.5, "Mediar", rotation=90, va="center", transform=ax.transAxes) + + ax = plt.subplot(grid[2:, :]) + pos = ax.get_position().bounds + ax.set_position([pos[0], pos[1], pos[2]-0.02, pos[3]-0.05]) + grid1 = matplotlib.gridspec.GridSpecFromSubplotSpec(1, 4, subplot_spec=ax, + wspace=0.15, hspace=0.15) + ax.remove() + cols = plt.get_cmap("tab10")(np.linspace(0, 1, 10)) + + ax = plt.subplot(grid1[0,0]) + cols0 = plt.get_cmap("Paired")(np.linspace(0, 1, 12)) + cols = np.zeros((len(type_names), 4)) + cols[:,-1] = 1 + cols[:2] = cols0[:2] + cols[2] = cols0[3] + cols[4] = cols0[-3] + cols[-2:] = cols0[4:6] + cols[3] = np.array([0,1.,1.,1.]) + cols[6] = cols0[6] + cols[7] = cols0[-1] + irand = np.random.permutation(len(emb)-100) + ax.scatter(emb[:-100,1][irand], emb[:-100,0][irand], color=cols[types[:-100]][irand],#, cmap="tab10", + s=1, alpha=0.5, marker="o", rasterized=True, zorder=-10) + new_names = ["Omnipose (fluor.)", "Omnipose (phase)", "Cellpose", "DeepBacs", "Livecell", "Ma et al, 2024", "Nuclei", "Tissuenet", "YeaZ (BF)", "YeaZ (phase)"] + ax.set_title("t-SNE of image style vectors\n(training set)", va="top", y=1.05) + torder = np.array([5, 2, 6, 4, 7, 0, 1, 3, 8, 9]) + for k in range(len(type_names)): + th = (torder==k).nonzero()[0][0] + ax.text(0.9, 0.93-0.045*th, new_names[k], color=cols[k], + transform=ax.transAxes, fontsize="small") + ax.axis("off") + il = plot_label(ltr, il, ax, transl1, fs_title) + + dx = 0.03 + ax = plt.subplot(grid1[0,1]) + pos = ax.get_position().bounds + ax.set_position([pos[0]+dx, pos[1], pos[2], pos[3]]) + ax.scatter(emb[:-100,1], emb[:-100,0], color=0.8*np.ones(3), s=1, rasterized=True) + s1 = ax.scatter(emb[-100:,1], emb[-100:,0], color="k", + s=50, marker="x", alpha=1, lw=0.5, rasterized=True) + s2 = ax.scatter(emb_test[:,1], emb_test[:,0], color="k", facecolors='none', + s=50, marker="o", alpha=1, lw=0.5, rasterized=True) + ax.axis("off") + ax.set_title("Validation and test set\n(Ma et al, 2024)", va="top", y=1.05) + ax.legend([s1, s2], ["validation", "test"], frameon=False, loc="upper left") + il = plot_label(ltr, il, ax, transl1, fs_title) + + ax = plt.subplot(grid1[0,2]) + pos = ax.get_position().bounds + ax.set_position([pos[0]+dx, pos[1], pos[2], pos[3]]) + pos = ax.get_position().bounds + ax.scatter(emb[:-100,1], emb[:-100,0], color=0.8*np.ones(3), s=1, rasterized=True) + im = ax.scatter(emb[-100:,1], emb[-100:,0], c=dat["default_f1"][:,0], lw=2, + s=60, marker="x", alpha=1, cmap="plasma", vmin=0, vmax=1, rasterized=True) + ax.axis("off") + cax = fig.add_axes([pos[0]+pos[2]-0.02, pos[1]+pos[3]-0.12, 0.005, 0.11]) + plt.colorbar(im, cax=cax) + ax.set_title("F1 score for Cellpose (default)") + il = plot_label(ltr, il, ax, transl1, fs_title) + + ax = plt.subplot(grid1[0,3]) + pos = ax.get_position().bounds + ax.set_position([pos[0]+dx, pos[1], pos[2], pos[3]]) + pos = ax.get_position().bounds + ax.scatter(emb[:-100,1], emb[:-100,0], color=0.8*np.ones(3), s=1, rasterized=True) + im = ax.scatter(emb[-100:,1], emb[-100:,0], c=dat["default_f1"][:,0] - dat["mediar_f1"][:,0], lw=2, + s=60, marker="x", alpha=1, cmap="coolwarm", vmin=-0.3, vmax=0.3, rasterized=True) + ax.axis("off") + cax = fig.add_axes([pos[0]+pos[2]-0.02, pos[1]+pos[3]-0.12, 0.005, 0.11]) + plt.colorbar(im, cax=cax) + ax.set_title("$\Delta$F1, Cellpose (default) - Mediar") + il = plot_label(ltr, il, ax, transl1, fs_title) + + if save_fig: + fig.savefig("figs/fig2_neurips.pdf", dpi=200) diff --git a/setup.py b/setup.py index 18a73d10..2a0ae5f0 100644 --- a/setup.py +++ b/setup.py @@ -20,7 +20,6 @@ gui_deps = [ 'pyqtgraph>=0.11.0rc0', "pyqt6", "pyqt6.sip", 'qtpy', 'superqt', - 'google-cloud-storage' ] docs_deps = [ diff --git a/tests/test_import.py b/tests/test_import.py index 2cc9f01e..d0526e31 100644 --- a/tests/test_import.py +++ b/tests/test_import.py @@ -7,7 +7,9 @@ def test_cellpose_imports_without_error(): def test_model_zoo_imports_without_error(): from cellpose import models, denoise for model_name in models.MODEL_NAMES: - model = models.CellposeModel(model_type=model_name) + if "neurips" not in model_name: + model = models.CellposeModel(model_type=model_name, + backbone="transformer" if "transformer" in model_name else "default") def test_gui_imports_without_error(): From 518f92b57b3aed6a76641afc9a130ca34e3ae1a2 Mon Sep 17 00:00:00 2001 From: Carsen Stringer Date: Sun, 7 Apr 2024 19:14:54 -0400 Subject: [PATCH 9/9] removing transformer from test --- tests/test_import.py | 6 ++---- 1 file changed, 2 insertions(+), 4 deletions(-) diff --git a/tests/test_import.py b/tests/test_import.py index d0526e31..89cd7504 100644 --- a/tests/test_import.py +++ b/tests/test_import.py @@ -7,10 +7,8 @@ def test_cellpose_imports_without_error(): def test_model_zoo_imports_without_error(): from cellpose import models, denoise for model_name in models.MODEL_NAMES: - if "neurips" not in model_name: - model = models.CellposeModel(model_type=model_name, - backbone="transformer" if "transformer" in model_name else "default") - + if "neurips" not in model_name and "transformer" not in model_name: + model = models.CellposeModel(model_type=model_name) def test_gui_imports_without_error(): from cellpose import gui