Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Why do I perform so poorly on the tissuenet dataset? #650

Open
wzr0108 opened this issue Jan 15, 2023 · 4 comments
Open

Why do I perform so poorly on the tissuenet dataset? #650

wzr0108 opened this issue Jan 15, 2023 · 4 comments

Comments

@wzr0108
Copy link

wzr0108 commented Jan 15, 2023

i cropped the training images so that the length and width are both 256.
my training code is as follows:

classes = {
        'inner_distance': 1,  # inner distance
        'outer_distance': 1,  # outer distance
        'fgbg': 2,  # foreground/background separation
    }
norm_method = 'whole_image'  # data normalization
model = PanopticNet(
      backbone='resnet50',
      input_shape=train_x.shape[1:],
      norm_method=norm_method,
      num_semantic_classes=classes)

model_name = 'watershed_centroid_nuclear_general_whole'

lr = 1e-4
optimizer = Adam(learning_rate=lr, clipnorm=0.001)
lr_sched = rate_scheduler(lr=lr, decay=0.977)

batch_size = 8
min_objects = 3

transforms = list(classes.keys())
transforms_kwargs = {'outer-distance': {'erosion_width': 0}}

# use augmentation for training but not validation
datagen = image_generators.SemanticDataGenerator(
      rotation_range=180,
      shear_range=0,
      zoom_range=(0.75, 1.25),
      horizontal_flip=True,
      vertical_flip=True)

datagen_val = image_generators.SemanticDataGenerator(
      rotation_range=0,
      shear_range=0,
      zoom_range=0,
      horizontal_flip=0,
      vertical_flip=0)

train_data = datagen.flow(
      {'X': train_x, 'y': train_y},
      seed=seed,
      transforms=transforms,
      transforms_kwargs=transforms_kwargs,
      min_objects=min_objects,
      batch_size=batch_size,
  )

val_data = datagen_val.flow(
      {'X': val_x, 'y': val_y},
      seed=seed,
      transforms=transforms,
      transforms_kwargs=transforms_kwargs,
      min_objects=min_objects,
      batch_size=batch_size)

loss = {}

# Give losses for all of the semantic heads
for layer in model.layers:
      if layer.name.startswith('semantic_'):
          n_classes = layer.output_shape[-1]
          loss[layer.name] = semantic_loss(n_classes)

model.compile(loss=loss, optimizer=optimizer)

model_path = os.path.join("./model_tis_norm_method_whole_seed42_aug", model_name)
loss_path = os.path.join("./loss_tis_norm_method_whole_seed42_aug", model_name)

num_gpus = count_gpus()

print('Training on', num_gpus, 'GPUs.')

train_callbacks = get_callbacks(
      model_path,
      lr_sched=lr_sched,
      tensorboard_log_dir="./logs_tis_norm_method_whole_seed42_aug",
      save_weights_only=num_gpus >= 2,
      monitor='val_loss',
      verbose=1)

loss_history = model.fit(
      train_data,
      steps_per_epoch=train_data.y.shape[0] // batch_size,
      epochs=n_epoch,
      validation_data=val_data,
      validation_steps=val_data.y.shape[0] // batch_size,
      callbacks=train_callbacks)

my performance of recall, precision is as follows:

DeepWatershed - Remove no pixels
____________Object-based statistics____________

Number of true cells:		 124530
Number of predicted cells:	 104750

Correct detections:  67983	Recall: 0.5459%
Incorrect detections: 36767	Precision: 0.649%


Gained detections: 21142	Perc Error 36.9551%
Missed detections: 21709	Perc Error 37.9462%
Splits: 84		Perc Error 0.1468%
Merges: 12205		Perc Error 21.3337%
Catastrophes: 2070		Perc Error 3.6182%
Gained detectionsfrom split: 84
Missed detectionsfrom merge: 16545
True detectionsin catastrophe: 8074
Pred detectionsin catastrophe: 5322
SEG: 0.6857 

Average Pixel IOU (Jaccard Index): 0.727 

DeepWatershed - Remove objects < 100 pixels
____________Object-based statistics____________

Number of true cells:		 101783
Number of predicted cells:	 98473

Correct detections:  65329	Recall: 0.6418%
Incorrect detections: 33144	Precision: 0.6634%


Gained detections: 23251	Perc Error 50.7088%
Missed detections: 12935	Perc Error 28.2103%
Splits: 104		Perc Error 0.2268%
Merges: 8694		Perc Error 18.961%
Catastrophes: 868		Perc Error 1.893%
Gained detectionsfrom split: 104
Missed detectionsfrom merge: 12199
True detectionsin catastrophe: 3390
Pred detectionsin catastrophe: 1859
SEG: 0.6991 

Average Pixel IOU (Jaccard Index): 0.6946 
@wzr0108
Copy link
Author

wzr0108 commented Jan 15, 2023

my testing code is as follows:

classes = {
        'inner_distance': 1,
        'outer_distance': 1,
    }

norm_method = 'whole_image'  # data normalization
prediction_model = PanopticNet(
        backbone='resnet50',
        input_shape=test_x.shape[1:],
        norm_method=norm_method,
        num_semantic_heads=2,
        num_semantic_classes=classes,
        location=True,  # should always be true
        include_top=True)

print("load model from %s" % model_path)
prediction_model.load_weights(model_path, by_name=True)

start = default_timer()

test_images = prediction_model.predict(test_x)
watershed_time = default_timer() - start

print('Watershed segmentation of shape', test_images[0].shape, 'in', watershed_time, 'seconds.')

y_pred = []
for i in range(test_images[0].shape[0]):
    mask = deep_watershed(
        [t[[i]] for t in test_images],
        min_distance=10,
        detection_threshold=0.1,
        distance_threshold=0.01,
        exclude_border=False,
        small_objects_threshold=0)

y_pred.append(mask[0])

y_pred = np.stack(y_pred, axis=0)
# y_pred = np.expand_dims(y_pred, axis=-1)
y_true = test_y.copy()

print('DeepWatershed - Remove no pixels')
m = Metrics('DeepWatershed - Remove no pixels', seg=False)
m.calc_object_stats(y_true, y_pred)
print('\n')

for i in range(y_pred.shape[0]):
        y_pred[i] = remove_small_objects(y_pred[i].astype(int), min_size=100)
        y_true[i] = remove_small_objects(y_true[i].astype(int), min_size=100)

print('DeepWatershed - Remove objects < 100 pixels')
m = Metrics('DeepWatershed - Remove 100 pixels', seg=False)
m.calc_object_stats(y_true, y_pred)
print('\n')

@msschwartz21
Copy link
Member

Hi @wzr0108, are you trying to recapitulate the results in our Mesmer model? If so, I noticed that you set up your model a bit differently from how Mesmer was trained. An example training notebook is available here.

The key differences that I noticed:

  • I can't tell if you are training on one or two input channels, but Mesmer was trained on nuclear and membrane as the input to the model.
  • The semantic classes are setup differently. For reference, here's the configuration from the Mesmer script num_semantic_classes=[1, 3, 1, 3], # inner distance, pixelwise, inner distance, pixelwise.

Let me know if that helps or if you have other questions!

@wzr0108
Copy link
Author

wzr0108 commented Jan 17, 2023

Thank you for your reply. I am training on nuclear channel, how do i change the configuration

@rossbar
Copy link
Contributor

rossbar commented Jan 17, 2023

Thank you for your reply. I am training on nuclear channel, how do i change the configuration

The linked notebook on model training is likely the best place to start.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

3 participants