Skip to content

Commit

Permalink
Refactor dtype conversion in convert_to_tensor and to_mlx_dtype funct…
Browse files Browse the repository at this point in the history
…ions
  • Loading branch information
Faisal-Alsrheed committed Apr 28, 2024
1 parent 8b9331a commit 6dc86da
Showing 1 changed file with 32 additions and 23 deletions.
55 changes: 32 additions & 23 deletions keras/src/backend/mlx/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,27 +26,6 @@
}


def to_mlx_dtype(dtype):
print(f"Input dtype: {dtype}")
# Check if the dtype is already an instance of mlx.core.Dtype
if isinstance(dtype, mx.Dtype):
print("dtype is an instance of mx.Dtype")
return dtype

print("Standardizing dtype...")
# Convert numpy dtype string or Python type to a standard dtype string
dtype_str = np.dtype(dtype).name if not isinstance(dtype, str) else dtype
# Fetch the corresponding mlx.core.Dtype
standardized_dtype = MLX_DTYPES.get(dtype_str, None)

print(f"Standardized dtype: {standardized_dtype}")
if standardized_dtype is None:
print(f"Unsupported dtype for MLX: {dtype}")
raise ValueError(f"Unsupported dtype for MLX: {dtype}")
print("Returning standardized dtype")
return standardized_dtype


class Variable(KerasVariable):
def _initialize(self, value):
self._value = convert_to_tensor(value, dtype=self._dtype)
Expand All @@ -67,15 +46,44 @@ def __array__(self, dtype=None):
return value


def to_mlx_dtype(dtype):
"""Converts a Keras dtype to the corresponding MLX dtype."""
print(f"Input dtype: {dtype}")

# Check if dtype is already an instance of mx.core.Dtype
if isinstance(dtype, mx.Dtype):
print("dtype is an instance of mx.Dtype")
return dtype

print("Standardizing dtype...")

# Convert numpy dtype string or Python type to a standard dtype string
dtype_str = np.dtype(dtype).name if not isinstance(dtype, str) else dtype

# Fetch the corresponding mlx.core.Dtype
standardized_dtype = MLX_DTYPES.get(dtype_str, None)
print(f"Standardized dtype: {standardized_dtype}")

if standardized_dtype is None:
print(f"Unsupported dtype for MLX: {dtype}")
raise ValueError(f"Unsupported dtype for MLX: {dtype}")

print("Returning standardized dtype")
return standardized_dtype


def convert_to_tensor(x, dtype=None, sparse=None):
"""Converts the input x to an MLX tensor, handling various input types."""
try:
print(
f"Starting conversion: Input type {type(x)}, dtype requested: {dtype}"
)
print(f"Input: {x}, dtype: {dtype}, sparse: {sparse}")

if sparse:
print("`sparse=True` is not supported with mlx backend")
raise ValueError("`sparse=True` is not supported with mlx backend")

mlx_dtype = to_mlx_dtype(dtype) if dtype is not None else None
print(f"mlx_dtype: {mlx_dtype}")

Expand All @@ -93,9 +101,10 @@ def convert_to_tensor(x, dtype=None, sparse=None):

if isinstance(x, np.ndarray):
print("x is an instance of np.ndarray")
# Ensure compatibility with MLX by converting int64 to int32
if x.dtype == np.int64:
x = x.astype(np.int32)
x = x.astype(standardize_dtype(x.dtype))
x = x.astype(standardize_dtype(x.dtype)) # Standardize dtype
return mx.array(x, dtype=mlx_dtype)

if isinstance(x, list):
Expand All @@ -107,7 +116,7 @@ def to_scalar_list(x):
return [to_scalar_list(xi) for xi in x]
elif isinstance(x, mx.array):
if x.ndim == 0:
return x.item()
return x.item() # Convert 0-dim array to scalar
else:
return x.tolist()
else:
Expand Down

0 comments on commit 6dc86da

Please sign in to comment.