Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix ops.ctc_decode #19633

Merged
merged 8 commits into from
Apr 28, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
31 changes: 16 additions & 15 deletions keras/src/backend/jax/nn.py
Expand Up @@ -593,7 +593,7 @@ def ctc_loss(target, output, target_length, output_length, mask_index=0):
output = convert_to_tensor(output)
target_length = convert_to_tensor(target_length, "int32")
output_length = convert_to_tensor(output_length, "int32")
batch_size, _, num_classes = output.shape
batch_size, max_input_length, num_classes = output.shape
batch_size, max_label_length = target.shape
log_epsilon = -1e5

Expand All @@ -610,7 +610,7 @@ def _lengths_to_paddings(lengths, max_length):
return jnp.logical_not(elem_valid)

target_paddings = _lengths_to_paddings(target_length, max_label_length)
output_paddings = _lengths_to_paddings(output_length, max_label_length)
output_paddings = _lengths_to_paddings(output_length, max_input_length)
target_paddings = target_paddings.astype(output.dtype)
output_paddings = output_paddings.astype(output.dtype)

Expand Down Expand Up @@ -690,12 +690,12 @@ def loop_body(prev, x):

def _ctc_greedy_decode(
inputs,
sequence_length,
sequence_lengths,
merge_repeated=True,
mask_index=None,
):
inputs = convert_to_tensor(inputs)
sequence_length = convert_to_tensor(sequence_length, dtype="int32")
sequence_lengths = convert_to_tensor(sequence_lengths, dtype="int32")
batch_size, max_length, num_classes = inputs.shape

