Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

MAINT: Cleanup vecdot's signature, typing, and importing #26313

Merged
merged 2 commits into from May 13, 2024
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
5 changes: 2 additions & 3 deletions numpy/__init__.pyi
Expand Up @@ -364,7 +364,6 @@ from numpy._core.numeric import (
convolve as convolve,
outer as outer,
tensordot as tensordot,
vecdot as vecdot,
roll as roll,
rollaxis as rollaxis,
moveaxis as moveaxis,
Expand Down Expand Up @@ -3321,7 +3320,7 @@ logical_and: _UFunc_Nin2_Nout1[L['logical_and'], L[20], L[True]]
logical_not: _UFunc_Nin1_Nout1[L['logical_not'], L[20], None]
logical_or: _UFunc_Nin2_Nout1[L['logical_or'], L[20], L[False]]
logical_xor: _UFunc_Nin2_Nout1[L['logical_xor'], L[19], L[False]]
matmul: _GUFunc_Nin2_Nout1[L['matmul'], L[19], None]
matmul: _GUFunc_Nin2_Nout1[L['matmul'], L[19], None, L["(n?,k),(k,m?)->(n?,m?)"]]
maximum: _UFunc_Nin2_Nout1[L['maximum'], L[21], None]
minimum: _UFunc_Nin2_Nout1[L['minimum'], L[21], None]
mod: _UFunc_Nin2_Nout1[L['remainder'], L[16], None]
Expand Down Expand Up @@ -3350,7 +3349,7 @@ tan: _UFunc_Nin1_Nout1[L['tan'], L[8], None]
tanh: _UFunc_Nin1_Nout1[L['tanh'], L[8], None]
true_divide: _UFunc_Nin2_Nout1[L['true_divide'], L[11], None]
trunc: _UFunc_Nin1_Nout1[L['trunc'], L[7], None]
vecdot: _GUFunc_Nin2_Nout1[L['vecdot'], L[19], None]
vecdot: _GUFunc_Nin2_Nout1[L['vecdot'], L[19], None, L["(n),(n)->()"]]

abs = absolute
acos = arccos
Expand Down
4 changes: 2 additions & 2 deletions numpy/_core/multiarray.py
Expand Up @@ -35,8 +35,8 @@
'empty', 'empty_like', 'error', 'flagsobj', 'flatiter', 'format_longfloat',
'frombuffer', 'fromfile', 'fromiter', 'fromstring',
'get_handler_name', 'get_handler_version', 'inner', 'interp',
'interp_complex', 'is_busday', 'lexsort', 'matmul', 'may_share_memory',
'min_scalar_type', 'ndarray', 'nditer', 'nested_iters',
'interp_complex', 'is_busday', 'lexsort', 'matmul', 'vecdot',
'may_share_memory', 'min_scalar_type', 'ndarray', 'nditer', 'nested_iters',
'normalize_axis_index', 'packbits', 'promote_types', 'putmask',
'ravel_multi_index', 'result_type', 'scalar', 'set_datetimeparse_function',
'set_legacy_print_mode',
Expand Down
6 changes: 3 additions & 3 deletions numpy/_core/numeric.py
Expand Up @@ -18,7 +18,7 @@
fromstring, inner, lexsort, matmul, may_share_memory, min_scalar_type,
ndarray, nditer, nested_iters, promote_types, putmask, result_type,
shares_memory, vdot, where, zeros, normalize_axis_index,
_get_promotion_state, _set_promotion_state
_get_promotion_state, _set_promotion_state, vecdot
)

from . import overrides
Expand Down Expand Up @@ -52,8 +52,8 @@
'isclose', 'isscalar', 'binary_repr', 'base_repr', 'ones',
'identity', 'allclose', 'putmask',
'flatnonzero', 'inf', 'nan', 'False_', 'True_', 'bitwise_not',
'full', 'full_like', 'matmul', 'shares_memory', 'may_share_memory',
'_get_promotion_state', '_set_promotion_state']
'full', 'full_like', 'matmul', 'vecdot', 'shares_memory',
'may_share_memory', '_get_promotion_state', '_set_promotion_state']
Comment on lines +55 to +56
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I wanted vecdot to be closer to matmul in imports/__all__ as they are both gufuncs.



