New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[WIP] Fix backend/mlx/core
and backend/common/dtypes
for MLX + Improve integration_tests/numerical_test.py
#19619
Draft
Faisal-Alsrheed
wants to merge
66
commits into
keras-team:mlx
Choose a base branch
from
Faisal-Alsrheed:mlx--conve
base: mlx
Could not load branches
Branch not found: {{ refName }}
Could not load tags
Nothing to show
Are you sure you want to change the base?
Some commits from the old base branch may be removed from the timeline,
and old review comments may become outdated.
Draft
Changes from all commits
Commits
Show all changes
66 commits
Select commit
Hold shift + click to select a range
9648968
Implement convolution operation for MLX backend
Faisal-Alsrheed e73ec18
Fix dilation_rate standardization in conv function
Faisal-Alsrheed 804a7be
Fix validation check for input and kernel channels in conv function
Faisal-Alsrheed 00160e7
Refactor validation check for input and kernel channels in conv function
Faisal-Alsrheed f1bb716
Refactor conv function to improve readability and fix validation checks
Faisal-Alsrheed 55fbace
Refactor _transpose_conv_kernel function to support both channels_fir…
Faisal-Alsrheed 365f48e
Refactor _transpose_conv_kernel and conv functions to improve readabi…
Faisal-Alsrheed 894c5d5
Refactor numerical_test.py to support MLX backend and fix validation …
Faisal-Alsrheed 1b2f2bf
Refactor conv function
Faisal-Alsrheed 437b7a0
Fix validation check for input and kernel channels in conv function
Faisal-Alsrheed 5a62e65
Refactor _transpose_conv_kernel function to support both channels_fir…
Faisal-Alsrheed a35a56c
Refactor numerical_test.py to include data_format parameter in build_…
Faisal-Alsrheed f6278f0
Refactor _transpose_conv_kernel function to support both channels_fir…
Faisal-Alsrheed 4372351
Refactor _transpose_conv_kernel function to handle both channels_firs…
Faisal-Alsrheed 9549880
Refactor _transpose_conv_kernel function to handle both channels_firs…
Faisal-Alsrheed 6514899
Refactor _transpose_conv_kernel function to handle both channels_firs…
Faisal-Alsrheed 682072e
Refactor data_format parameter in conv function to use channels_last …
Faisal-Alsrheed 6d7f5dd
Refactor build_keras_model function in numerical_test.py to use consi…
Faisal-Alsrheed b96047e
Refactor conv function to include support for groups parameter
Faisal-Alsrheed 9e25cb4
Refactor _transpose_spatial_inputs function to handle both channels_f…
Faisal-Alsrheed 9f0aa30
Refactor build_keras_model function in numerical_test.py to include i…
Faisal-Alsrheed e63913d
Refactor build_keras_model function in numerical_test.py to update in…
Faisal-Alsrheed a6941c1
Refactor build_keras_model function in numerical_test.py to remove in…
Faisal-Alsrheed b9d116d
fix
Faisal-Alsrheed 4390780
fix
Faisal-Alsrheed 5fa46eb
fix
Faisal-Alsrheed ec9ab5b
fix _transpose_spatial_inputs and _transpose_conv_kernel
Faisal-Alsrheed e535415
Refactor conv function for debugging
Faisal-Alsrheed cea85e9
Refactor conv function for better debugging
Faisal-Alsrheed 6f4a809
Refactor numerical_test.py to test channels_last
Faisal-Alsrheed 385c5a5
Refactor conv v1
Faisal-Alsrheed bdc3ef1
Refactor numerical_test.py to update Conv2D layer in build_keras_mode…
Faisal-Alsrheed c801dcc
No numerical_test.py
Faisal-Alsrheed 40ce8c9
fix def conv
Faisal-Alsrheed 9de2732
Transposing kernel shape
Faisal-Alsrheed 0d77386
(C_out, C_in, H, W)
Faisal-Alsrheed 66c5c70
kernel.transpose((3, 0, 1, 2))
Faisal-Alsrheed 3ce2103
Refactor conv function
Faisal-Alsrheed 929c31f
Update data type in numerical_test.py to float16 for testing mlx
Faisal-Alsrheed 3c743f7
Update back the data type in numerical_test.py to float32
Faisal-Alsrheed cc207e7
Fix dtype conversion in convert_to_tensor function
Faisal-Alsrheed 6a4aa6b
Fix to_mlx_dtype
Faisal-Alsrheed c99db80
Refactor dtype conversion in to_mlx_dtype and convert_to_tensor funct…
Faisal-Alsrheed e20dee0
Refactor dtype conversion in convert_to_tensor function
Faisal-Alsrheed 90b0f58
Refactor dtype conversion in convert_to_tensor function and fix to_ml…
Faisal-Alsrheed ae91ef9
Add debugging print statements and comments to core.py
Faisal-Alsrheed fb27c5b
Refactor dtype conversion in to_mlx_dtype and convert_to_tensor funct…
Faisal-Alsrheed 8b9331a
Refactor dtype conversion in convert_to_tensor function and add debug…
Faisal-Alsrheed 6dc86da
Refactor dtype conversion in convert_to_tensor and to_mlx_dtype funct…
Faisal-Alsrheed 690e243
No /workspaces/keras/integration_tests/numerical_test.py
Faisal-Alsrheed 77314fe
add debugging print statements
Faisal-Alsrheed 8e040db
fix convert_to_tensor
Faisal-Alsrheed 0668b46
Refactor dtype conversion in convert_to_tensor and to_mlx_dtype funct…
Faisal-Alsrheed 59c9af9
Refactor dtype handle MLX
Faisal-Alsrheed 19171b7
Refactor dtype handle MLX
Faisal-Alsrheed 44a33ec
Refactor dtype handle MLX
Faisal-Alsrheed 62be51b
Refactor dtype handling and add MLX backend support
Faisal-Alsrheed 7f75c3a
Refactor dtype handling and add explicit handling for float8 dtypes
Faisal-Alsrheed 82776dd
comments in result_type
Faisal-Alsrheed fc78a65
Add core test file for mlx
Faisal-Alsrheed 02dc92a
Refactor dtype handling
Faisal-Alsrheed de83556
no numerical_test.py
Faisal-Alsrheed 213a7fc
improve numerical_test.py and standardize_dtype
Faisal-Alsrheed 2e40687
Add DEBUGGING = True
Faisal-Alsrheed c3bfdcb
Refactor numerical_test.py and variables.py
Faisal-Alsrheed 4eb82b3
Refactor dtype handling and add MLX backend support
Faisal-Alsrheed File filter
Filter by extension
Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -70,9 +70,11 @@ def convert_to_tensor(x, dtype=None, sparse=None): | |
return x.value.astype(mlx_dtype) | ||
return x.value | ||
|
||
if isinstance(x, np.ndarray): | ||
if x.dtype == np.int64: | ||
x = x.astype(np.int32) | ||
if isinstance(x, mx.array): | ||
if x.dtype == mx.int64: | ||
x = x.astype(mx.int32) | ||
elif x.dtype == mx.float64: | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. there is no mx.float64 ? |
||
x = x.astype(mx.float32) | ||
x = x.astype(standardize_dtype(x.dtype)) | ||
return mx.array(x, dtype=mlx_dtype) | ||
|
||
|
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,116 @@ | ||
import mlx.core as mx | ||
import numpy as np | ||
import pytest | ||
|
||
from keras.src import backend | ||
from keras.src import testing | ||
from keras.src.backend.mlx import core | ||
|
||
|
||
@pytest.mark.skipif( | ||
backend.backend() != "mlx", | ||
reason="Testing core MLX backend functionality", | ||
) | ||
class TestVariableMethods(testing.TestCase): | ||
def test_initialize(self): | ||
v = core.Variable(5, "int32") | ||
self.assertEqual(v._value, mx.array(5, dtype=mx.int32)) | ||
|
||
def test_direct_assign(self): | ||
v = core.Variable(5, "int32") | ||
v._direct_assign(10) | ||
self.assertEqual(v._value, mx.array(10, dtype=mx.int32)) | ||
|
||
def test_convert_to_tensor(self): | ||
v = core.Variable(5, "int32") | ||
tensor = v._convert_to_tensor(10) | ||
self.assertIsInstance(tensor, mx.array) | ||
self.assertEqual(tensor, mx.array(10, dtype=mx.int32)) | ||
|
||
def test_array_conversion(self): | ||
v = core.Variable(mx.array([1, 2, 3]), "int32") | ||
arr = v.__array__() | ||
arr_mx = mx.array(arr) # Convert arr to a mlx array | ||
self.assertTrue(mx.array_equal(arr_mx, mx.array([1, 2, 3]))) | ||
|
||
def test_array_conversion_multidimensional(self): | ||
v = core.Variable(mx.array([[1, 2, 3], [4, 5, 6]]), "int32") | ||
arr = v.__array__() | ||
arr_mx = mx.array(arr) | ||
self.assertTrue( | ||
mx.array_equal(arr_mx, mx.array([[1, 2, 3], [4, 5, 6]])) | ||
) | ||
|
||
def test_null_initialization(self): | ||
with self.assertRaises(TypeError): | ||
core.Variable(None, "float32") | ||
|
||
def test_to_mlx_dtype(self): | ||
self.assertEqual(core.to_mlx_dtype("float32"), mx.float32) | ||
with self.assertRaises(ValueError): | ||
core.to_mlx_dtype("unsupported_dtype") | ||
|
||
def test_convert_to_tensor_exceptions(self): | ||
with self.assertRaises(ValueError): | ||
core.convert_to_tensor(10, sparse=True) | ||
|
||
def test_convert_to_numpy(self): | ||
arr = mx.array([1, 2, 3]) | ||
np.testing.assert_array_equal(core.convert_to_numpy(arr), arr) | ||
|
||
def test_is_tensor(self): | ||
self.assertTrue(core.is_tensor(mx.array([1, 2, 3]))) | ||
self.assertFalse(core.is_tensor([1, 2, 3])) | ||
|
||
def test_shape(self): | ||
arr = mx.array([1, 2, 3]) | ||
self.assertEqual(core.shape(arr), (3,)) | ||
|
||
def test_cast(self): | ||
tensor = core.cast([1, 2, 3], "float32") | ||
self.assertEqual(tensor.dtype, mx.float32) | ||
|
||
def test_tensor_to_numpy_and_back(self): | ||
tensor = core.cast(mx.array([1.5, 2.5, 3.5]), "float32") | ||
numpy_arr = core.convert_to_numpy(tensor) | ||
tensor_back = core.convert_to_tensor(numpy_arr, "float32") | ||
np.testing.assert_array_equal(tensor, tensor_back) | ||
|
||
def test_with_scalar_values(self): | ||
scalar = 5 | ||
tensor = core.cast(scalar, "int32") | ||
self.assertEqual(tensor, mx.array(5, dtype=mx.int32)) | ||
|
||
def test_with_zero_size_array(self): | ||
empty_arr = np.array([]) | ||
tensor = core.convert_to_tensor(empty_arr, "float32") | ||
self.assertEqual(tensor.size, 0) | ||
|
||
def test_cond(self): | ||
result = core.cond(True, lambda: "true", lambda: "false") | ||
self.assertEqual(result, "true") | ||
|
||
def test_vectorized_map(self): | ||
result = core.vectorized_map(lambda x: x * 2, mx.array([1, 2, 3])) | ||
self.assertTrue(mx.array_equal(result, mx.array([2, 4, 6]))) | ||
|
||
def test_scatter(self): | ||
zeros = mx.zeros((4,)) | ||
result = core.scatter(mx.array([1]), mx.array([10]), zeros.shape) | ||
self.assertTrue(mx.array_equal(result, mx.array([0, 10, 0, 0]))) | ||
|
||
def test_cond_complex_condition(self): | ||
result = core.cond(False, lambda: "true", lambda: "false") | ||
self.assertEqual(result, "false") | ||
|
||
def test_vectorized_map_complex_function(self): | ||
result = core.vectorized_map(lambda x: x * x + 2, mx.array([1, 2, 3])) | ||
self.assertTrue(mx.array_equal(result, mx.array([3, 6, 11]))) | ||
|
||
def test_while_loop(self): | ||
result = core.while_loop(lambda x: x < 5, lambda x: x + 1, [0]) | ||
self.assertEqual(result, (5,)) | ||
|
||
def test_fori_loop(self): | ||
result = core.fori_loop(0, 5, lambda i, x: x + i, 0) | ||
self.assertEqual(result, 10) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
If we remove this cast, mlx crashes when running some test cases on my local machine. I logged a bug on mlx repo (see mlx backend issue).
It did not crash on CI tests on this pr because it is not being run on macOS ? Am really confused as I thought mlx runs only on macOS.
https://github.com/keras-team/keras/actions/runs/8907542811/job/24461598952#step:1:2