if mask_index is None:
Expand All @@ -705,7 +705,7 @@ def _ctc_greedy_decode(
scores = jnp.max(inputs, axis=-1)

seqlen_mask = jnp.arange(max_length)[None, :]
seqlen_mask = seqlen_mask >= sequence_length[:, None]
seqlen_mask = seqlen_mask >= sequence_lengths[:, None]

indices = jnp.where(seqlen_mask, mask_index, indices)
scores = jnp.where(seqlen_mask, 0.0, scores)
Expand All @@ -715,34 +715,35 @@ def _ctc_greedy_decode(
repeat_mask = jnp.pad(repeat_mask, ((0, 0), (1, 0)))
indices = jnp.where(repeat_mask, mask_index, indices)

# We rearrange the indices by moving `mask_index` to the end of the array
# We set to -1 for blank labels
invalid_mask = indices == mask_index
indices = jnp.where(invalid_mask, -1, indices)

# We rearrange the indices by moving `mask_index` to the end of the array
order = jnp.expand_dims(jnp.arange(max_length), axis=0) # [1, N]
order = jnp.tile(order, (batch_size, 1)) # [B, N]
order = jnp.where(invalid_mask, max_length, order)
order = jnp.argsort(order, axis=-1)
indices = jnp.take_along_axis(indices, order, axis=-1)

# We set to -1 for blank labels
indices = jnp.where(invalid_mask, -1, indices)
scores = -jnp.sum(scores, axis=1)[:, None]
indices = jnp.expand_dims(indices, axis=0)
return indices, scores


def _ctc_beam_search_decode(
inputs,
sequence_length,
sequence_lengths,
beam_width=100,
top_paths=1,
mask_index=None,
):
inputs = convert_to_tensor(inputs)
sequence_length = convert_to_tensor(sequence_length)
sequence_lengths = convert_to_tensor(sequence_lengths)

batch_size, max_seq_len, num_classes = inputs.shape
inputs = jnn.log_softmax(inputs)
seqlen_mask = jnp.arange(max_seq_len)[None, :] >= sequence_length[:, None]
seqlen_mask = jnp.arange(max_seq_len)[None, :] >= sequence_lengths[:, None]

if mask_index is None:
mask_index = num_classes - 1
Expand Down Expand Up @@ -895,12 +896,12 @@ def _decode_batch(

def ctc_decode(
inputs,
sequence_length,
sequence_lengths,
strategy="greedy",
beam_width=100,
top_paths=1,
merge_repeated=True,
mask_index=None,
mask_index=0,
):
inputs = convert_to_tensor(inputs)
dtype = backend.result_type(inputs.dtype, "float32")
Expand All @@ -909,14 +910,14 @@ def ctc_decode(
if strategy == "greedy":
return _ctc_greedy_decode(
inputs,
sequence_length,
sequence_lengths,
merge_repeated=merge_repeated,
mask_index=mask_index,
)
elif strategy == "beam_search":
return _ctc_beam_search_decode(
inputs,
sequence_length,
sequence_lengths,
beam_width=beam_width,
top_paths=top_paths,
mask_index=mask_index,
Expand Down
31 changes: 16 additions & 15 deletions keras/src/backend/numpy/nn.py
Expand Up @@ -621,7 +621,7 @@ def ctc_loss(target, output, target_length, output_length, mask_index=0):
output = convert_to_tensor(output)
target_length = convert_to_tensor(target_length, "int32")
output_length = convert_to_tensor(output_length, "int32")
batch_size, _, num_classes = output.shape
batch_size, max_input_length, num_classes = output.shape
batch_size, max_label_length = target.shape
log_epsilon = -1e5

Expand All @@ -638,7 +638,7 @@ def _lengths_to_paddings(lengths, max_length):
return np.logical_not(elem_valid)

target_paddings = _lengths_to_paddings(target_length, max_label_length)
output_paddings = _lengths_to_paddings(output_length, max_label_length)
output_paddings = _lengths_to_paddings(output_length, max_input_length)
target_paddings = target_paddings.astype(output.dtype)
output_paddings = output_paddings.astype(output.dtype)

Expand Down Expand Up @@ -729,12 +729,12 @@ def np_scan(f, init, xs):

def _ctc_greedy_decode(
inputs,
sequence_length,
sequence_lengths,
merge_repeated=True,
mask_index=None,
):
inputs = convert_to_tensor(inputs)
sequence_length = convert_to_tensor(sequence_length, dtype="int32")
sequence_lengths = convert_to_tensor(sequence_lengths, dtype="int32")
batch_size, max_length, num_classes = inputs.shape

if mask_index is None:
Expand All @@ -744,7 +744,7 @@ def _ctc_greedy_decode(
scores = np.max(inputs, axis=-1)

seqlen_mask = np.arange(max_length)[None, :]
seqlen_mask = seqlen_mask >= sequence_length[:, None]
seqlen_mask = seqlen_mask >= sequence_lengths[:, None]

indices = np.where(seqlen_mask, mask_index, indices)
scores = np.where(seqlen_mask, 0.0, scores)
Expand All @@ -754,34 +754,35 @@ def _ctc_greedy_decode(
repeat_mask = np.pad(repeat_mask, ((0, 0), (1, 0)))
indices = np.where(repeat_mask, mask_index, indices)

# We rearrange the indices by moving `mask_index` to the end of the array
# We set to -1 for blank labels
invalid_mask = indices == mask_index
indices = np.where(invalid_mask, -1, indices)

# We rearrange the indices by moving `mask_index` to the end of the array
order = np.expand_dims(np.arange(max_length), axis=0) # [1, N]
order = np.tile(order, (batch_size, 1)) # [B, N]
order = np.where(invalid_mask, max_length, order)
order = np.argsort(order, axis=-1)
indices = np.take_along_axis(indices, order, axis=-1)

# We set to -1 for blank labels
indices = np.where(invalid_mask, -1, indices)
scores = -np.sum(scores, axis=1)[:, None]
indices = np.expand_dims(indices, axis=0)
return indices, scores


def _ctc_beam_search_decode(
inputs,
sequence_length,
sequence_lengths,
beam_width=100,
top_paths=1,
mask_index=None,
):
inputs = convert_to_tensor(inputs)
sequence_length = convert_to_tensor(sequence_length)
sequence_lengths = convert_to_tensor(sequence_lengths)

batch_size, max_seq_len, num_classes = inputs.shape
inputs = log_softmax(inputs, axis=-1)
seqlen_mask = np.arange(max_seq_len)[None, :] >= sequence_length[:, None]
seqlen_mask = np.arange(max_seq_len)[None, :] >= sequence_lengths[:, None]

if mask_index is None:
mask_index = num_classes - 1
Expand Down Expand Up @@ -936,12 +937,12 @@ def np_scan_only_carry(f, init, xs):

def ctc_decode(
inputs,
sequence_length,
sequence_lengths,
strategy="greedy",
beam_width=100,
top_paths=1,
merge_repeated=True,
mask_index=None,
mask_index=0,
):
inputs = convert_to_tensor(inputs)
dtype = backend.result_type(inputs.dtype, "float32")
Expand All @@ -950,14 +951,14 @@ def ctc_decode(
if strategy == "greedy":
return _ctc_greedy_decode(
inputs,
sequence_length,
sequence_lengths,
merge_repeated=merge_repeated,
mask_index=mask_index,
)
elif strategy == "beam_search":
return _ctc_beam_search_decode(
inputs,
sequence_length,
sequence_lengths,
beam_width=beam_width,
top_paths=top_paths,
mask_index=mask_index,
Expand Down
27 changes: 22 additions & 5 deletions keras/src/backend/tensorflow/nn.py
Expand Up @@ -802,12 +802,12 @@ def ctc_loss(

def ctc_decode(
inputs,
sequence_length,
sequence_lengths,
strategy="greedy",
beam_width=100,
top_paths=1,
merge_repeated=True,
mask_index=None,
mask_index=0,
):
inputs = convert_to_tensor(inputs)
input_shape = tf.shape(inputs)
Expand All @@ -817,18 +817,27 @@ def ctc_decode(
dtype = backend.result_type(inputs.dtype, "float32")
inputs = tf.cast(inputs, dtype)

sequence_length = convert_to_tensor(sequence_length, dtype="int32")
sequence_lengths = convert_to_tensor(sequence_lengths, dtype="int32")
if strategy == "greedy":
(decoded, scores) = tf.nn.ctc_greedy_decoder(
inputs=inputs,
sequence_length=sequence_length,
sequence_length=sequence_lengths,
merge_repeated=merge_repeated,
blank_index=mask_index,
)
elif strategy == "beam_search":
# Move `mask_index` column to the last position since this is the
# default for `tf.nn.ctc_beam_search_decoder`
if mask_index is not None:
inputs_before = inputs[..., :mask_index]
inputs_mask = inputs[..., mask_index : mask_index + 1]
inputs_after = inputs[..., mask_index + 1 :]
inputs = tf.concat(
[inputs_before, inputs_after, inputs_mask], axis=-1
)
(decoded, scores) = tf.nn.ctc_beam_search_decoder(
inputs=inputs,
sequence_length=sequence_length,
sequence_length=sequence_lengths,
beam_width=beam_width,
top_paths=top_paths,
)
Expand All @@ -845,6 +854,14 @@ def ctc_decode(
decoded_dense.append(tf.sparse.to_dense(sp_input=st, default_value=-1))
decoded_dense = tf.stack(decoded_dense, axis=0)
decoded_dense = tf.cast(decoded_dense, "int32")

# We need to recover the labels because we swapped the indices earlier
if strategy == "beam_search" and mask_index is not None:
if mask_index < 0:
mask_index = mask_index + input_shape[-1]
decoded_dense = tf.where(
decoded_dense >= mask_index, decoded_dense + 1, decoded_dense
)
return decoded_dense, scores


Expand Down
19 changes: 10 additions & 9 deletions keras/src/backend/torch/nn.py
Expand Up @@ -775,12 +775,12 @@ def ctc_loss(

def _ctc_greedy_decode(
inputs,
sequence_length,
sequence_lengths,
merge_repeated=True,
mask_index=None,
):
inputs = convert_to_tensor(inputs)
sequence_length = convert_to_tensor(sequence_length, dtype="int32")
sequence_lengths = convert_to_tensor(sequence_lengths, dtype="int32")
batch_size, max_length, num_classes = inputs.shape

if mask_index is None:
Expand All @@ -791,7 +791,7 @@ def _ctc_greedy_decode(
scores = torch.max(inputs, axis=-1)[0]

seqlen_mask = torch.arange(max_length, device=indices.device)[None, :]
seqlen_mask = seqlen_mask >= sequence_length[:, None]
seqlen_mask = seqlen_mask >= sequence_lengths[:, None]

indices = torch.where(seqlen_mask, mask_index, indices)
scores = torch.where(seqlen_mask, 0.0, scores)
Expand All @@ -801,8 +801,11 @@ def _ctc_greedy_decode(
repeat = tnn.pad(repeat, (1, 0, 0, 0))
indices = torch.where(repeat, mask_index, indices)

# We rearrange the indices by moving `mask_index` to the end of the array
# We set to -1 for blank labels
invalid_mask = indices == mask_index
indices = torch.where(invalid_mask, -1, indices)

# We rearrange the indices by moving `mask_index` to the end of the array
order = torch.unsqueeze(
torch.arange(max_length, device=indices.device), dim=0
) # [1, N]
Expand All @@ -811,21 +814,19 @@ def _ctc_greedy_decode(
order = torch.argsort(order, dim=-1)
indices = torch.take_along_dim(indices, order, dim=-1)

# We set to -1 for blank labels
indices = torch.where(invalid_mask, -1, indices)
scores = -torch.sum(scores, axis=1)[:, None]
indices = torch.unsqueeze(indices, dim=0)
return indices, scores


def ctc_decode(
inputs,
sequence_length,
sequence_lengths,
strategy="greedy",
beam_width=100,
top_paths=1,
merge_repeated=True,
mask_index=None,
mask_index=0,
):
inputs = convert_to_tensor(inputs)
dtype = backend.result_type(inputs.dtype, "float32")
Expand All @@ -834,7 +835,7 @@ def ctc_decode(
if strategy == "greedy":
return _ctc_greedy_decode(
inputs,
sequence_length,
sequence_lengths,
merge_repeated=merge_repeated,
mask_index=mask_index,
)
Expand Down
12 changes: 7 additions & 5 deletions keras/src/losses/losses.py
Expand Up @@ -1933,14 +1933,16 @@ def ctc(y_true, y_pred):
f"Received: y_pred.shape={ops.shape(y_pred)}"
)

batch_length = ops.cast(ops.shape(y_true)[0], dtype="int32")
input_length = ops.cast(ops.shape(y_pred)[1], dtype="int32")
label_length = ops.cast(ops.shape(y_true)[1], dtype="int32")
mask_index = 0
batch_length = ops.shape(y_pred)[0]
input_length = ops.shape(y_pred)[1]
input_length = input_length * ops.ones((batch_length,), dtype="int32")
label_length = label_length * ops.ones((batch_length,), dtype="int32")
label_length = ops.cast(
ops.sum(y_true != mask_index, axis=-1), dtype="int32"
)

return ops.ctc_loss(
y_true, y_pred, label_length, input_length, mask_index=0
y_true, y_pred, label_length, input_length, mask_index=mask_index
)


Expand Down
2 changes: 1 addition & 1 deletion keras/src/losses/losses_test.py
Expand Up @@ -1387,7 +1387,7 @@ def test_correctness(self):
logits = (np.arange(24).reshape((2, 4, 3)).astype("float32") - 12) / 100
y_true = np.array(([[1, 2, 1, 0], [1, 2, 0, 2]]))
output = losses.CTC()(y_true, logits)
self.assertAllClose(output, 4.389582)
self.assertAllClose(output, 2.448645)


class DiceTest(testing.TestCase):
Expand Down