Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix istft and add class TestMathErrors in ops/math_test.py #19594

Merged
2 changes: 1 addition & 1 deletion keras/api/_tf_keras/keras/losses/__init__.py
Expand Up @@ -8,9 +8,9 @@
from keras.src.losses import get
from keras.src.losses import serialize
from keras.src.losses.loss import Loss
from keras.src.losses.losses import CTC
from keras.src.losses.losses import BinaryCrossentropy
from keras.src.losses.losses import BinaryFocalCrossentropy
from keras.src.losses.losses import CTC
from keras.src.losses.losses import CategoricalCrossentropy
from keras.src.losses.losses import CategoricalFocalCrossentropy
from keras.src.losses.losses import CategoricalHinge
Expand Down
3 changes: 3 additions & 0 deletions keras/src/backend/jax/math.py
Expand Up @@ -204,6 +204,9 @@ def istft(
x = _get_complex_tensor_from_tuple(x)
dtype = jnp.real(x).dtype

if len(x.shape) < 2:
raise ValueError("Input `x` must have at least 2 dimensions.")
Faisal-Alsrheed marked this conversation as resolved.
Show resolved Hide resolved

expected_output_len = fft_length + sequence_stride * (x.shape[-2] - 1)
l_pad = (fft_length - sequence_length) // 2
r_pad = fft_length - sequence_length - l_pad
Expand Down
187 changes: 187 additions & 0 deletions keras/src/backend/jax/math_test.py
@@ -0,0 +1,187 @@
import jax
import jax.numpy as jnp
import pytest

from keras.src import backend
from keras.src import testing
from keras.src.backend.jax.math import _get_complex_tensor_from_tuple
from keras.src.backend.jax.math import istft
from keras.src.backend.jax.math import qr
from keras.src.backend.jax.math import segment_max
from keras.src.backend.jax.math import segment_sum
from keras.src.backend.jax.math import stft


@pytest.mark.skipif(
backend.backend() != "jax", reason="Testing Jax functions only"
)
class TestJaxMathErrors(testing.TestCase):
Faisal-Alsrheed marked this conversation as resolved.
Show resolved Hide resolved

def test_segment_sum_no_num_segments(self):
data = jnp.array([1, 2, 3, 4])
segment_ids = jnp.array([0, 0, 1, 1])
with self.assertRaisesRegex(
ValueError,
"Argument `num_segments` must be set when using the JAX backend.",
):
segment_sum(data, segment_ids)

def test_segment_max_no_num_segments(self):
data = jnp.array([1, 2, 3, 4])
segment_ids = jnp.array([0, 0, 1, 1])
with self.assertRaisesRegex(
ValueError,
"Argument `num_segments` must be set when using the JAX backend.",
):
segment_max(data, segment_ids)

def test_qr_invalid_mode(self):
x = jnp.array([[1, 2], [3, 4]])
invalid_mode = "invalid_mode"
with self.assertRaisesRegex(
ValueError, "Expected one of {'reduced', 'complete'}."
):
qr(x, mode=invalid_mode)

def test_get_complex_tensor_from_tuple_creates_complex_object(self):
real = jnp.array([[1.0, 2.0, 3.0]])
imag = jnp.array([[4.0, 5.0, 6.0]])
complex_tensor = _get_complex_tensor_from_tuple((real, imag))
self.assertTrue(
jnp.iscomplexobj(complex_tensor),
"The output should be a complex object.",
)

def test_get_complex_tensor_from_tuple_correct_real_part(self):
real = jnp.array([[1.0, 2.0, 3.0]])
imag = jnp.array([[4.0, 5.0, 6.0]])
complex_tensor = _get_complex_tensor_from_tuple((real, imag))
self.assertTrue(
jnp.array_equal(jnp.real(complex_tensor), real),
"The real parts should match.",
)

def test_get_complex_tensor_from_tuple_correct_imaginary_part(self):
real = jnp.array([[1.0, 2.0, 3.0]])
imag = jnp.array([[4.0, 5.0, 6.0]])
complex_tensor = _get_complex_tensor_from_tuple((real, imag))
self.assertTrue(
jnp.array_equal(jnp.imag(complex_tensor), imag),
"The imaginary parts should match.",
)

def test_invalid_get_complex_tensor_from_tuple_input_type(self):
with self.assertRaisesRegex(ValueError, "Input `x` should be a tuple"):
_get_complex_tensor_from_tuple(jnp.array([1.0, 2.0, 3.0]))

def test_invalid_get_complex_tensor_from_tuple_input_length(self):
with self.assertRaisesRegex(ValueError, "Input `x` should be a tuple"):
_get_complex_tensor_from_tuple(
(
jnp.array([1.0, 2.0, 3.0]),
jnp.array([4.0, 5.0, 6.0]),
jnp.array([7.0, 8.0, 9.0]),
)
)

def test_get_complex_tensor_from_tuple_mismatched_shapes(self):
real = jnp.array([1.0, 2.0, 3.0])
imag = jnp.array([4.0, 5.0])
with self.assertRaisesRegex(ValueError, "Both the real and imaginary"):
_get_complex_tensor_from_tuple((real, imag))

def test_invalid_not_float_get_complex_tensor_from_tuple_dtype(self):
real = jnp.array([[1, 2, 3]])
imag = jnp.array([[4.0, 5.0, 6.0]])
expected_message = "is not of type float"
with self.assertRaisesRegex(ValueError, expected_message):
_get_complex_tensor_from_tuple((real, imag))

def test_get_complex_tensor_from_tuple_complex_tensor_creation(self):
real = jnp.array([1.0, 2.0])
imag = jnp.array([3.0, 4.0])
expected_complex = jax.lax.complex(real, imag)
result = _get_complex_tensor_from_tuple((real, imag))
self.assertTrue(
jnp.array_equal(result, expected_complex),
msg="Complex tensor not created correctly.",
)

def test_get_complex_tensor_from_tuple_output_completeness(self):
real = jnp.array([1.0, 2.0])
imag = jnp.array([3.0, 4.0])
complex_tensor = _get_complex_tensor_from_tuple((real, imag))
self.assertEqual(
jnp.real(complex_tensor)[0],
real[0],
msg="Real parts are not aligned.",
)
self.assertEqual(
jnp.imag(complex_tensor)[0],
imag[0],
msg="Imaginary parts are not aligned.",
)

def test_stft_invalid_input_type(self):
x = jnp.array([1, 2, 3, 4])
sequence_length = 2
sequence_stride = 1
fft_length = 4
with self.assertRaisesRegex(TypeError, "`float32` or `float64`"):
stft(x, sequence_length, sequence_stride, fft_length)

def test_invalid_fft_length(self):
x = jnp.array([1.0, 2.0, 3.0, 4.0])
sequence_length = 4
sequence_stride = 1
fft_length = 2
with self.assertRaisesRegex(ValueError, "`fft_length` must equal or"):
stft(x, sequence_length, sequence_stride, fft_length)

def test_stft_invalid_window(self):
x = jnp.array([1.0, 2.0, 3.0, 4.0])
sequence_length = 2
sequence_stride = 1
fft_length = 4
window = "invalid_window"
with self.assertRaisesRegex(ValueError, "If a string is passed to"):
stft(x, sequence_length, sequence_stride, fft_length, window=window)

def test_stft_invalid_window_shape(self):
x = jnp.array([1.0, 2.0, 3.0, 4.0])
sequence_length = 2
sequence_stride = 1
fft_length = 4
window = jnp.ones((sequence_length + 1))
with self.assertRaisesRegex(ValueError, "The shape of `window` must"):
stft(x, sequence_length, sequence_stride, fft_length, window=window)

def test_istft_invalid_window_shape_2D_inputs(self):
x = (jnp.array([[1.0, 2.0]]), jnp.array([[3.0, 4.0]]))
sequence_length = 2
sequence_stride = 1
fft_length = 4
incorrect_window = jnp.ones((sequence_length + 1,))
with self.assertRaisesRegex(
ValueError, "The shape of `window` must be equal to"
):
istft(
x,
sequence_length,
sequence_stride,
fft_length,
window=incorrect_window,
)

def test_istft_1D_inputs(self):
real = jnp.array([1.0, 2.0, 3.0, 4.0])
imag = jnp.array([1.0, 2.0, 3.0, 4.0])
x = (real, imag)
sequence_length = 3
sequence_stride = 1
fft_length = 4
window = jnp.ones((sequence_length,))
with self.assertRaisesRegex(ValueError, "Input `x` must have at least"):
istft(
x, sequence_length, sequence_stride, fft_length, window=window
)