Skip to content

Commit

Permalink
BUG: vecdot signature
Browse files Browse the repository at this point in the history
  • Loading branch information
mtsokol committed May 9, 2024
1 parent 2a9b913 commit 37568cf
Show file tree
Hide file tree
Showing 9 changed files with 19 additions and 55 deletions.
5 changes: 2 additions & 3 deletions numpy/__init__.pyi
Expand Up @@ -364,7 +364,6 @@ from numpy._core.numeric import (
convolve as convolve,
outer as outer,
tensordot as tensordot,
vecdot as vecdot,
roll as roll,
rollaxis as rollaxis,
moveaxis as moveaxis,
Expand Down Expand Up @@ -3321,7 +3320,7 @@ logical_and: _UFunc_Nin2_Nout1[L['logical_and'], L[20], L[True]]
logical_not: _UFunc_Nin1_Nout1[L['logical_not'], L[20], None]
logical_or: _UFunc_Nin2_Nout1[L['logical_or'], L[20], L[False]]
logical_xor: _UFunc_Nin2_Nout1[L['logical_xor'], L[19], L[False]]
matmul: _GUFunc_Nin2_Nout1[L['matmul'], L[19], None]
matmul: _GUFunc_Nin2_Nout1[L['matmul'], L[19], None, L["(n?,k),(k,m?)->(n?,m?)"]]
maximum: _UFunc_Nin2_Nout1[L['maximum'], L[21], None]
minimum: _UFunc_Nin2_Nout1[L['minimum'], L[21], None]
mod: _UFunc_Nin2_Nout1[L['remainder'], L[16], None]
Expand Down Expand Up @@ -3350,7 +3349,7 @@ tan: _UFunc_Nin1_Nout1[L['tan'], L[8], None]
tanh: _UFunc_Nin1_Nout1[L['tanh'], L[8], None]
true_divide: _UFunc_Nin2_Nout1[L['true_divide'], L[11], None]
trunc: _UFunc_Nin1_Nout1[L['trunc'], L[7], None]
vecdot: _GUFunc_Nin2_Nout1[L['vecdot'], L[19], None]
vecdot: _GUFunc_Nin2_Nout1[L['vecdot'], L[19], None, L["(n),(n)->()"]]

abs = absolute
acos = arccos
Expand Down
2 changes: 1 addition & 1 deletion numpy/_core/multiarray.py
Expand Up @@ -42,7 +42,7 @@
'set_legacy_print_mode',
'set_typeDict', 'shares_memory', 'typeinfo',
'unpackbits', 'unravel_index', 'vdot', 'where', 'zeros',
'_get_promotion_state', '_set_promotion_state']
'_get_promotion_state', '_set_promotion_state', 'vecdot']

# For backward compatibility, make sure pickle imports
# these functions from here
Expand Down
4 changes: 2 additions & 2 deletions numpy/_core/numeric.py
Expand Up @@ -18,7 +18,7 @@
fromstring, inner, lexsort, matmul, may_share_memory, min_scalar_type,
ndarray, nditer, nested_iters, promote_types, putmask, result_type,
shares_memory, vdot, where, zeros, normalize_axis_index,
_get_promotion_state, _set_promotion_state
_get_promotion_state, _set_promotion_state, vecdot
)

from . import overrides
Expand Down Expand Up @@ -53,7 +53,7 @@
'identity', 'allclose', 'putmask',
'flatnonzero', 'inf', 'nan', 'False_', 'True_', 'bitwise_not',
'full', 'full_like', 'matmul', 'shares_memory', 'may_share_memory',
'_get_promotion_state', '_set_promotion_state']
'_get_promotion_state', '_set_promotion_state', 'vecdot']


def _zeros_like_dispatcher(
Expand Down
33 changes: 0 additions & 33 deletions numpy/_core/numeric.pyi
Expand Up @@ -497,39 +497,6 @@ def tensordot(
axes: int | tuple[_ShapeLike, _ShapeLike] = ...,
) -> NDArray[object_]: ...

@overload
def vecdot(
x1: _ArrayLikeUnknown, x2: _ArrayLikeUnknown, axis: int = ...
) -> NDArray[Any]: ...
@overload
def vecdot(
x1: _ArrayLikeBool_co, x2: _ArrayLikeBool_co, axis: int = ...
) -> NDArray[np.bool]: ...
@overload
def vecdot(
x1: _ArrayLikeUInt_co, x2: _ArrayLikeUInt_co, axis: int = ...
) -> NDArray[unsignedinteger[Any]]: ...
@overload
def vecdot(
x1: _ArrayLikeInt_co, x2: _ArrayLikeInt_co, axis: int = ...
) -> NDArray[signedinteger[Any]]: ...
@overload
def vecdot(
x1: _ArrayLikeFloat_co, x2: _ArrayLikeFloat_co, axis: int = ...
) -> NDArray[floating[Any]]: ...
@overload
def vecdot(
x1: _ArrayLikeComplex_co, x2: _ArrayLikeComplex_co, axis: int = ...
) -> NDArray[complexfloating[Any, Any]]: ...
@overload
def vecdot(
x1: _ArrayLikeTD64_co, x2: _ArrayLikeTD64_co, axis: int = ...
) -> NDArray[timedelta64]: ...
@overload
def vecdot(
x1: _ArrayLikeObject_co, x2: _ArrayLikeObject_co, axis: int = ...
) -> NDArray[object_]: ...

@overload
def roll(
a: _ArrayLike[_SCT],
Expand Down
2 changes: 1 addition & 1 deletion numpy/_core/umath.py
Expand Up @@ -37,4 +37,4 @@
'multiply', 'negative', 'nextafter', 'not_equal', 'pi', 'positive',
'power', 'rad2deg', 'radians', 'reciprocal', 'remainder', 'right_shift',
'rint', 'sign', 'signbit', 'sin', 'sinh', 'spacing', 'sqrt', 'square',
'subtract', 'tan', 'tanh', 'true_divide', 'trunc', 'vecdot']
'subtract', 'tan', 'tanh', 'true_divide', 'trunc']
8 changes: 3 additions & 5 deletions numpy/_typing/_ufunc.pyi
Expand Up @@ -33,6 +33,7 @@ _4Tuple = tuple[_T, _T, _T, _T]
_NTypes = TypeVar("_NTypes", bound=int)
_IDType = TypeVar("_IDType", bound=Any)
_NameType = TypeVar("_NameType", bound=str)
_Signature = TypeVar("_Signature", bound=str)


