Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[WIP] Fix backend/mlx/core and backend/common/dtypes for MLX + Improve integration_tests/numerical_test.py #19619

Draft
wants to merge 66 commits into
base: mlx
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
66 commits
Select commit Hold shift + click to select a range
9648968
Implement convolution operation for MLX backend
Faisal-Alsrheed Apr 26, 2024
e73ec18
Fix dilation_rate standardization in conv function
Faisal-Alsrheed Apr 26, 2024
804a7be
Fix validation check for input and kernel channels in conv function
Faisal-Alsrheed Apr 26, 2024
00160e7
Refactor validation check for input and kernel channels in conv function
Faisal-Alsrheed Apr 26, 2024
f1bb716
Refactor conv function to improve readability and fix validation checks
Faisal-Alsrheed Apr 26, 2024
55fbace
Refactor _transpose_conv_kernel function to support both channels_fir…
Faisal-Alsrheed Apr 26, 2024
365f48e
Refactor _transpose_conv_kernel and conv functions to improve readabi…
Faisal-Alsrheed Apr 26, 2024
894c5d5
Refactor numerical_test.py to support MLX backend and fix validation …
Faisal-Alsrheed Apr 26, 2024
1b2f2bf
Refactor conv function
Faisal-Alsrheed Apr 26, 2024
437b7a0
Fix validation check for input and kernel channels in conv function
Faisal-Alsrheed Apr 26, 2024
5a62e65
Refactor _transpose_conv_kernel function to support both channels_fir…
Faisal-Alsrheed Apr 26, 2024
a35a56c
Refactor numerical_test.py to include data_format parameter in build_…
Faisal-Alsrheed Apr 26, 2024
f6278f0
Refactor _transpose_conv_kernel function to support both channels_fir…
Faisal-Alsrheed Apr 26, 2024
4372351
Refactor _transpose_conv_kernel function to handle both channels_firs…
Faisal-Alsrheed Apr 26, 2024
9549880
Refactor _transpose_conv_kernel function to handle both channels_firs…
Faisal-Alsrheed Apr 26, 2024
6514899
Refactor _transpose_conv_kernel function to handle both channels_firs…
Faisal-Alsrheed Apr 26, 2024
682072e
Refactor data_format parameter in conv function to use channels_last …
Faisal-Alsrheed Apr 26, 2024
6d7f5dd
Refactor build_keras_model function in numerical_test.py to use consi…
Faisal-Alsrheed Apr 26, 2024
b96047e
Refactor conv function to include support for groups parameter
Faisal-Alsrheed Apr 26, 2024
9e25cb4
Refactor _transpose_spatial_inputs function to handle both channels_f…
Faisal-Alsrheed Apr 26, 2024
9f0aa30
Refactor build_keras_model function in numerical_test.py to include i…
Faisal-Alsrheed Apr 26, 2024
e63913d
Refactor build_keras_model function in numerical_test.py to update in…
Faisal-Alsrheed Apr 26, 2024
a6941c1
Refactor build_keras_model function in numerical_test.py to remove in…
Faisal-Alsrheed Apr 26, 2024
b9d116d
fix
Faisal-Alsrheed Apr 26, 2024
4390780
fix
Faisal-Alsrheed Apr 26, 2024
5fa46eb
fix
Faisal-Alsrheed Apr 26, 2024
ec9ab5b
fix _transpose_spatial_inputs and _transpose_conv_kernel
Faisal-Alsrheed Apr 26, 2024
e535415
Refactor conv function for debugging
Faisal-Alsrheed Apr 27, 2024
cea85e9
Refactor conv function for better debugging
Faisal-Alsrheed Apr 27, 2024
6f4a809
Refactor numerical_test.py to test channels_last
Faisal-Alsrheed Apr 27, 2024
385c5a5
Refactor conv v1
Faisal-Alsrheed Apr 27, 2024
bdc3ef1
Refactor numerical_test.py to update Conv2D layer in build_keras_mode…
Faisal-Alsrheed Apr 27, 2024
c801dcc
No numerical_test.py
Faisal-Alsrheed Apr 27, 2024
40ce8c9
fix def conv
Faisal-Alsrheed Apr 27, 2024
9de2732
Transposing kernel shape
Faisal-Alsrheed Apr 27, 2024
0d77386
(C_out, C_in, H, W)
Faisal-Alsrheed Apr 27, 2024
66c5c70
kernel.transpose((3, 0, 1, 2))
Faisal-Alsrheed Apr 27, 2024
3ce2103
Refactor conv function
Faisal-Alsrheed Apr 28, 2024
929c31f
Update data type in numerical_test.py to float16 for testing mlx
Faisal-Alsrheed Apr 28, 2024
3c743f7
Update back the data type in numerical_test.py to float32
Faisal-Alsrheed Apr 28, 2024
cc207e7
Fix dtype conversion in convert_to_tensor function
Faisal-Alsrheed Apr 28, 2024
6a4aa6b
Fix to_mlx_dtype
Faisal-Alsrheed Apr 28, 2024
c99db80
Refactor dtype conversion in to_mlx_dtype and convert_to_tensor funct…
Faisal-Alsrheed Apr 28, 2024
e20dee0
Refactor dtype conversion in convert_to_tensor function
Faisal-Alsrheed Apr 28, 2024
90b0f58
Refactor dtype conversion in convert_to_tensor function and fix to_ml…
Faisal-Alsrheed Apr 28, 2024
ae91ef9
Add debugging print statements and comments to core.py
Faisal-Alsrheed Apr 28, 2024
fb27c5b
Refactor dtype conversion in to_mlx_dtype and convert_to_tensor funct…
Faisal-Alsrheed Apr 28, 2024
8b9331a
Refactor dtype conversion in convert_to_tensor function and add debug…
Faisal-Alsrheed Apr 28, 2024
6dc86da
Refactor dtype conversion in convert_to_tensor and to_mlx_dtype funct…
Faisal-Alsrheed Apr 28, 2024
690e243
No /workspaces/keras/integration_tests/numerical_test.py
Faisal-Alsrheed Apr 28, 2024
77314fe
add debugging print statements
Faisal-Alsrheed Apr 28, 2024
8e040db
fix convert_to_tensor
Faisal-Alsrheed Apr 28, 2024
0668b46
Refactor dtype conversion in convert_to_tensor and to_mlx_dtype funct…
Faisal-Alsrheed Apr 28, 2024
59c9af9
Refactor dtype handle MLX
Faisal-Alsrheed Apr 29, 2024
19171b7
Refactor dtype handle MLX
Faisal-Alsrheed Apr 29, 2024
44a33ec
Refactor dtype handle MLX
Faisal-Alsrheed Apr 30, 2024
62be51b
Refactor dtype handling and add MLX backend support
Faisal-Alsrheed Apr 30, 2024
7f75c3a
Refactor dtype handling and add explicit handling for float8 dtypes
Faisal-Alsrheed Apr 30, 2024
82776dd
comments in result_type
Faisal-Alsrheed Apr 30, 2024
fc78a65
Add core test file for mlx
Faisal-Alsrheed Apr 30, 2024
02dc92a
Refactor dtype handling
Faisal-Alsrheed Apr 30, 2024
de83556
no numerical_test.py
Faisal-Alsrheed Apr 30, 2024
213a7fc
improve numerical_test.py and standardize_dtype
Faisal-Alsrheed May 1, 2024
2e40687
Add DEBUGGING = True
Faisal-Alsrheed May 1, 2024
c3bfdcb
Refactor numerical_test.py and variables.py
Faisal-Alsrheed May 1, 2024
4eb82b3
Refactor dtype handling and add MLX backend support
Faisal-Alsrheed May 1, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
111 changes: 74 additions & 37 deletions integration_tests/numerical_test.py
Expand Up @@ -6,6 +6,7 @@
keras.backend.set_image_data_format("channels_last")
tf_keras.backend.set_image_data_format("channels_last")

