Skip to content

Commit

Permalink
add debugging print statements
Browse files Browse the repository at this point in the history
  • Loading branch information
Faisal-Alsrheed committed Apr 28, 2024
1 parent 690e243 commit 77314fe
Showing 1 changed file with 31 additions and 33 deletions.
64 changes: 31 additions & 33 deletions keras/src/backend/mlx/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -72,63 +72,61 @@ def to_mlx_dtype(dtype):
return standardized_dtype


import logging

logging.basicConfig(level=logging.DEBUG)


def debug_print(message):
logging.debug(message)


def convert_to_tensor(x, dtype=None, sparse=None):
"""Converts the input x to an MLX tensor, handling various input types."""
try:
print(
f"Starting conversion: Input type {type(x)}, dtype requested: {dtype}"
debug_print(
f"Starting conversion: Input type {type(x)}, dtype requested: {dtype}, input: {x}, sparse: {sparse}"
)
print(f"Input: {x}, dtype: {dtype}, sparse: {sparse}")

if sparse:
print("`sparse=True` is not supported with mlx backend")
debug_print("`sparse=True` is not supported with mlx backend")
raise ValueError("`sparse=True` is not supported with mlx backend")

mlx_dtype = to_mlx_dtype(dtype) if dtype is not None else None
print(f"mlx_dtype: {mlx_dtype}")
debug_print(f"mlx_dtype resolved to: {mlx_dtype}")

if is_tensor(x):
print("x is a tensor")
if dtype is None:
return x
return x.astype(mlx_dtype)
debug_print("Input is already a tensor")
return x.astype(mlx_dtype) if dtype else x

if isinstance(x, Variable):
print("x is an instance of Variable")
if dtype and standardize_dtype(dtype) != x.dtype:
return x.value.astype(mlx_dtype)
return x.value
debug_print("Input is an instance of Variable")
return (
x.value.astype(mlx_dtype)
if dtype and standardize_dtype(dtype) != x.dtype
else x.value
)

if isinstance(x, np.ndarray):
print("x is an instance of np.ndarray")
# Ensure compatibility with MLX by converting int64 to int32
debug_print("Input is an instance of np.ndarray")
if x.dtype == np.int64:
x = x.astype(np.int32)
x = x.astype(standardize_dtype(x.dtype)) # Standardize dtype
debug_print("Converted int64 to int32 due to MLX compatibility")
x = x.astype(standardize_dtype(x.dtype))
return mx.array(x, dtype=mlx_dtype)

if isinstance(x, list):
print("x is a list")

def to_scalar_list(x):
print(f"Converting to scalar list: {x}")
if isinstance(x, list):
return [to_scalar_list(xi) for xi in x]
elif isinstance(x, mx.array):
if x.ndim == 0:
return x.item() # Convert 0-dim array to scalar
else:
return x.tolist()
else:
return x

return mx.array(to_scalar_list(x), dtype=mlx_dtype)
debug_print("Input is a list")
converted_list = [
convert_to_tensor(item, dtype=mlx_dtype) for item in x
]
return mx.array(converted_list, dtype=mlx_dtype)

print("Returning mx.array")
debug_print("Returning mx.array for the input")
return mx.array(x, dtype=mlx_dtype)

except Exception as e:
print(f"Failed to convert tensor: {e}")
debug_print(f"Failed to convert tensor: {e}")
raise


Expand Down

0 comments on commit 77314fe

Please sign in to comment.