Skip to content

Commit

Permalink
Fix istft and add class TestMathErrors in ops/math_test.py (#19594
Browse files Browse the repository at this point in the history
)

* Fix and test math functions for jax backend

* run /workspaces/keras/shell/format.sh

* refix

* fix

* fix _get_complex_tensor_from_tuple

* fix

* refix

* Fix istft function to handle inputs with less than 2 dimensions

* fix

* Fix ValueError in istft function for inputs with less than 2 dimensions
  • Loading branch information
Faisal-Alsrheed committed Apr 29, 2024
1 parent 4cb5671 commit 61d85f3
Show file tree
Hide file tree
Showing 3 changed files with 103 additions and 0 deletions.
6 changes: 6 additions & 0 deletions keras/src/backend/jax/math.py
Expand Up @@ -204,6 +204,12 @@ def istft(
x = _get_complex_tensor_from_tuple(x)
dtype = jnp.real(x).dtype

if len(x.shape) < 2:
raise ValueError(
f"Input `x` must have at least 2 dimensions. "
f"Received shape: {x.shape}"
)

expected_output_len = fft_length + sequence_stride * (x.shape[-2] - 1)
l_pad = (fft_length - sequence_length) // 2
r_pad = fft_length - sequence_length - l_pad
Expand Down
9 changes: 9 additions & 0 deletions keras/src/ops/linalg_test.py
Expand Up @@ -101,6 +101,15 @@ def test_qr(self):
self.assertEqual(q.shape, qref_shape)
self.assertEqual(r.shape, rref_shape)

def test_qr_invalid_mode(self):
# backend agnostic error message
x = np.array([[1, 2], [3, 4]])
invalid_mode = "invalid_mode"
with self.assertRaisesRegex(
ValueError, "Expected one of {'reduced', 'complete'}."
):
linalg.qr(x, mode=invalid_mode)

def test_solve(self):
a = KerasTensor([None, 20, 20])
b = KerasTensor([None, 20, 5])
Expand Down
88 changes: 88 additions & 0 deletions keras/src/ops/math_test.py
@@ -1,5 +1,6 @@
import math

import jax.numpy as jnp
import numpy as np
import pytest
import scipy.signal
Expand Down Expand Up @@ -1256,3 +1257,90 @@ def test_undefined_fft_length_and_last_dimension(self):
expected_shape = real_part.shape[:-1] + (None,)

self.assertEqual(output_spec.shape, expected_shape)


class TestMathErrors(testing.TestCase):

@pytest.mark.skipif(
backend.backend() != "jax", reason="Testing Jax errors only"
)
def test_segment_sum_no_num_segments(self):
data = jnp.array([1, 2, 3, 4])
segment_ids = jnp.array([0, 0, 1, 1])
with self.assertRaisesRegex(
ValueError,
"Argument `num_segments` must be set when using the JAX backend.",
):
kmath.segment_sum(data, segment_ids)

@pytest.mark.skipif(
backend.backend() != "jax", reason="Testing Jax errors only"
)
def test_segment_max_no_num_segments(self):
data = jnp.array([1, 2, 3, 4])
segment_ids = jnp.array([0, 0, 1, 1])
with self.assertRaisesRegex(
ValueError,
"Argument `num_segments` must be set when using the JAX backend.",
):
kmath.segment_max(data, segment_ids)

def test_stft_invalid_input_type(self):
# backend agnostic error message
x = np.array([1, 2, 3, 4])
sequence_length = 2
sequence_stride = 1
fft_length = 4
with self.assertRaisesRegex(TypeError, "`float32` or `float64`"):
kmath.stft(x, sequence_length, sequence_stride, fft_length)

def test_invalid_fft_length(self):
# backend agnostic error message
x = np.array([1.0, 2.0, 3.0, 4.0])
sequence_length = 4
sequence_stride = 1
fft_length = 2
with self.assertRaisesRegex(ValueError, "`fft_length` must equal or"):
kmath.stft(x, sequence_length, sequence_stride, fft_length)

def test_stft_invalid_window(self):
# backend agnostic error message
x = np.array([1.0, 2.0, 3.0, 4.0])
sequence_length = 2
sequence_stride = 1
fft_length = 4
window = "invalid_window"
with self.assertRaisesRegex(ValueError, "If a string is passed to"):
kmath.stft(
x, sequence_length, sequence_stride, fft_length, window=window
)

def test_stft_invalid_window_shape(self):
# backend agnostic error message
x = np.array([1.0, 2.0, 3.0, 4.0])
sequence_length = 2
sequence_stride = 1
fft_length = 4
window = np.ones((sequence_length + 1))
with self.assertRaisesRegex(ValueError, "The shape of `window` must"):
kmath.stft(
x, sequence_length, sequence_stride, fft_length, window=window
)

def test_istft_invalid_window_shape_2D_inputs(self):
# backend agnostic error message
x = (np.array([[1.0, 2.0]]), np.array([[3.0, 4.0]]))
sequence_length = 2
sequence_stride = 1
fft_length = 4
incorrect_window = np.ones((sequence_length + 1,))
with self.assertRaisesRegex(
ValueError, "The shape of `window` must be equal to"
):
kmath.istft(
x,
sequence_length,
sequence_stride,
fft_length,
window=incorrect_window,
)

0 comments on commit 61d85f3

Please sign in to comment.