Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

MLX Backend #19571

Open
6 of 59 tasks
lkarthee opened this issue Apr 20, 2024 · 7 comments
Open
6 of 59 tasks

MLX Backend #19571

lkarthee opened this issue Apr 20, 2024 · 7 comments
Labels
stat:contributions welcome A pull request to fix this issue would be welcome.

Comments

@lkarthee
Copy link
Contributor

lkarthee commented Apr 20, 2024

Issue for tracking and coordinating mlx backend work:

mlx.math

mlx.numpy

mlx.image

mlx.nn

  • max_pool
  • avg_pool
  • conv
  • depthwise_conv
  • separable_conv
  • conv_transpose
  • ctc_loss

mlx.rnn

  • rnn
  • lstm
  • gru

mlx.linalg

mlx.core

@lkarthee
Copy link
Contributor Author

lkarthee commented Apr 20, 2024

PyTest Output
=========================================================================== test session starts ============================================================================
platform darwin -- Python 3.12.2, pytest-8.1.1, pluggy-1.4.0 -- /Users/kartheek/erlang-ws/github-ws/latest/keras/.venv/bin/python3.12
cachedir: .pytest_cache
rootdir: /Users/kartheek/erlang-ws/github-ws/latest/keras
configfile: pyproject.toml
plugins: cov-5.0.0
collected 6 items

keras/src/ops/operation_test.py::OperationTest::test_autoconfig PASSED                                                                                               [ 16%]
keras/src/ops/operation_test.py::OperationTest::test_eager_call PASSED                                                                                               [ 33%]
keras/src/ops/operation_test.py::OperationTest::test_input_conversion FAILED                                                                                         [ 50%]
keras/src/ops/operation_test.py::OperationTest::test_serialization PASSED                                                                                            [ 66%]
keras/src/ops/operation_test.py::OperationTest::test_symbolic_call PASSED                                                                                            [ 83%]
keras/src/ops/operation_test.py::OperationTest::test_valid_naming PASSED                                                                                             [100%]

================================================================================= FAILURES =================================================================================
___________________________________________________________________ OperationTest.test_input_conversion ____________________________________________________________________

self = <keras.src.ops.operation_test.OperationTest testMethod=test_input_conversion>

    def test_input_conversion(self):
        x = np.ones((2,))
        y = np.ones((2,))
        z = knp.ones((2,))  # mix
        if backend.backend() == "torch":
            z = z.cpu()
        op = OpWithMultipleInputs()
>       out = op(x, y, z)

keras/src/ops/operation_test.py:152:
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _
keras/src/utils/traceback_utils.py:113: in error_handler
    return fn(*args, **kwargs)
keras/src/ops/operation.py:56: in __call__
    return self.call(*args, **kwargs)
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _

self = <Operation name=op_with_multiple_inputs>, x = array([1., 1.]), y = array([1., 1.])
z = <[ValueError('item can only be called on arrays of size 1.') raised in repr()] array object at 0x13f7450c0>

    def call(self, x, y, z=None):
        # `z` has to be put first due to the order of operations issue with
        # torch backend.
>       return 3 * z + x + 2 * y
E       ValueError: Cannot perform addition on an mlx.core.array and ndarray

keras/src/ops/operation_test.py:14: ValueError
========================================================================= short test summary info ==========================================================================
FAILED keras/src/ops/operation_test.py::OperationTest::test_input_conversion - ValueError: Cannot perform addition on an mlx.core.array and ndarray
======================================================================= 1 failed, 5 passed in 0.13s ========================================================================

How to fix this test case any idea ? add(mx_array, numpy_array) works but fails when using + operator. Should we skip this test for mlx backend ?

@fchollet
Copy link
Member

How to fix this test case any idea ? add(mx_array, numpy_array) works but fails when using + operator. Should we skip this test for mlx backend ?

It's not fixable on our side, we should file an issue with the MLX repo. + will hit array.__add__ which is on their side.

@fchollet fchollet added the stat:contributions welcome A pull request to fix this issue would be welcome. label Apr 20, 2024
@Faisal-Alsrheed
Copy link
Contributor

Thank you for the list.

I am doing

keras/backend/mlx/nn.py:conv
keras/backend/mlx/nn.py:depthwise_conv
keras/backend/mlx/nn.py:separable_conv
keras/backend/mlx/nn.py:conv_transpose

@lkarthee
Copy link
Contributor Author

I am working on segment_sum, segment_max, max_pool and avg_pool. Thank you .

@yrahul3910
Copy link

I want to take a stab at arctan2 (first-time contributor, so I want to start small). I'm working with the mlx team to see if I can add in the required stuff there first, and then I'll add the implementation here.

@lkarthee
Copy link
Contributor Author

lkarthee commented May 2, 2024

Thank you @yrahul3910 , please go ahead with adding arctan2 impl.

@lkarthee
Copy link
Contributor Author

lkarthee commented May 6, 2024

mx.matmul and mx.tensordot works only for bfloat16, float16, float32.

FAILED keras/src/ops/numpy_test.py::NumpyDtypeTest::test_tensordot_('int16', 'bool') - ValueError: [matmul] Only real floating point types are supported but int16 and bool were provided which results in int16, which is not a real floating point type.

@fchollet How do we handle this - we can cast integers arguments to float32 if both are integers and result will be float32. If we go this route, we have to modify test cases in numpy_test.py for mlx. Do you have any suggestions.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
stat:contributions welcome A pull request to fix this issue would be welcome.
Projects
None yet
Development

No branches or pull requests

5 participants