DEBUGGING = True
NUM_CLASSES = 10
BATCH_SIZE = 32
EPOCHS = 1
Expand Down Expand Up @@ -95,45 +96,81 @@ def predict_model(model, x):


def numerical_test():
x_train, y_train = build_mnist_data(NUM_CLASSES)
keras_model = build_keras_model(keras, NUM_CLASSES)
tf_keras_model = build_keras_model(tf_keras, NUM_CLASSES)

# Make sure both model have same weights before training
weights = [weight.numpy() for weight in keras_model.weights]
tf_keras_model.set_weights(weights)

for kw, kcw in zip(keras_model.weights, tf_keras_model.weights):
np.testing.assert_allclose(kw.numpy(), kcw.numpy())

compile_model(keras_model)
compile_model(tf_keras_model)

print("Checking training histories:")
keras_history = train_model(keras_model, x_train, y_train)
tf_keras_history = train_model(tf_keras_model, x_train, y_train)
check_history(keras_history, tf_keras_history)
print("Training histories match.")
print()
print("Building data and creating models:")
try:
x_train, y_train = build_mnist_data(NUM_CLASSES)
keras_model = build_keras_model(keras, NUM_CLASSES)
tf_keras_model = build_keras_model(tf_keras, NUM_CLASSES)
print("Data building and model creation passed.")
except Exception as e:
print("Data building and model creation failed with error:", e)
if not DEBUGGING:
raise