class _SupportsArrayUFunc(Protocol):
Expand Down Expand Up @@ -366,7 +367,7 @@ class _UFunc_Nin2_Nout2(ufunc, Generic[_NameType, _NTypes, _IDType]): # type: i
signature: str | _4Tuple[None | str] = ...,
) -> _2Tuple[NDArray[Any]]: ...

class _GUFunc_Nin2_Nout1(ufunc, Generic[_NameType, _NTypes, _IDType]): # type: ignore[misc]
class _GUFunc_Nin2_Nout1(ufunc, Generic[_NameType, _NTypes, _IDType, _Signature]): # type: ignore[misc]
@property
def __name__(self) -> _NameType: ...
@property
Expand All @@ -379,11 +380,8 @@ class _GUFunc_Nin2_Nout1(ufunc, Generic[_NameType, _NTypes, _IDType]): # type:
def nout(self) -> Literal[1]: ...
@property
def nargs(self) -> Literal[3]: ...

# NOTE: In practice the only gufunc in the main namespace is `matmul`,
# so we can use its signature here
@property
def signature(self) -> Literal["(n?,k),(k,m?)->(n?,m?)"]: ...
def signature(self) -> _Signature: ...
@property
def reduce(self) -> None: ...
@property
Expand Down
6 changes: 0 additions & 6 deletions numpy/typing/tests/data/reveal/numeric.pyi
Expand Up @@ -91,12 +91,6 @@ assert_type(np.tensordot(AR_i8, AR_c16), npt.NDArray[np.complexfloating[Any, Any
assert_type(np.tensordot(AR_i8, AR_m), npt.NDArray[np.timedelta64])
assert_type(np.tensordot(AR_O, AR_O), npt.NDArray[np.object_])

assert_type(np.vecdot(AR_i8, AR_i8), npt.NDArray[np.signedinteger[Any]])
assert_type(np.vecdot(AR_b, AR_b), npt.NDArray[np.bool])
assert_type(np.vecdot(AR_b, AR_u8), npt.NDArray[np.unsignedinteger[Any]])
assert_type(np.vecdot(AR_i8, AR_b), npt.NDArray[np.signedinteger[Any]])
assert_type(np.vecdot(AR_i8, AR_f8), npt.NDArray[np.floating[Any]])

assert_type(np.isscalar(i8), bool)
assert_type(np.isscalar(AR_i8), bool)
assert_type(np.isscalar(B), bool)
Expand Down
10 changes: 10 additions & 0 deletions numpy/typing/tests/data/reveal/ufuncs.pyi
Expand Up @@ -76,6 +76,16 @@ assert_type(np.matmul.identity, None)
assert_type(np.matmul(AR_f8, AR_f8), Any)
assert_type(np.matmul(AR_f8, AR_f8, axes=[(0, 1), (0, 1), (0, 1)]), Any)

assert_type(np.vecdot.__name__, Literal["vecdot"])
assert_type(np.vecdot.ntypes, Literal[19])
assert_type(np.vecdot.identity, None)
assert_type(np.vecdot.nin, Literal[2])
assert_type(np.vecdot.nout, Literal[1])
assert_type(np.vecdot.nargs, Literal[3])
assert_type(np.vecdot.signature, Literal["(n),(n)->()"])
assert_type(np.vecdot.identity, None)
assert_type(np.vecdot(AR_f8, AR_f8), Any)

assert_type(np.bitwise_count.__name__, Literal['bitwise_count'])
assert_type(np.bitwise_count.ntypes, Literal[11])
assert_type(np.bitwise_count.identity, None)
Expand Down
4 changes: 0 additions & 4 deletions tools/ci/array-api-skips.txt
Expand Up @@ -13,7 +13,3 @@ array_api_tests/test_signatures.py::test_func_signature[reshape]
# missing 'descending' keyword arguments
array_api_tests/test_signatures.py::test_func_signature[argsort]
array_api_tests/test_signatures.py::test_func_signature[sort]

# TODO: check why in CI `inspect.signature(np.vecdot)` returns (*arg, **kwarg)
# instead of raising ValueError. mtsokol: couldn't reproduce locally
array_api_tests/test_signatures.py::test_func_signature[vecdot]

0 comments on commit 37568cf

Please sign in to comment.