Skip to content

Commit

Permalink
fix(axes): Handle np.ravel and using poorly-initialized arrays
Browse files Browse the repository at this point in the history
Also add tests for np.ravel and np.ma.ravel
  • Loading branch information
Jacob-Stevens-Haas committed Feb 29, 2024
1 parent a0e36dd commit 098d231
Show file tree
Hide file tree
Showing 2 changed files with 43 additions and 0 deletions.
22 changes: 22 additions & 0 deletions pysindy/utils/axes.py
Expand Up @@ -346,11 +346,23 @@ def __array_finalize__(self, obj) -> None:
elif all(
(
isinstance(obj, AxesArray),
hasattr(obj, "_ax_map"),
not hasattr(self, "_ax_map"),
self.shape == obj.shape,
)
):
self._ax_map = _AxisMapping(obj.axes, obj.ndim)
# Using a poorly-initialized AxesArray
# Occurs in MaskedArray.ravel, used in some plotting. MaskedArray views
# of AxesArray lose the axes attributes, and then the _ax_map attributes.
# See numpy.ma.core:asanyarray
elif all(
(
isinstance(obj, AxesArray),
not hasattr(obj, "_ax_map"),
)
):
self._ax_map = _AxisMapping({"ax_unk": 0}, in_ndim=1)
# maybe add errors for incompatible views?

def __array_ufunc__(
Expand Down Expand Up @@ -418,6 +430,16 @@ def decorator(func):
return decorator


@_implements(np.ravel)
def ravel(a, order="C"):
out = np.ravel(np.asarray(a), order=order)
is_1d_already = len(a.shape) == 1
if is_1d_already:
return AxesArray(out, a.axes)
else:
return AxesArray(out, {"ax_unk": 0})


@_implements(np.ix_)
def ix_(*args: AxesArray):
calc = np.ix_(*(np.asarray(arg) for arg in args))
Expand Down
21 changes: 21 additions & 0 deletions test/utils/test_axes.py
Expand Up @@ -629,6 +629,27 @@ def test_tensordot_list_axes():
assert_array_equal(result, super_result)


def test_ravel_1d():
arr = AxesArray(np.array([1, 2]), axes={"ax_a": 0})
result = np.ravel(arr)
assert_array_equal(result, arr)
assert result.axes == arr.axes


def test_ravel_nd():
arr = AxesArray(np.array([[1, 2], [3, 4]]), axes={"ax_a": 0, "ax_b": 1})
result = np.ravel(arr)
expected = np.ravel(np.asarray(arr))
assert_array_equal(result, expected)
assert result.axes == {"ax_unk": 0}


def test_ma_ravel():
arr = AxesArray(np.array([1, 2]), axes={"ax_a": 0})
marr = np.ma.MaskedArray(arr)
np.ma.ravel(marr)


@pytest.mark.skip
def test_einsum_implicit():
...
Expand Down

0 comments on commit 098d231

Please sign in to comment.