def _zeros_like_dispatcher(
Expand Down
33 changes: 0 additions & 33 deletions numpy/_core/numeric.pyi
Expand Up @@ -497,39 +497,6 @@ def tensordot(
axes: int | tuple[_ShapeLike, _ShapeLike] = ...,
) -> NDArray[object_]: ...

@overload
def vecdot(
x1: _ArrayLikeUnknown, x2: _ArrayLikeUnknown, axis: int = ...
) -> NDArray[Any]: ...
@overload
def vecdot(
x1: _ArrayLikeBool_co, x2: _ArrayLikeBool_co, axis: int = ...
) -> NDArray[np.bool]: ...
@overload
def vecdot(
x1: _ArrayLikeUInt_co, x2: _ArrayLikeUInt_co, axis: int = ...
) -> NDArray[unsignedinteger[Any]]: ...
@overload
def vecdot(
x1: _ArrayLikeInt_co, x2: _ArrayLikeInt_co, axis: int = ...
) -> NDArray[signedinteger[Any]]: ...
@overload
def vecdot(
x1: _ArrayLikeFloat_co, x2: _ArrayLikeFloat_co, axis: int = ...
) -> NDArray[floating[Any]]: ...
@overload
def vecdot(
x1: _ArrayLikeComplex_co, x2: _ArrayLikeComplex_co, axis: int = ...
) -> NDArray[complexfloating[Any, Any]]: ...
@overload
def vecdot(
x1: _ArrayLikeTD64_co, x2: _ArrayLikeTD64_co, axis: int = ...
) -> NDArray[timedelta64]: ...
@overload
def vecdot(
x1: _ArrayLikeObject_co, x2: _ArrayLikeObject_co, axis: int = ...
) -> NDArray[object_]: ...
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why can these type annotations be deleted?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ah because vecdot is a ufunc now and ufunc.pyi handles it for us, right?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes! For example there is no def matmul(...) in pyi files.


@overload
def roll(
a: _ArrayLike[_SCT],
Expand Down
2 changes: 1 addition & 1 deletion numpy/_core/umath.py
Expand Up @@ -37,4 +37,4 @@
'multiply', 'negative', 'nextafter', 'not_equal', 'pi', 'positive',
'power', 'rad2deg', 'radians', 'reciprocal', 'remainder', 'right_shift',
'rint', 'sign', 'signbit', 'sin', 'sinh', 'spacing', 'sqrt', 'square',
'subtract', 'tan', 'tanh', 'true_divide', 'trunc', 'vecdot']
'subtract', 'tan', 'tanh', 'true_divide', 'trunc']
8 changes: 3 additions & 5 deletions numpy/_typing/_ufunc.pyi
Expand Up @@ -33,6 +33,7 @@ _4Tuple = tuple[_T, _T, _T, _T]
_NTypes = TypeVar("_NTypes", bound=int)
_IDType = TypeVar("_IDType", bound=Any)
_NameType = TypeVar("_NameType", bound=str)
_Signature = TypeVar("_Signature", bound=str)


class _SupportsArrayUFunc(Protocol):
Expand Down Expand Up @@ -366,7 +367,7 @@ class _UFunc_Nin2_Nout2(ufunc, Generic[_NameType, _NTypes, _IDType]): # type: i
signature: str | _4Tuple[None | str] = ...,
) -> _2Tuple[NDArray[Any]]: ...

class _GUFunc_Nin2_Nout1(ufunc, Generic[_NameType, _NTypes, _IDType]): # type: ignore[misc]
class _GUFunc_Nin2_Nout1(ufunc, Generic[_NameType, _NTypes, _IDType, _Signature]): # type: ignore[misc]
@property
def __name__(self) -> _NameType: ...
@property
Expand All @@ -379,11 +380,8 @@ class _GUFunc_Nin2_Nout1(ufunc, Generic[_NameType, _NTypes, _IDType]): # type:
def nout(self) -> Literal[1]: ...
@property
def nargs(self) -> Literal[3]: ...