print("Setting and checking weights:")
try:
weights = [weight.numpy() for weight in keras_model.weights]
tf_keras_model.set_weights(weights)
for kw, kcw in zip(keras_model.weights, tf_keras_model.weights):
np.testing.assert_allclose(kw.numpy(), kcw.numpy())
print("Weight setting and checking passed.")
except Exception as e:
print("Weight setting and checking failed with error:", e)
if not DEBUGGING:
raise

print("Compiling models:")
try:
compile_model(keras_model)
compile_model(tf_keras_model)
print("Model compilation passed.")
except Exception as e:
print("Model compilation failed with error:", e)
if not DEBUGGING:
raise

print("Training models and checking histories:")
try:
keras_history = train_model(keras_model, x_train, y_train)
tf_keras_history = train_model(tf_keras_model, x_train, y_train)
check_history(keras_history, tf_keras_history)
print("Training and history checking passed.")
except Exception as e:
print("Training and history checking failed with error:", e)
if not DEBUGGING:
raise

print("Checking trained weights:")
for kw, kcw in zip(keras_model.weights, tf_keras_model.weights):
np.testing.assert_allclose(kw.numpy(), kcw.numpy(), atol=1e-3)
print("Trained weights match.")
print()

print("Checking predict:")
outputs1 = predict_model(keras_model, x_train)
outputs2 = predict_model(tf_keras_model, x_train)
np.testing.assert_allclose(outputs1, outputs2, atol=1e-3)
print("Predict results match.")
print()

print("Checking evaluate:")
score1 = eval_model(keras_model, x_train, y_train)
score2 = eval_model(tf_keras_model, x_train, y_train)
np.testing.assert_allclose(score1, score2, atol=1e-3)
print("Evaluate results match.")
try:
for kw, kcw in zip(keras_model.weights, tf_keras_model.weights):
np.testing.assert_allclose(kw.numpy(), kcw.numpy(), atol=1e-3)
print("Trained weights checking passed.")
except Exception as e:
print("Trained weights checking failed with error:", e)
if not DEBUGGING:
raise

print("Predicting with models:")
try:
outputs1 = predict_model(keras_model, x_train)
outputs2 = predict_model(tf_keras_model, x_train)
np.testing.assert_allclose(outputs1, outputs2, atol=1e-3)
print("Prediction passed.")
except Exception as e:
print("Prediction failed with error:", e)
if not DEBUGGING:
raise

