Skip to content

Commit

Permalink
Refactor dtype conversion in convert_to_tensor function and fix to_ml…
Browse files Browse the repository at this point in the history
…x_dtype
  • Loading branch information
Faisal-Alsrheed committed Apr 28, 2024
1 parent e20dee0 commit 90b0f58
Showing 1 changed file with 11 additions and 8 deletions.
19 changes: 11 additions & 8 deletions keras/src/backend/mlx/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,32 +59,35 @@ def __array__(self, dtype=None):
def convert_to_tensor(x, dtype=None, sparse=None):
if sparse:
raise ValueError("`sparse=True` is not supported with mlx backend")
# Convert the dtype to mlx.Dtype early to avoid comparison issues later.
# Convert the dtype to a compatible numpy dtype if necessary.
mlx_dtype = to_mlx_dtype(dtype) if dtype is not None else None
np_dtype = (
np.dtype(mlx_dtype) if mlx_dtype else None
) # Ensure dtype is compatible with numpy

if is_tensor(x):
# Only cast if necessary to avoid redundant operations
if dtype and x.dtype != mlx_dtype:
x = x.astype(mlx_dtype)
x = x.astype(np_dtype)
return x

if isinstance(x, Variable):
# Adjust Variable objects if dtype is different
if dtype and standardize_dtype(dtype) != x.dtype:
x = x.value.astype(mlx_dtype)
x = x.value.astype(np_dtype)
return x.value

if isinstance(x, np.ndarray):
# Ensure dtype consistency for numpy arrays
if dtype and x.dtype != np.dtype(mlx_dtype):
x = x.astype(mlx_dtype)
return mx.array(x, dtype=mlx_dtype)
if dtype and x.dtype != np_dtype:
x = x.astype(np_dtype)
return mx.array(x, dtype=np_dtype)

# Handle conversion for lists to ensure dtype uniformity
if isinstance(x, list):
return mx.array(x, dtype=mlx_dtype)
return mx.array(x, dtype=np_dtype)

return mx.array(x, dtype=mlx_dtype)
return mx.array(x, dtype=np_dtype)


def convert_to_tensors(*xs):
Expand Down

0 comments on commit 90b0f58

Please sign in to comment.