Skip to content

Commit

Permalink
fix(ops): specify NonZero output dtype and add test coverage
Browse files Browse the repository at this point in the history
  • Loading branch information
lpizzinidev committed Apr 28, 2024
1 parent 54e15eb commit 895334a
Show file tree
Hide file tree
Showing 2 changed files with 5 additions and 2 deletions.
2 changes: 1 addition & 1 deletion keras/src/ops/numpy.py
Original file line number Diff line number Diff line change
Expand Up @@ -3881,7 +3881,7 @@ def call(self, x):
return backend.numpy.nonzero(x)

def compute_output_spec(self, x):
return KerasTensor([None] * len(x.shape))
return KerasTensor([None] * len(x.shape), dtype="int32")


@keras_export(["keras.ops.nonzero", "keras.ops.numpy.nonzero"])
Expand Down
5 changes: 4 additions & 1 deletion keras/src/ops/numpy_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -7102,7 +7102,10 @@ def test_nonzero(self, dtype):
self.assertEqual(
standardize_dtype(knp.nonzero(x)[0].dtype), expected_dtype
)
# TODO: verify Nonzero
self.assertEqual(
standardize_dtype(knp.Nonzero().symbolic_call(x)[0].dtype),
expected_dtype,
)

@parameterized.named_parameters(
named_product(dtypes=itertools.combinations(ALL_DTYPES, 2))
Expand Down

0 comments on commit 895334a

Please sign in to comment.