Skip to content

Commit

Permalink
Add option for either batch or layer norm in tracking model (#598)
Browse files Browse the repository at this point in the history
* Add option for either batch or layer norm in tracking model

* Bump rc version

* pep8
  • Loading branch information
msschwartz21 committed May 21, 2022
1 parent fc13bbd commit c3e67d3
Show file tree
Hide file tree
Showing 2 changed files with 31 additions and 14 deletions.
2 changes: 1 addition & 1 deletion deepcell/_version.py
Expand Up @@ -27,7 +27,7 @@
__title__ = 'DeepCell'
__description__ = 'Deep learning for single cell image segmentation'
__url__ = 'https://github.com/vanvalenlab/deepcell-tf'
__version__ = '0.12.0rc1'
__version__ = '0.12.0rc2'
__download_url__ = '{}/tarball/{}'.format(__url__, __version__)
__author__ = 'The Van Valen Lab'
__author_email__ = 'vanvalen@caltech.edu'
Expand Down
43 changes: 30 additions & 13 deletions deepcell/model_zoo/tracking.py
Expand Up @@ -43,7 +43,7 @@
from tensorflow.keras.layers import Add, Subtract, Dense, Reshape
from tensorflow.keras.layers import MaxPool3D
from tensorflow.keras.layers import Activation, Softmax
from tensorflow.keras.layers import BatchNormalization, Lambda
from tensorflow.keras.layers import LayerNormalization, BatchNormalization, Lambda
from tensorflow.keras.regularizers import l2

from spektral.layers import GCSConv, GCNConv, GATConv
Expand Down Expand Up @@ -222,6 +222,7 @@ class GNNTrackingModel(object):
Additional kwargs for the graph layers can be encoded in the following format
``<layer name>-kwarg:value-kwarg:value``
appearance_shape (tuple): shape of each object's appearance tensor
norm_layer (str): Must be one of {'layer', 'batch'}
"""
def __init__(self,
max_cells=39,
Expand All @@ -231,7 +232,8 @@ def __init__(self,
embedding_dim=64,
n_layers=3,
graph_layer='gcs',
appearance_shape=(32, 32, 1)):
appearance_shape=(32, 32, 1),
norm_layer='batch'):

self.n_filters = n_filters
self.encoder_dim = encoder_dim
Expand All @@ -253,6 +255,17 @@ def __init__(self,
raise ValueError('Invalid graph_layer: {}'.format(graph_layer_name))
self.graph_layer = graph_layer

norm_options = {'layer', 'batch'}
if norm_layer not in norm_options:
raise ValueError('Invalid normalization layer {}. Must be one of {}.'.format(
norm_layer, norm_options))
if norm_layer == 'layer':
self.norm_layer = LayerNormalization
self.norm_layer_prefix = 'ln'
elif norm_layer == 'batch':
self.norm_layer = BatchNormalization
self.norm_layer_prefix = 'bn'

# Use inputs to build expected shapes
base_shape = [self.track_length, self.max_cells]
self.appearance_shape = tuple(base_shape + list(appearance_shape))
Expand Down Expand Up @@ -308,12 +321,12 @@ def get_appearance_encoder(self):
strides=1,
padding='same',
use_bias=False, name='conv3d_ae{}'.format(i))(x)
x = BatchNormalization(axis=-1, name='bn_ae{}'.format(i))(x)
x = self.norm_layer(axis=-1, name='{}_ae{}'.format(self.norm_layer_prefix, i))(x)
x = Activation('relu', name='relu_ae{}'.format(i))(x)
x = MaxPool3D(pool_size=(1, 2, 2))(x)
x = Lambda(lambda t: tf.squeeze(t, axis=(2, 3)))(x)
x = Dense(self.encoder_dim, name='dense_aeout')(x)
x = BatchNormalization(axis=-1, name='bn_aeout')(x)
x = self.norm_layer(axis=-1, name='{}_aeout'.format(self.norm_layer_prefix))(x)
x = Activation('relu', name='appearance_embedding')(x)
return Model(inputs=inputs, outputs=x)

Expand All @@ -322,7 +335,7 @@ def get_morphology_encoder(self):
inputs = Input(shape=morph_shape, name='encoder_morph_input')
x = inputs
x = Dense(self.encoder_dim, name='dense_me')(x)
x = BatchNormalization(axis=-1, name='bn_me')(x)
x = self.norm_layer(axis=-1, name='{}_me'.format(self.norm_layer_prefix))(x)
x = Activation('relu', name='morphology_embedding')(x)
return Model(inputs=inputs, outputs=x)

Expand All @@ -331,7 +344,7 @@ def get_centroid_encoder(self):
inputs = Input(shape=centroid_shape, name='encoder_centroid_input')
x = inputs
x = Dense(self.encoder_dim, name='dense_ce')(x)
x = BatchNormalization(axis=-1, name='bn_ce')(x)
x = self.norm_layer(axis=-1, name='{}_ce'.format(self.norm_layer_prefix))(x)
x = Activation('relu', name='centroid_embedding')(x)
return Model(inputs=inputs, outputs=x)

Expand All @@ -346,11 +359,11 @@ def get_delta_encoders(self):
a = Activation('relu', name='relu_des')

x_0 = d(inputs)
x_0 = BatchNormalization(axis=-1, name='bn_des0')(x_0)
x_0 = self.norm_layer(axis=-1, name='{}_des0'.format(self.norm_layer_prefix))(x_0)
x_0 = a(x_0)

x_1 = d(inputs_across_frames)
x_1 = BatchNormalization(axis=-1, name='bn_des1')(x_1)
x_1 = self.norm_layer(axis=-1, name='{}_des1'.format(self.norm_layer_prefix))(x_1)
x_1 = a(x_1)

delta_encoder = Model(inputs=inputs, outputs=x_0)
Expand All @@ -374,7 +387,8 @@ def get_neighborhood_encoder(self):
# Concatenate features
node_features = Concatenate(axis=-1)([app_features, morph_features, centroid_features])
node_features = Dense(self.n_filters, name='dense_ne0')(node_features)
node_features = BatchNormalization(axis=-1, name='bn_ne0')(node_features)
node_features = self.norm_layer(axis=-1, name='{}_ne0'.format(self.norm_layer_prefix)
)(node_features)
node_features = Activation('relu', name='relu_ne0')(node_features)

# Apply graph convolution
Expand Down Expand Up @@ -403,13 +417,15 @@ def get_neighborhood_encoder(self):
raise ValueError('Unexpected graph_layer: {}'.format(graph_layer_name))

node_features = graph_layer([node_features, adj])
node_features = BatchNormalization(axis=-1,
name='bn_ne{}'.format(i + 1))(node_features)
node_features = self.norm_layer(axis=-1,
name='{}_ne{}'.format(self.norm_layer_prefix, i + 1)
)(node_features)
node_features = Activation('relu', name='relu_ne{}'.format(i + 1))(node_features)

concat = Concatenate(axis=-1)([app_features, morph_features, node_features])
node_features = Dense(self.embedding_dim, name='dense_nef')(concat)
node_features = BatchNormalization(axis=-1, name='bn_nef')(node_features)
node_features = self.norm_layer(axis=-1, name='{}_nef'.format(self.norm_layer_prefix)
)(node_features)
node_features = Activation('relu', name='relu_nef')(node_features)

inputs = [app_input, morph_input, centroid_input, adj_input]
Expand Down Expand Up @@ -576,7 +592,8 @@ def get_tracking_decoder(self):
embedding = Concatenate(axis=-1)([embedding_input, deltas_input])

embedding = Dense(self.n_filters, name='dense_td0')(embedding)
embedding = BatchNormalization(axis=-1, name='bn_td0')(embedding)
embedding = self.norm_layer(axis=-1, name='{}_td0'.format(self.norm_layer_prefix)
)(embedding)
embedding = Activation('relu', name='relu_td0')(embedding)

# TODO: set to n_classes
Expand Down

0 comments on commit c3e67d3

Please sign in to comment.