Skip to content

Commit

Permalink
Add debugging print statements and comments to core.py
Browse files Browse the repository at this point in the history
  • Loading branch information
Faisal-Alsrheed committed Apr 28, 2024
1 parent 90b0f58 commit ae91ef9
Showing 1 changed file with 37 additions and 18 deletions.
55 changes: 37 additions & 18 deletions keras/src/backend/mlx/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,12 +27,17 @@


def to_mlx_dtype(dtype):
"""Convert input dtype to mlx.core.Dtype, handling strings and mlx.core.Dtype inputs."""
print(f"Input dtype: {dtype}")
if isinstance(dtype, mx.Dtype):
print("dtype is an instance of mx.Dtype")
return dtype
print("Standardizing dtype...")
standardized_dtype = MLX_DTYPES.get(standardize_dtype(dtype), None)
print(f"Standardized dtype: {standardized_dtype}")
if standardized_dtype is None:
print(f"Unsupported dtype for MLX: {dtype}")
raise ValueError(f"Unsupported dtype for MLX: {dtype}")
print("Returning standardized dtype")
return standardized_dtype


Expand All @@ -57,37 +62,51 @@ def __array__(self, dtype=None):


def convert_to_tensor(x, dtype=None, sparse=None):
print(f"Input: {x}, dtype: {dtype}, sparse: {sparse}")
if sparse:
print("`sparse=True` is not supported with mlx backend")
raise ValueError("`sparse=True` is not supported with mlx backend")
# Convert the dtype to a compatible numpy dtype if necessary.
mlx_dtype = to_mlx_dtype(dtype) if dtype is not None else None
np_dtype = (
np.dtype(mlx_dtype) if mlx_dtype else None
) # Ensure dtype is compatible with numpy
print(f"mlx_dtype: {mlx_dtype}")

if is_tensor(x):
# Only cast if necessary to avoid redundant operations
if dtype and x.dtype != mlx_dtype:
x = x.astype(np_dtype)
return x
print("x is a tensor")
if dtype is None:
return x
return x.astype(mlx_dtype)

if isinstance(x, Variable):
# Adjust Variable objects if dtype is different
print("x is an instance of Variable")
if dtype and standardize_dtype(dtype) != x.dtype:
x = x.value.astype(np_dtype)
return x.value.astype(mlx_dtype)
return x.value

if isinstance(x, np.ndarray):
# Ensure dtype consistency for numpy arrays
if dtype and x.dtype != np_dtype:
x = x.astype(np_dtype)
return mx.array(x, dtype=np_dtype)
print("x is an instance of np.ndarray")
if x.dtype == np.int64:
x = x.astype(np.int32)
x = x.astype(standardize_dtype(x.dtype))
return mx.array(x, dtype=mlx_dtype)

# Handle conversion for lists to ensure dtype uniformity
if isinstance(x, list):
return mx.array(x, dtype=np_dtype)
print("x is a list")

def to_scalar_list(x):
print(f"Converting to scalar list: {x}")
if isinstance(x, list):
return [to_scalar_list(xi) for xi in x]
elif isinstance(x, mx.array):
if x.ndim == 0:
return x.item()
else:
return x.tolist()
else:
return x

return mx.array(to_scalar_list(x), dtype=mlx_dtype)

return mx.array(x, dtype=np_dtype)
print("Returning mx.array")
return mx.array(x, dtype=mlx_dtype)


def convert_to_tensors(*xs):
Expand Down

0 comments on commit ae91ef9

Please sign in to comment.