Skip to content

Commit

Permalink
Fix ops.ctc_decode (#19633)
Browse files Browse the repository at this point in the history
* Fix greedy ctc decode

* Remove print

* Fix `tf.nn.ctc_beam_search_decoder`

* Change default `mask_index` to `0`

* Fix losses test

* Update
  • Loading branch information
james77777778 committed Apr 28, 2024
1 parent d5c9540 commit 880f0cd
Show file tree
Hide file tree
Showing 8 changed files with 96 additions and 70 deletions.
31 changes: 16 additions & 15 deletions keras/src/backend/jax/nn.py
Expand Up @@ -593,7 +593,7 @@ def ctc_loss(target, output, target_length, output_length, mask_index=0):
output = convert_to_tensor(output)
target_length = convert_to_tensor(target_length, "int32")
output_length = convert_to_tensor(output_length, "int32")
batch_size, _, num_classes = output.shape
batch_size, max_input_length, num_classes = output.shape
batch_size, max_label_length = target.shape
log_epsilon = -1e5

Expand All @@ -610,7 +610,7 @@ def _lengths_to_paddings(lengths, max_length):
return jnp.logical_not(elem_valid)

target_paddings = _lengths_to_paddings(target_length, max_label_length)
output_paddings = _lengths_to_paddings(output_length, max_label_length)
output_paddings = _lengths_to_paddings(output_length, max_input_length)
target_paddings = target_paddings.astype(output.dtype)
output_paddings = output_paddings.astype(output.dtype)

Expand Down Expand Up @@ -690,12 +690,12 @@ def loop_body(prev, x):

def _ctc_greedy_decode(
inputs,
sequence_length,
sequence_lengths,
merge_repeated=True,
mask_index=None,
):
inputs = convert_to_tensor(inputs)
sequence_length = convert_to_tensor(sequence_length, dtype="int32")
sequence_lengths = convert_to_tensor(sequence_lengths, dtype="int32")
batch_size, max_length, num_classes = inputs.shape

if mask_index is None:
Expand All @@ -705,7 +705,7 @@ def _ctc_greedy_decode(
scores = jnp.max(inputs, axis=-1)

seqlen_mask = jnp.arange(max_length)[None, :]
seqlen_mask = seqlen_mask >= sequence_length[:, None]
seqlen_mask = seqlen_mask >= sequence_lengths[:, None]

indices = jnp.where(seqlen_mask, mask_index, indices)
scores = jnp.where(seqlen_mask, 0.0, scores)
Expand All @@ -715,34 +715,35 @@ def _ctc_greedy_decode(
repeat_mask = jnp.pad(repeat_mask, ((0, 0), (1, 0)))
indices = jnp.where(repeat_mask, mask_index, indices)

# We rearrange the indices by moving `mask_index` to the end of the array
# We set to -1 for blank labels
invalid_mask = indices == mask_index
indices = jnp.where(invalid_mask, -1, indices)

# We rearrange the indices by moving `mask_index` to the end of the array
order = jnp.expand_dims(jnp.arange(max_length), axis=0) # [1, N]
order = jnp.tile(order, (batch_size, 1)) # [B, N]
order = jnp.where(invalid_mask, max_length, order)
order = jnp.argsort(order, axis=-1)
indices = jnp.take_along_axis(indices, order, axis=-1)

# We set to -1 for blank labels
indices = jnp.where(invalid_mask, -1, indices)
scores = -jnp.sum(scores, axis=1)[:, None]
indices = jnp.expand_dims(indices, axis=0)
return indices, scores


def _ctc_beam_search_decode(
inputs,
sequence_length,
sequence_lengths,
beam_width=100,
top_paths=1,
mask_index=None,
):
inputs = convert_to_tensor(inputs)
sequence_length = convert_to_tensor(sequence_length)
sequence_lengths = convert_to_tensor(sequence_lengths)

batch_size, max_seq_len, num_classes = inputs.shape
inputs = jnn.log_softmax(inputs)
seqlen_mask = jnp.arange(max_seq_len)[None, :] >= sequence_length[:, None]
seqlen_mask = jnp.arange(max_seq_len)[None, :] >= sequence_lengths[:, None]

if mask_index is None:
mask_index = num_classes - 1
Expand Down Expand Up @@ -895,12 +896,12 @@ def _decode_batch(

def ctc_decode(
inputs,
sequence_length,
sequence_lengths,
strategy="greedy",
beam_width=100,
top_paths=1,
merge_repeated=True,
mask_index=None,
mask_index=0,
):
inputs = convert_to_tensor(inputs)
dtype = backend.result_type(inputs.dtype, "float32")
Expand All @@ -909,14 +910,14 @@ def ctc_decode(
if strategy == "greedy":
return _ctc_greedy_decode(
inputs,
sequence_length,
sequence_lengths,
merge_repeated=merge_repeated,
mask_index=mask_index,
)
elif strategy == "beam_search":
return _ctc_beam_search_decode(
inputs,
sequence_length,
sequence_lengths,
beam_width=beam_width,
top_paths=top_paths,
mask_index=mask_index,
Expand Down
31 changes: 16 additions & 15 deletions keras/src/backend/numpy/nn.py
Expand Up @@ -621,7 +621,7 @@ def ctc_loss(target, output, target_length, output_length, mask_index=0):
output = convert_to_tensor(output)
target_length = convert_to_tensor(target_length, "int32")
output_length = convert_to_tensor(output_length, "int32")
batch_size, _, num_classes = output.shape
batch_size, max_input_length, num_classes = output.shape
batch_size, max_label_length = target.shape
log_epsilon = -1e5

Expand All @@ -638,7 +638,7 @@ def _lengths_to_paddings(lengths, max_length):
return np.logical_not(elem_valid)

target_paddings = _lengths_to_paddings(target_length, max_label_length)
output_paddings = _lengths_to_paddings(output_length, max_label_length)
output_paddings = _lengths_to_paddings(output_length, max_input_length)
target_paddings = target_paddings.astype(output.dtype)
output_paddings = output_paddings.astype(output.dtype)

Expand Down Expand Up @@ -729,12 +729,12 @@ def np_scan(f, init, xs):

def _ctc_greedy_decode(
inputs,
sequence_length,
sequence_lengths,
merge_repeated=True,
mask_index=None,
):
inputs = convert_to_tensor(inputs)
sequence_length = convert_to_tensor(sequence_length, dtype="int32")
sequence_lengths = convert_to_tensor(sequence_lengths, dtype="int32")
batch_size, max_length, num_classes = inputs.shape

if mask_index is None:
Expand All @@ -744,7 +744,7 @@ def _ctc_greedy_decode(
scores = np.max(inputs, axis=-1)

seqlen_mask = np.arange(max_length)[None, :]
seqlen_mask = seqlen_mask >= sequence_length[:, None]
seqlen_mask = seqlen_mask >= sequence_lengths[:, None]

indices = np.where(seqlen_mask, mask_index, indices)
scores = np.where(seqlen_mask, 0.0, scores)
Expand All @@ -754,34 +754,35 @@ def _ctc_greedy_decode(
repeat_mask = np.pad(repeat_mask, ((0, 0), (1, 0)))
indices = np.where(repeat_mask, mask_index, indices)

# We rearrange the indices by moving `mask_index` to the end of the array
# We set to -1 for blank labels
invalid_mask = indices == mask_index
indices = np.where(invalid_mask, -1, indices)

# We rearrange the indices by moving `mask_index` to the end of the array
order = np.expand_dims(np.arange(max_length), axis=0) # [1, N]
order = np.tile(order, (batch_size, 1)) # [B, N]
order = np.where(invalid_mask, max_length, order)
order = np.argsort(order, axis=-1)
indices = np.take_along_axis(indices, order, axis=-1)

# We set to -1 for blank labels
indices = np.where(invalid_mask, -1, indices)
scores = -np.sum(scores, axis=1)[:, None]
indices = np.expand_dims(indices, axis=0)
return indices, scores


def _ctc_beam_search_decode(
inputs,
sequence_length,
sequence_lengths,
beam_width=100,
top_paths=1,
mask_index=None,
):
inputs = convert_to_tensor(inputs)
sequence_length = convert_to_tensor(sequence_length)
sequence_lengths = convert_to_tensor(sequence_lengths)

batch_size, max_seq_len, num_classes = inputs.shape
inputs = log_softmax(inputs, axis=-1)
seqlen_mask = np.arange(max_seq_len)[None, :] >= sequence_length[:, None]
seqlen_mask = np.arange(max_seq_len)[None, :] >= sequence_lengths[:, None]

if mask_index is None:
mask_index = num_classes - 1
Expand Down Expand Up @@ -936,12 +937,12 @@ def np_scan_only_carry(f, init, xs):

def ctc_decode(
inputs,
sequence_length,
sequence_lengths,
strategy="greedy",
beam_width=100,
top_paths=1,
merge_repeated=True,
mask_index=None,
mask_index=0,
):
inputs = convert_to_tensor(inputs)
dtype = backend.result_type(inputs.dtype, "float32")
Expand All @@ -950,14 +951,14 @@ def ctc_decode(
if strategy == "greedy":
return _ctc_greedy_decode(
inputs,
sequence_length,
sequence_lengths,
merge_repeated=merge_repeated,
mask_index=mask_index,
)
elif strategy == "beam_search":
return _ctc_beam_search_decode(
inputs,
sequence_length,
sequence_lengths,
beam_width=beam_width,
top_paths=top_paths,
mask_index=mask_index,
Expand Down
27 changes: 22 additions & 5 deletions keras/src/backend/tensorflow/nn.py
Expand Up @@ -802,12 +802,12 @@ def ctc_loss(

def ctc_decode(
inputs,
sequence_length,
sequence_lengths,
strategy="greedy",
beam_width=100,
top_paths=1,
merge_repeated=True,
mask_index=None,
mask_index=0,
):
inputs = convert_to_tensor(inputs)
input_shape = tf.shape(inputs)
Expand All @@ -817,18 +817,27 @@ def ctc_decode(
dtype = backend.result_type(inputs.dtype, "float32")
inputs = tf.cast(inputs, dtype)

sequence_length = convert_to_tensor(sequence_length, dtype="int32")
sequence_lengths = convert_to_tensor(sequence_lengths, dtype="int32")
if strategy == "greedy":
(decoded, scores) = tf.nn.ctc_greedy_decoder(
inputs=inputs,
sequence_length=sequence_length,
sequence_length=sequence_lengths,
merge_repeated=merge_repeated,
blank_index=mask_index,
)
elif strategy == "beam_search":
# Move `mask_index` column to the last position since this is the
# default for `tf.nn.ctc_beam_search_decoder`
if mask_index is not None:
inputs_before = inputs[..., :mask_index]
inputs_mask = inputs[..., mask_index : mask_index + 1]
inputs_after = inputs[..., mask_index + 1 :]
inputs = tf.concat(
[inputs_before, inputs_after, inputs_mask], axis=-1
)
(decoded, scores) = tf.nn.ctc_beam_search_decoder(
inputs=inputs,
sequence_length=sequence_length,
sequence_length=sequence_lengths,
beam_width=beam_width,
top_paths=top_paths,
)
Expand All @@ -845,6 +854,14 @@ def ctc_decode(
decoded_dense.append(tf.sparse.to_dense(sp_input=st, default_value=-1))
decoded_dense = tf.stack(decoded_dense, axis=0)
decoded_dense = tf.cast(decoded_dense, "int32")

# We need to recover the labels because we swapped the indices earlier
if strategy == "beam_search" and mask_index is not None:
if mask_index < 0:
mask_index = mask_index + input_shape[-1]
decoded_dense = tf.where(
decoded_dense >= mask_index, decoded_dense + 1, decoded_dense
)
return decoded_dense, scores


Expand Down
19 changes: 10 additions & 9 deletions keras/src/backend/torch/nn.py
Expand Up @@ -775,12 +775,12 @@ def ctc_loss(

def _ctc_greedy_decode(
inputs,
sequence_length,
sequence_lengths,
merge_repeated=True,
mask_index=None,
):
inputs = convert_to_tensor(inputs)
sequence_length = convert_to_tensor(sequence_length, dtype="int32")
sequence_lengths = convert_to_tensor(sequence_lengths, dtype="int32")
batch_size, max_length, num_classes = inputs.shape

if mask_index is None:
Expand All @@ -791,7 +791,7 @@ def _ctc_greedy_decode(
scores = torch.max(inputs, axis=-1)[0]

seqlen_mask = torch.arange(max_length, device=indices.device)[None, :]
seqlen_mask = seqlen_mask >= sequence_length[:, None]
seqlen_mask = seqlen_mask >= sequence_lengths[:, None]

indices = torch.where(seqlen_mask, mask_index, indices)
scores = torch.where(seqlen_mask, 0.0, scores)
Expand All @@ -801,8 +801,11 @@ def _ctc_greedy_decode(
repeat = tnn.pad(repeat, (1, 0, 0, 0))
indices = torch.where(repeat, mask_index, indices)

# We rearrange the indices by moving `mask_index` to the end of the array
# We set to -1 for blank labels
invalid_mask = indices == mask_index
indices = torch.where(invalid_mask, -1, indices)

# We rearrange the indices by moving `mask_index` to the end of the array
order = torch.unsqueeze(
torch.arange(max_length, device=indices.device), dim=0
) # [1, N]
Expand All @@ -811,21 +814,19 @@ def _ctc_greedy_decode(
order = torch.argsort(order, dim=-1)
indices = torch.take_along_dim(indices, order, dim=-1)

# We set to -1 for blank labels
indices = torch.where(invalid_mask, -1, indices)
scores = -torch.sum(scores, axis=1)[:, None]
indices = torch.unsqueeze(indices, dim=0)
return indices, scores


def ctc_decode(
inputs,
sequence_length,
sequence_lengths,
strategy="greedy",
beam_width=100,
top_paths=1,
merge_repeated=True,
mask_index=None,
mask_index=0,
):
inputs = convert_to_tensor(inputs)
dtype = backend.result_type(inputs.dtype, "float32")
Expand All @@ -834,7 +835,7 @@ def ctc_decode(
if strategy == "greedy":
return _ctc_greedy_decode(
inputs,
sequence_length,
sequence_lengths,
merge_repeated=merge_repeated,
mask_index=mask_index,
)
Expand Down
12 changes: 7 additions & 5 deletions keras/src/losses/losses.py
Expand Up @@ -1933,14 +1933,16 @@ def ctc(y_true, y_pred):
f"Received: y_pred.shape={ops.shape(y_pred)}"
)

batch_length = ops.cast(ops.shape(y_true)[0], dtype="int32")
input_length = ops.cast(ops.shape(y_pred)[1], dtype="int32")
label_length = ops.cast(ops.shape(y_true)[1], dtype="int32")
mask_index = 0
batch_length = ops.shape(y_pred)[0]
input_length = ops.shape(y_pred)[1]
input_length = input_length * ops.ones((batch_length,), dtype="int32")
label_length = label_length * ops.ones((batch_length,), dtype="int32")
label_length = ops.cast(
ops.sum(y_true != mask_index, axis=-1), dtype="int32"
)

return ops.ctc_loss(
y_true, y_pred, label_length, input_length, mask_index=0
y_true, y_pred, label_length, input_length, mask_index=mask_index
)


Expand Down
2 changes: 1 addition & 1 deletion keras/src/losses/losses_test.py
Expand Up @@ -1387,7 +1387,7 @@ def test_correctness(self):
logits = (np.arange(24).reshape((2, 4, 3)).astype("float32") - 12) / 100
y_true = np.array(([[1, 2, 1, 0], [1, 2, 0, 2]]))
output = losses.CTC()(y_true, logits)
self.assertAllClose(output, 4.389582)
self.assertAllClose(output, 2.448645)


class DiceTest(testing.TestCase):
Expand Down

0 comments on commit 880f0cd

Please sign in to comment.