print("Evaluating models:")
try:
score1 = eval_model(keras_model, x_train, y_train)
score2 = eval_model(tf_keras_model, x_train, y_train)
np.testing.assert_allclose(score1, score2, atol=1e-3)
print("Evaluation passed.")
except Exception as e:
print("Evaluation failed with error:", e)
if not DEBUGGING:
raise


if __name__ == "__main__":
Expand Down
24 changes: 15 additions & 9 deletions keras/src/backend/common/dtypes.py
Expand Up @@ -141,9 +141,16 @@ def _least_upper_bound(*nodes):
bounds = [UB[n] for n in N]
except KeyError:
dtype = next(n for n in N if n not in UB)
raise ValueError(
f"{dtype=} is not a valid dtype for Keras type promotion."
)
# Special handling for float8 types
if dtype.startswith("float8"):
raise ValueError(
"There is no implicit conversions from float8 dtypes to others."
f" You must cast it internally. Received dtype='{dtype}'"
)
else:
raise ValueError(
f"{dtype=} is not a valid dtype for Keras type promotion."
)
CUB = set.intersection(*bounds)
LUB = (CUB & N) or {c for c in CUB if CUB.issubset(UB[c])}
if len(LUB) == 1:
Expand Down Expand Up @@ -300,12 +307,11 @@ def result_type(*dtypes):
# If no dtypes provided, default to floatx, this matches
# `ops.convert_to_tensor([])`
return config.floatx()

standardized_dtypes = []
for dtype in dtypes:
if dtype in FLOAT8_TYPES:
raise ValueError(
"There is no implicit conversions from float8 dtypes to others."
f" You must cast it internally. Received: {dtypes}"
)
if dtype is None:
standardized_dtypes.append(config.floatx())
return _lattice_result_type(
*(config.floatx() if arg is None else arg for arg in dtypes),
*(config.floatx() if arg is None else arg for arg in dtypes)
)
12 changes: 11 additions & 1 deletion keras/src/backend/common/dtypes_test.py
Expand Up @@ -57,6 +57,14 @@ def test_result_type_with_python_scalar_types(self, dtype1, dtype2):
def test_result_type_with_tensor(self, dtype1, dtype2):
import jax.numpy as jnp

from keras import backend

# Skip float64 tests for MLX backend because it is not supported.
if (
dtype1 == "float64" or dtype2 == "float64"
) and backend.backend() == "mlx":
self.skipTest("Unsupported dtype for MLX: float64")

x1 = ops.ones((1,), dtype=dtype1)
x2 = ops.ones((1,), dtype=dtype2)
x1_jax = jnp.ones((1,), dtype=dtype1)
Expand Down Expand Up @@ -221,11 +229,13 @@ def test_least_upper_bound_with_no_common_upper_bound(self):
):
dtypes._least_upper_bound("test_dtype1", "test_dtype2")

def test_invalid_float8_dtype(self):
def test_invalid_float8_dtype_e4m3fn(self):
with self.assertRaisesRegex(
ValueError, "There is no implicit conversions from float8 dtypes"
):
dtypes.result_type("float8_e4m3fn", "bfloat16")

def test_invalid_float8_dtype_e5m2(self):
with self.assertRaisesRegex(
ValueError, "There is no implicit conversions from float8 dtypes"
):
Expand Down
11 changes: 8 additions & 3 deletions keras/src/backend/common/variables.py
Expand Up @@ -507,13 +507,18 @@ def initialize_all_variables():
def standardize_dtype(dtype):
if dtype is None:
return config.floatx()

# Convert MLX data types to strings for comparison
if hasattr(dtype, "__module__") and dtype.__module__.startswith("mlx."):
dtype = str(dtype).split(".")[-1]

dtype = dtypes.PYTHON_DTYPES_MAP.get(dtype, dtype)

