Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Support TFLite and TFJS conversion for CenterNet multi-class keypoints #11101

Open
wants to merge 2 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
107 changes: 42 additions & 65 deletions research/object_detection/meta_architectures/center_net_meta_arch.py
Original file line number Diff line number Diff line change
Expand Up @@ -4433,10 +4433,8 @@ def _postprocess_keypoints_multi_class(self, prediction_dict, classes,

This is the most general keypoint postprocessing function which supports
multiple keypoint tasks (e.g. human and dog keypoints) and multiple object
detection classes. Note that it is the most expensive postprocessing logics
and is currently not tf.lite/tf.js compatible. See
_postprocess_keypoints_single_class if you plan to export the model in more
portable format.
detection classes. Note that it is a more expensive postprocessing logic
compared to _postprocess_keypoints_single_class.

Args:
prediction_dict: a dictionary holding predicted tensors, returned from the
Expand All @@ -4460,90 +4458,69 @@ def _postprocess_keypoints_multi_class(self, prediction_dict, classes,
keypoint_scores: a [batch_size, max_detections, num_total_keypoints]
float32 tensor with keypoint scores.
"""
total_num_keypoints = sum(len(kp_dict.keypoint_indices) for kp_dict
in self._kp_params_dict.values())
batch_size, max_detections = _get_shape(classes, 2)
kpt_coords_for_example_list = []
kpt_scores_for_example_list = []
kpt_coords_combined = []
kpt_scores_combined = []
batch_size, _ = _get_shape(classes, 2)

for ex_ind in range(batch_size):
# The tensors that host the keypoint coordinates and scores for all
# instances and all keypoints. They will be updated by scatter_nd_add for
# each keypoint tasks.
kpt_coords_for_example_all_det = tf.zeros(
[max_detections, total_num_keypoints, 2])
kpt_scores_for_example_all_det = tf.zeros(
[max_detections, total_num_keypoints])
kpt_coords_for_example = []
kpt_scores_for_example = []

for task_name, kp_params in self._kp_params_dict.items():
keypoint_heatmap = prediction_dict[
get_keypoint_name(task_name, KEYPOINT_HEATMAP)][-1]
keypoint_offsets = prediction_dict[
get_keypoint_name(task_name, KEYPOINT_OFFSET)][-1]
keypoint_regression = prediction_dict[
get_keypoint_name(task_name, KEYPOINT_REGRESSION)][-1]
instance_inds = self._get_instance_indices(
classes, num_detections, ex_ind, kp_params.class_id)

# Gather the feature map locations corresponding to the object class.
y_indices_for_kpt_class = tf.gather(y_indices, instance_inds, axis=1)
x_indices_for_kpt_class = tf.gather(x_indices, instance_inds, axis=1)
if boxes is None:
boxes_for_kpt_class = None
else:
boxes_for_kpt_class = tf.gather(boxes, instance_inds, axis=1)

# Postprocess keypoints and scores for class and single image. Shapes
# are [1, num_instances_i, num_keypoints_i, 2] and
# [1, num_instances_i, num_keypoints_i], respectively. Note that
# num_instances_i and num_keypoints_i refers to the number of
# instances and keypoints for class i, respectively.
# Postprocess keypoints and scores for class and single image.
# Shapes are [1, max_detections, num_keypoints, 2] and
# [1, max_detections, num_keypoints], respectively.
(kpt_coords_for_class, kpt_scores_for_class, _) = (
self._postprocess_keypoints_for_class_and_image(
keypoint_heatmap,
keypoint_offsets,
keypoint_regression,
classes,
y_indices_for_kpt_class,
x_indices_for_kpt_class,
boxes_for_kpt_class,
y_indices,
x_indices,
boxes,
ex_ind,
kp_params,
))

# Prepare the indices for scatter_nd. The resulting combined_inds has
# the shape of [num_instances_i * num_keypoints_i, 2], where the first
# column corresponds to the instance IDs and the second column
# corresponds to the keypoint IDs.
kpt_inds = tf.constant(kp_params.keypoint_indices, dtype=tf.int32)
kpt_inds = tf.expand_dims(kpt_inds, axis=0)
instance_inds_expand = tf.expand_dims(instance_inds, axis=-1)
kpt_inds_expand = kpt_inds * tf.ones_like(instance_inds_expand)
instance_inds_expand = instance_inds_expand * tf.ones_like(kpt_inds)
combined_inds = tf.stack(
[instance_inds_expand, kpt_inds_expand], axis=2)
combined_inds = tf.reshape(combined_inds, [-1, 2])

# Reshape the keypoint coordinates/scores to [num_instances_i *
# num_keypoints_i, 2]/[num_instances_i * num_keypoints_i] to be used
# by scatter_nd_add.
kpt_coords_for_class = tf.reshape(kpt_coords_for_class, [-1, 2])
kpt_scores_for_class = tf.reshape(kpt_scores_for_class, [-1])
kpt_coords_for_example_all_det = tf.tensor_scatter_nd_add(
kpt_coords_for_example_all_det,
combined_inds, kpt_coords_for_class)
kpt_scores_for_example_all_det = tf.tensor_scatter_nd_add(
kpt_scores_for_example_all_det,
combined_inds, kpt_scores_for_class)

kpt_coords_for_example_list.append(
tf.expand_dims(kpt_coords_for_example_all_det, axis=0))
kpt_scores_for_example_list.append(
tf.expand_dims(kpt_scores_for_example_all_det, axis=0))
# Set all keypoint coordinates and scores to zeros except for those
# whose class corresponds to the task in the current iteration.
mask_for_class = classes[ex_ind] == kp_params.class_id
mask_scores_for_class = mask_for_class[..., tf.newaxis]
mask_coords_for_class = mask_scores_for_class[..., tf.newaxis]
kpt_coords_for_class = tf2.where(mask_coords_for_class,
kpt_coords_for_class,
tf.zeros_like(
kpt_coords_for_class))
kpt_scores_for_class = tf2.where(mask_scores_for_class,
kpt_scores_for_class,
tf.zeros_like(
kpt_scores_for_class))

kpt_coords_for_example.append(kpt_coords_for_class)
kpt_scores_for_example.append(kpt_scores_for_class)

# Concatenate keypoints and scores from all classes in the example.
# Shapes are [1, max_detections, num_total_keypoints, 2] and
# [1, max_detections, num_total_keypoints], respectively.
kpt_coords_for_example = tf.concat(kpt_coords_for_example, axis=2)
kpt_scores_for_example = tf.concat(kpt_scores_for_example, axis=2)

kpt_coords_combined.append(kpt_coords_for_example)
kpt_scores_combined.append(kpt_scores_for_example)

# Concatenate all keypoints and scores from all examples in the batch.
# Shapes are [batch_size, max_detections, num_total_keypoints, 2] and
# [batch_size, max_detections, num_total_keypoints], respectively.
keypoints = tf.concat(kpt_coords_for_example_list, axis=0)
keypoint_scores = tf.concat(kpt_scores_for_example_list, axis=0)
keypoints = tf.concat(kpt_coords_combined, axis=0)
keypoint_scores = tf.concat(kpt_scores_combined, axis=0)

return keypoints, keypoint_scores

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -2529,6 +2529,167 @@ def graph_fn():
self.assertAllClose(detections['detection_scores'][0][:num_detections],
[0.675])

@parameterized.parameters(
{
'candidate_ranking_mode': 'min_distance',
'argmax_postprocessing': False
},
{
'candidate_ranking_mode': 'score_distance_ratio',
'argmax_postprocessing': True
})
def test_postprocess_multi_class(self, candidate_ranking_mode,
argmax_postprocessing):
"""Test the postprocess function for multiple classes."""
feature_extractor = DummyFeatureExtractor(
channel_means=(1.0, 2.0, 3.0),
channel_stds=(10., 20., 30.),
bgr_ordering=False,
num_feature_outputs=2,
stride=4)
image_resizer_fn = functools.partial(
preprocessor.resize_to_range,
min_dimension=128,
max_dimension=128,
pad_to_max_dimesnion=True)

kp_params_1 = cnma.KeypointEstimationParams(
task_name='kpt_task_1',
class_id=0,
keypoint_indices=[0, 1],
keypoint_std_dev=[0.00001] * 2,
classification_loss=losses.WeightedSigmoidClassificationLoss(),
localization_loss=losses.L1LocalizationLoss(),
keypoint_candidate_score_threshold=0.1,
candidate_ranking_mode=candidate_ranking_mode,
argmax_postprocessing=argmax_postprocessing)
kp_params_2 = cnma.KeypointEstimationParams(
task_name='kpt_task_2',
class_id=1,
keypoint_indices=[2, 3, 4],
keypoint_std_dev=[0.00001] * 3,
classification_loss=losses.WeightedSigmoidClassificationLoss(),
localization_loss=losses.L1LocalizationLoss(),
keypoint_candidate_score_threshold=0.1,
candidate_ranking_mode=candidate_ranking_mode,
argmax_postprocessing=argmax_postprocessing)
model = cnma.CenterNetMetaArch(
is_training=True,
add_summaries=False,
num_classes=2,
feature_extractor=feature_extractor,
image_resizer_fn=image_resizer_fn,
object_center_params=get_fake_center_params(),
object_detection_params=get_fake_od_params(),
keypoint_params_dict={
'kpt_task_1': kp_params_1,
'kpt_task_2': kp_params_2,
})
max_detection = model._center_params.max_box_predictions
kp_params_dict = model._kp_params_dict
num_keypoints_task_1 = len(kp_params_dict['kpt_task_1'].keypoint_indices)
num_keypoints_task_2 = len(kp_params_dict['kpt_task_2'].keypoint_indices)
num_keypoints = num_keypoints_task_1 + num_keypoints_task_2

class_center = np.zeros((1, 32, 32, 2), dtype=np.float32)
height_width = np.zeros((1, 32, 32, 2), dtype=np.float32)
offset = np.zeros((1, 32, 32, 2), dtype=np.float32)

class_probs = np.zeros(2)
class_probs[0] = _logit(0.75)
class_probs[1] = _logit(0.75)
class_center[0, 16, 16] = class_probs
height_width[0, 16, 16] = [5, 10]
offset[0, 16, 16] = [.25, .5]

keypoint_heatmaps_task_1 = np.ones(
(1, 32, 32, num_keypoints_task_1), dtype=np.float32) * _logit(0.01)
keypoint_offsets_task_1 = np.zeros(
(1, 32, 32, num_keypoints_task_1 * 2), dtype=np.float32)
keypoint_regression_task_1 = np.random.randn(1, 32, 32,
num_keypoints_task_1 * 2)

keypoint_regression_task_1[0, 16, 16] = [
-1., -1.,
-1., 1.]
keypoint_heatmaps_task_1[0, 14, 14, 0] = _logit(0.9)
keypoint_heatmaps_task_1[0, 14, 18, 1] = _logit(0.05) # Note the low score.

keypoint_heatmaps_task_2 = np.ones(
(1, 32, 32, num_keypoints_task_2), dtype=np.float32) * _logit(0.01)
keypoint_offsets_task_2 = np.zeros(
(1, 32, 32, num_keypoints_task_2 * 2), dtype=np.float32)
keypoint_regression_task_2 = np.random.randn(1, 32, 32,
num_keypoints_task_2 * 2)

keypoint_regression_task_2[0, 16, 16] = [
-1., -1.,
-1., 1.,
1, -1]
keypoint_heatmaps_task_2[0, 14, 14, 0] = _logit(0.9)
keypoint_heatmaps_task_2[0, 14, 18, 1] = _logit(0.9)
keypoint_heatmaps_task_2[0, 14, 18, 2] = _logit(0.05) # Note the low score.

class_center = tf.constant(class_center)
height_width = tf.constant(height_width)
offset = tf.constant(offset)
keypoint_heatmaps_task_1 = tf.constant(
keypoint_heatmaps_task_1, dtype=tf.float32)
keypoint_offsets_task_1 = tf.constant(
keypoint_offsets_task_1, dtype=tf.float32)
keypoint_regression_task_1 = tf.constant(
keypoint_regression_task_1, dtype=tf.float32)
keypoint_heatmaps_task_2 = tf.constant(
keypoint_heatmaps_task_2, dtype=tf.float32)
keypoint_offsets_task_2 = tf.constant(
keypoint_offsets_task_2, dtype=tf.float32)
keypoint_regression_task_2 = tf.constant(
keypoint_regression_task_2, dtype=tf.float32)

prediction_dict = {
cnma.OBJECT_CENTER: [class_center],
cnma.BOX_SCALE: [height_width],
cnma.BOX_OFFSET: [offset],
cnma.get_keypoint_name(kp_params_1.task_name, cnma.KEYPOINT_HEATMAP):
[keypoint_heatmaps_task_1],
cnma.get_keypoint_name(kp_params_1.task_name, cnma.KEYPOINT_OFFSET):
[keypoint_offsets_task_1],
cnma.get_keypoint_name(kp_params_1.task_name, cnma.KEYPOINT_REGRESSION):
[keypoint_regression_task_1],
cnma.get_keypoint_name(kp_params_2.task_name, cnma.KEYPOINT_HEATMAP):
[keypoint_heatmaps_task_2],
cnma.get_keypoint_name(kp_params_2.task_name, cnma.KEYPOINT_OFFSET):
[keypoint_offsets_task_2],
cnma.get_keypoint_name(kp_params_2.task_name, cnma.KEYPOINT_REGRESSION):
[keypoint_regression_task_2]
}

def graph_fn():
detections = model.postprocess(prediction_dict,
tf.constant([[128, 128, 3]]))
return detections

detections = self.execute_cpu(graph_fn, [])

self.assertAllClose(detections['detection_boxes'][0, 0],
np.array([55, 46, 75, 86]) / 128.0)
self.assertAllClose(detections['detection_scores'][0],
[.75, .75, .5, .5, .5])

self.assertAllEqual(detections['detection_classes'][0], [0, 1, 0, 1, 0])
self.assertEqual(detections['num_detections'], [5])
self.assertAllEqual([1, max_detection, num_keypoints, 2],
detections['detection_keypoints'].shape)
self.assertAllClose(
[[0.4375, 0.4375], [0.46875, 0.53125], [0, 0], [0, 0], [0, 0]],
detections['detection_keypoints'][0, 0, :, :])
self.assertAllClose(
[[0, 0], [0, 0], [0.4375, 0.4375], [0.4375, 0.5625],
[0.53125, 0.46875]],
detections['detection_keypoints'][0, 1, :, :])
self.assertAllEqual([1, max_detection, num_keypoints],
detections['detection_keypoint_scores'].shape)

@parameterized.parameters(
{
'candidate_ranking_mode': 'min_distance',
Expand Down