Skip to content

Commit

Permalink
Merge pull request #8906 from louisamand/main
Browse files Browse the repository at this point in the history
Add support for reflected dunder methods in jitclass
  • Loading branch information
sklam committed Apr 18, 2023
2 parents 746ae3f + 4d98923 commit 74a0840
Show file tree
Hide file tree
Showing 3 changed files with 100 additions and 0 deletions.
12 changes: 12 additions & 0 deletions docs/source/user/jitclass.rst
Original file line number Diff line number Diff line change
Expand Up @@ -252,7 +252,19 @@ The following dunder methods may be defined for jitclasses:
* ``__iand__``
* ``__ior__``
* ``__ixor__``
* ``__radd__``
* ``__rfloordiv__``
* ``__rlshift__``
* ``__rmatmul__``
* ``__rmod__``
* ``__rmul__``
* ``__rpow__``
* ``__rrshift__``
* ``__rsub__``
* ``__rtruediv__``
* ``__rand__``
* ``__ror__``
* ``__rxor__``

Refer to the `Python Data Model documentation
<https://docs.python.org/3/reference/datamodel.html>`_ for descriptions of
Expand Down
12 changes: 12 additions & 0 deletions numba/experimental/jitclass/boxing.py
Original file line number Diff line number Diff line change
Expand Up @@ -144,7 +144,19 @@ def _specialize_box(typ):
"__iand__",
"__ior__",
"__ixor__",
"__radd__",
"__rfloordiv__",
"__rlshift__",
"__rmatmul__",
"__rmod__",
"__rmul__",
"__rpow__",
"__rrshift__",
"__rsub__",
"__rtruediv__",
"__rand__",
"__ror__",
"__rxor__",
}
for name, func in typ.methods.items():
if name == "__init__":
Expand Down
76 changes: 76 additions & 0 deletions numba/tests/test_jitclasses.py
Original file line number Diff line number Diff line change
Expand Up @@ -1943,6 +1943,82 @@ def __init__(self, array):
tuple(jit_matrix_at.arr.ravel())
)

def test_arithmetic_logical_reflection(self):
class OperatorsDefined:
def __init__(self, x):
self.x = x

def __radd__(self, other):
return other.x + self.x

def __rsub__(self, other):
return other.x - self.x

def __rmul__(self, other):
return other.x * self.x

def __rtruediv__(self, other):
return other.x / self.x

def __rfloordiv__(self, other):
return other.x // self.x

def __rmod__(self, other):
return other.x % self.x

def __rpow__(self, other):
return other.x ** self.x

def __rlshift__(self, other):
return other.x << self.x

def __rrshift__(self, other):
return other.x >> self.x

def __rand__(self, other):
return other.x & self.x

def __rxor__(self, other):
return other.x ^ self.x

def __ror__(self, other):
return other.x | self.x

class NoOperatorsDefined:
def __init__(self, x):
self.x = x

float_op = ["+", "-", "*", "**", "/", "//", "%"]
int_op = [*float_op, "<<", ">>" , "&", "^", "|"]

for test_type, test_op, test_value in [
(int32, int_op, (2, 4)),
(float64, float_op, (2., 4.)),
(float64[::1], float_op,
(np.array([1., 2., 4.]), np.array([20., -24., 1.])))
]:
spec = {"x": test_type}
JitOperatorsDefined = jitclass(OperatorsDefined, spec)
JitNoOperatorsDefined = jitclass(NoOperatorsDefined, spec)

py_ops_defined = OperatorsDefined(test_value[0]) # noqa: F841
py_ops_not_defined = NoOperatorsDefined(test_value[1]) # noqa: F841

jit_ops_defined = JitOperatorsDefined(test_value[0]) # noqa: F841
jit_ops_not_defined = JitNoOperatorsDefined(test_value[1]) # noqa: F841 E501

for op in test_op:
if not ("array" in str(test_type)):
self.assertEqual(
eval(f"py_ops_not_defined {op} py_ops_defined"),
eval(f"jit_ops_not_defined {op} jit_ops_defined")
)
else:
self.assertTupleEqual(
tuple(eval(f"py_ops_not_defined {op} py_ops_defined")),
tuple(eval(f"jit_ops_not_defined {op} jit_ops_defined"))
)

def test_implicit_hash_compiles(self):
# Ensure that classes with __hash__ implicitly defined as None due to
# the presence of __eq__ are correctly handled by ignoring the __hash__
Expand Down

0 comments on commit 74a0840

Please sign in to comment.