# Existing logic for other backends
if hasattr(dtype, "name"):
dtype = dtype.name
elif hasattr(dtype, "__str__") and (
"torch" in str(dtype)
or "jax.numpy" in str(dtype)
or "mlx" in str(dtype)
"torch" in str(dtype) or "jax.numpy" in str(dtype)
):
dtype = str(dtype).split(".")[-1]
elif hasattr(dtype, "__name__"):
Expand Down
6 changes: 6 additions & 0 deletions keras/src/backend/common/variables_test.py
Expand Up @@ -175,6 +175,12 @@ def test_standardize_dtype(self, dtype):
f"jax backend does not support {dtype} without x64 enabled"
)

if backend.backend() == "mlx" and dtype in (
"float8_e4m3fn",
"float8_e5m2",
"float64",
):
self.skipTest(f"MLX backend does not support dtype {dtype}")
x = backend.convert_to_tensor(np.zeros(()), dtype)
actual = standardize_dtype(x.dtype)
self.assertEqual(actual, dtype)
Expand Down
8 changes: 5 additions & 3 deletions keras/src/backend/mlx/core.py
Expand Up @@ -70,9 +70,11 @@ def convert_to_tensor(x, dtype=None, sparse=None):
return x.value.astype(mlx_dtype)
return x.value

if isinstance(x, np.ndarray):
Copy link
Contributor

@lkarthee lkarthee May 4, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If we remove this cast, mlx crashes when running some test cases on my local machine. I logged a bug on mlx repo (see mlx backend issue).

It did not crash on CI tests on this pr because it is not being run on macOS ? Am really confused as I thought mlx runs only on macOS.

https://github.com/keras-team/keras/actions/runs/8907542811/job/24461598952#step:1:2

2024-05-01T09:07:10.5470064Z Current runner version: '2.316.0'
2024-05-01T09:07:10.5494922Z ##[group]Operating System
2024-05-01T09:07:10.5495654Z Ubuntu
2024-05-01T09:07:10.5496034Z 22.04.4
2024-05-01T09:07:10.5496669Z LTS
2024-05-01T09:07:10.5497089Z ##[endgroup]
2024-05-01T09:07:10.5497521Z ##[group]Runner Image
2024-05-01T09:07:10.5498174Z Image: ubuntu-22.04
2024-05-01T09:07:10.5498671Z Version: 20240422.1.0
2024-05-01T09:07:10.5499877Z Included Software: https://github.com/actions/runner-images/blob/ubuntu22/20240422.1/images/ubuntu/Ubuntu2204-Readme.md
2024-05-01T09:07:10.5501763Z Image Release: https://github.com/actions/runner-images/releases/tag/ubuntu22%2F20240422.1

if x.dtype == np.int64:
x = x.astype(np.int32)
if isinstance(x, mx.array):
if x.dtype == mx.int64:
x = x.astype(mx.int32)
elif x.dtype == mx.float64:
Copy link
Contributor

@lkarthee lkarthee May 4, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

there is no mx.float64 ?

x = x.astype(mx.float32)
x = x.astype(standardize_dtype(x.dtype))
return mx.array(x, dtype=mlx_dtype)

Expand Down
116 changes: 116 additions & 0 deletions keras/src/backend/mlx/core_test.py
@@ -0,0 +1,116 @@
import mlx.core as mx
import numpy as np
import pytest

from keras.src import backend
from keras.src import testing
from keras.src.backend.mlx import core


@pytest.mark.skipif(
backend.backend() != "mlx",
reason="Testing core MLX backend functionality",
)
class TestVariableMethods(testing.TestCase):
def test_initialize(self):
v = core.Variable(5, "int32")
self.assertEqual(v._value, mx.array(5, dtype=mx.int32))

def test_direct_assign(self):
v = core.Variable(5, "int32")
v._direct_assign(10)
self.assertEqual(v._value, mx.array(10, dtype=mx.int32))

def test_convert_to_tensor(self):
v = core.Variable(5, "int32")
tensor = v._convert_to_tensor(10)
self.assertIsInstance(tensor, mx.array)
self.assertEqual(tensor, mx.array(10, dtype=mx.int32))