# NOTE: In practice the only gufunc in the main namespace is `matmul`,
# so we can use its signature here
@property
def signature(self) -> Literal["(n?,k),(k,m?)->(n?,m?)"]: ...
def signature(self) -> _Signature: ...
@property
def reduce(self) -> None: ...
@property
Expand Down
6 changes: 0 additions & 6 deletions numpy/typing/tests/data/reveal/numeric.pyi
Expand Up @@ -91,12 +91,6 @@ assert_type(np.tensordot(AR_i8, AR_c16), npt.NDArray[np.complexfloating[Any, Any
assert_type(np.tensordot(AR_i8, AR_m), npt.NDArray[np.timedelta64])
assert_type(np.tensordot(AR_O, AR_O), npt.NDArray[np.object_])

assert_type(np.vecdot(AR_i8, AR_i8), npt.NDArray[np.signedinteger[Any]])
assert_type(np.vecdot(AR_b, AR_b), npt.NDArray[np.bool])
assert_type(np.vecdot(AR_b, AR_u8), npt.NDArray[np.unsignedinteger[Any]])
assert_type(np.vecdot(AR_i8, AR_b), npt.NDArray[np.signedinteger[Any]])
assert_type(np.vecdot(AR_i8, AR_f8), npt.NDArray[np.floating[Any]])

assert_type(np.isscalar(i8), bool)
assert_type(np.isscalar(AR_i8), bool)
assert_type(np.isscalar(B), bool)
Expand Down
10 changes: 10 additions & 0 deletions numpy/typing/tests/data/reveal/ufuncs.pyi
Expand Up @@ -76,6 +76,16 @@ assert_type(np.matmul.identity, None)
assert_type(np.matmul(AR_f8, AR_f8), Any)
assert_type(np.matmul(AR_f8, AR_f8, axes=[(0, 1), (0, 1), (0, 1)]), Any)

assert_type(np.vecdot.__name__, Literal["vecdot"])
assert_type(np.vecdot.ntypes, Literal[19])
assert_type(np.vecdot.identity, None)
assert_type(np.vecdot.nin, Literal[2])
assert_type(np.vecdot.nout, Literal[1])
assert_type(np.vecdot.nargs, Literal[3])
assert_type(np.vecdot.signature, Literal["(n),(n)->()"])
assert_type(np.vecdot.identity, None)
assert_type(np.vecdot(AR_f8, AR_f8), Any)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

There are a bunch of other tests like this that got deleted. Why only this one now? Or just an oversight?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It would improve my confidence that the typing changes are correct if all the tests that used to be in reveal/numeric.pyi were reproduced here.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Tests from reveal/numeric.pyi were associated with the removed typing stubs (numpy/_core/numeric.pyi) that I added when vecdot was implemented in pure Python. Later, when vecdot was reimplemented as a gufunc the stubs should have been removed.

As of now vecdot typing stub is coming from _GUFunc_Nin2_Nout1 only, and this: assert_type(np.vecdot(AR_f8, AR_f8), Any) is the only test that is valid with the current typing (same for matmul).
I wanted to mimic matmul implementation, as they are both gufuncs.

If we want to keep removed tests then there's a question how we want to proceed: Remove GUFunc based typing stub and restore numpy/_core/numeric.pyi typing stubs (but then the same refactoring could apply to matmul and others) or do we want to stick to GUFunc typing stub.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

assert_type(np.vecdot(AR_f8, AR_f8), Any) is the only test that is valid with the current typing

Can you explain a little more why e.g. assert_type(np.vecdot(AR_i8, AR_i8), npt.NDArray[np.signedinteger[Any]]) isn't valid? Is this a limitation of the automatically generated typing for gufuncs?

If so, it sounds like the hand-written type overrides for vecdot should be used and similar overrides should be written for matmul, or the automatically generated typing should be improved to capture the semantics of the old hand-written types. Users might write types using those in numpy 2.0 (unless you're looking to backport this to 2.0, which might not be possible given timing) so we shouldn't make the types less expressive or possibly break typing tests by changing semantics.

I'll defer to @BvB93 or others who are more experienced dealing with types about all this though.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

IIUC it's a limitation of GUFunc typing stub. For me we can keep vecdot overrides as before.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

So this is an unfortunate limitation of callables that are not just normal python functions: the latter are somewhat unique in that you it is very easy the customize the signature of each individual function at the instance-level; a degree of flexibility that is in pretty much all other cases only possible at the class-level.

Unfortunately (in this context) ufuncs have a bunch of extra methods in attributes, none of which can be adequately typed with normal python functions (as was the case prior to this PR) and thus forcing the use of a single ufunc class (or small set, as we do here with these typing-only nin-/nout-based ufunc classes).

The end result is that, unless you're willing to go through the massive task of creating a single ufunc class with customized signature for every single ufunc, that you're stuck with a less precise signature, one that is unfortunately too imprecise to actually infer the output type of the likes of assert_type(np.vecdot(AR_i8, AR_i8), npt.NDArray[np.signedinteger[Any]]) (a limitation that unfortunately applies to all ufuncs).


assert_type(np.bitwise_count.__name__, Literal['bitwise_count'])
assert_type(np.bitwise_count.ntypes, Literal[11])
assert_type(np.bitwise_count.identity, None)
Expand Down
4 changes: 2 additions & 2 deletions tools/ci/array-api-skips.txt
Expand Up @@ -14,6 +14,6 @@ array_api_tests/test_signatures.py::test_func_signature[reshape]
array_api_tests/test_signatures.py::test_func_signature[argsort]
array_api_tests/test_signatures.py::test_func_signature[sort]

# TODO: check why in CI `inspect.signature(np.vecdot)` returns (*arg, **kwarg)
# instead of raising ValueError. mtsokol: couldn't reproduce locally
# ufuncs signature on linux is always <Signature (*args, **kwargs)>
# np.vecdot is the only ufunc with a keyword argument which causes a failure
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Do the array API tests also check that vecdot's signature is correct using duck typing? Probably worth doing both - check the signature and validate that the signature is correct!

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

There is test_signatures.py in array-api-tests suite for testing signatures. Separately, functions are tested by calling them with args+kwargs, like test_vecdot in test_linalg.py.

array_api_tests/test_signatures.py::test_func_signature[vecdot]