def test_array_conversion(self):
v = core.Variable(mx.array([1, 2, 3]), "int32")
arr = v.__array__()
arr_mx = mx.array(arr) # Convert arr to a mlx array
self.assertTrue(mx.array_equal(arr_mx, mx.array([1, 2, 3])))

def test_array_conversion_multidimensional(self):
v = core.Variable(mx.array([[1, 2, 3], [4, 5, 6]]), "int32")
arr = v.__array__()
arr_mx = mx.array(arr)
self.assertTrue(
mx.array_equal(arr_mx, mx.array([[1, 2, 3], [4, 5, 6]]))
)

def test_null_initialization(self):
with self.assertRaises(TypeError):
core.Variable(None, "float32")

def test_to_mlx_dtype(self):
self.assertEqual(core.to_mlx_dtype("float32"), mx.float32)
with self.assertRaises(ValueError):
core.to_mlx_dtype("unsupported_dtype")

def test_convert_to_tensor_exceptions(self):
with self.assertRaises(ValueError):
core.convert_to_tensor(10, sparse=True)

def test_convert_to_numpy(self):
arr = mx.array([1, 2, 3])
np.testing.assert_array_equal(core.convert_to_numpy(arr), arr)

def test_is_tensor(self):
self.assertTrue(core.is_tensor(mx.array([1, 2, 3])))
self.assertFalse(core.is_tensor([1, 2, 3]))

def test_shape(self):
arr = mx.array([1, 2, 3])
self.assertEqual(core.shape(arr), (3,))

def test_cast(self):
tensor = core.cast([1, 2, 3], "float32")
self.assertEqual(tensor.dtype, mx.float32)

def test_tensor_to_numpy_and_back(self):
tensor = core.cast(mx.array([1.5, 2.5, 3.5]), "float32")
numpy_arr = core.convert_to_numpy(tensor)
tensor_back = core.convert_to_tensor(numpy_arr, "float32")
np.testing.assert_array_equal(tensor, tensor_back)

def test_with_scalar_values(self):
scalar = 5
tensor = core.cast(scalar, "int32")
self.assertEqual(tensor, mx.array(5, dtype=mx.int32))

def test_with_zero_size_array(self):
empty_arr = np.array([])
tensor = core.convert_to_tensor(empty_arr, "float32")
self.assertEqual(tensor.size, 0)

def test_cond(self):
result = core.cond(True, lambda: "true", lambda: "false")
self.assertEqual(result, "true")

def test_vectorized_map(self):
result = core.vectorized_map(lambda x: x * 2, mx.array([1, 2, 3]))
self.assertTrue(mx.array_equal(result, mx.array([2, 4, 6])))

def test_scatter(self):
zeros = mx.zeros((4,))
result = core.scatter(mx.array([1]), mx.array([10]), zeros.shape)
self.assertTrue(mx.array_equal(result, mx.array([0, 10, 0, 0])))

def test_cond_complex_condition(self):
result = core.cond(False, lambda: "true", lambda: "false")
self.assertEqual(result, "false")

def test_vectorized_map_complex_function(self):
result = core.vectorized_map(lambda x: x * x + 2, mx.array([1, 2, 3]))
self.assertTrue(mx.array_equal(result, mx.array([3, 6, 11])))

def test_while_loop(self):
result = core.while_loop(lambda x: x < 5, lambda x: x + 1, [0])
self.assertEqual(result, (5,))

def test_fori_loop(self):
result = core.fori_loop(0, 5, lambda i, x: x + i, 0)
self.assertEqual(result, 10)
1 change: 0 additions & 1 deletion keras/src/backend/mlx/trainer.py
Expand Up @@ -3,7 +3,6 @@

from keras.src import backend
from keras.src import callbacks as callbacks_module
from keras.src import ops
from keras.src import optimizers as optimizers_module
from keras.src import tree
from keras.src.backend.common import standardize_dtype
Expand Down