diff --git a/numba/core/errors.py b/numba/core/errors.py index 0d339583ff..892b230bd3 100644 --- a/numba/core/errors.py +++ b/numba/core/errors.py @@ -720,6 +720,12 @@ def __init__(self, kind, target_hw, hw_clazz): super().__init__(msg) +class NonexistentTargetError(InternalError): + """For signalling that a target that does not exist was requested. + """ + pass + + class RequireLiteralValue(TypingError): """ For signalling that a function's typing requires a constant value for diff --git a/numba/core/extending.py b/numba/core/extending.py index 48339598d0..e4da910998 100644 --- a/numba/core/extending.py +++ b/numba/core/extending.py @@ -190,7 +190,7 @@ def get(arr): def decorate(overload_func): template = make_overload_attribute_template( typ, attr, overload_func, - inline=kwargs.get('inline', 'never'), + **kwargs ) infer_getattr(template) overload(overload_func, **kwargs)(overload_func) diff --git a/numba/core/target_extension.py b/numba/core/target_extension.py index d0546671c6..91e5df0998 100644 --- a/numba/core/target_extension.py +++ b/numba/core/target_extension.py @@ -1,7 +1,8 @@ from abc import ABC, abstractmethod from numba.core.registry import DelayedRegistry, CPUDispatcher from numba.core.decorators import jit -from numba.core.errors import InternalTargetMismatchError, NumbaValueError +from numba.core.errors import (InternalTargetMismatchError, + NonexistentTargetError) from threading import local as tls @@ -18,7 +19,7 @@ def __getitem__(self, item): msg = "No target is registered against '{}', known targets:\n{}" known = '\n'.join([f"{k: <{10}} -> {v}" for k, v in target_registry.items()]) - raise NumbaValueError(msg.format(item, known)) from None + raise NonexistentTargetError(msg.format(item, known)) from None # Registry mapping target name strings to Target classes diff --git a/numba/core/typing/context.py b/numba/core/typing/context.py index 5b180c0c60..dd50eca923 100644 --- a/numba/core/typing/context.py +++ b/numba/core/typing/context.py @@ -435,9 +435,53 @@ def install_registry(self, registry): except KeyError: loader = templates.RegistryLoader(registry) self._registries[registry] = loader + + from numba.core.target_extension import (get_local_target, + resolve_target_str) + current_target = get_local_target(self) + + def is_for_this_target(ftcls): + metadata = getattr(ftcls, 'metadata', None) + if metadata is None: + return True + + target_str = metadata.get('target') + if target_str is None: + return True + + # There may be pending registrations for nonexistent targets. + # Ideally it would be impossible to leave a registration pending + # for an invalid target, but in practice this is exceedingly + # difficult to guard against - many things are registered at import + # time, and eagerly reporting an error when registering for invalid + # targets would require that all target registration code is + # executed prior to all typing registrations during the import + # process; attempting to enforce this would impose constraints on + # execution order during import that would be very difficult to + # resolve and maintain in the presence of typical code maintenance. + # Furthermore, these constraints would be imposed not only on + # Numba internals, but also on its dependents. + # + # Instead of that enforcement, we simply catch any occurrences of + # registrations for targets that don't exist, and report that + # they're not for this target. They will then not be encountered + # again during future typing context refreshes (because the + # loader's new registrations are a stream_list that doesn't yield + # previously-yielded items). + try: + ft_target = resolve_target_str(target_str) + except errors.NonexistentTargetError: + return False + + return current_target.inherits_from(ft_target) + for ftcls in loader.new_registrations('functions'): + if not is_for_this_target(ftcls): + continue self.insert_function(ftcls(self)) for ftcls in loader.new_registrations('attributes'): + if not is_for_this_target(ftcls): + continue self.insert_attributes(ftcls(self)) for gv, gty in loader.new_registrations('globals'): existing = self._lookup_global(gv) diff --git a/numba/core/typing/templates.py b/numba/core/typing/templates.py index 74097fdc0b..2b371cfe72 100644 --- a/numba/core/typing/templates.py +++ b/numba/core/typing/templates.py @@ -16,7 +16,6 @@ from numba.core.errors import ( TypingError, InternalError, - InternalTargetMismatchError, ) from numba.core.cpu_options import InlineOptions @@ -1096,24 +1095,18 @@ def _init_once(self): """ attr = self._attr - try: - registry = self._get_target_registry('method') - except InternalTargetMismatchError: - # Target mismatch. Do not register attribute lookup here. - pass - else: - lower_builtin = registry.lower - - @lower_builtin((self.key, attr), self.key, types.VarArg(types.Any)) - def method_impl(context, builder, sig, args): - typ = sig.args[0] - typing_context = context.typing_context - fnty = self._get_function_type(typing_context, typ) - sig = self._get_signature(typing_context, fnty, sig.args, {}) - call = context.get_function(fnty, sig) - # Link dependent library - context.add_linking_libs(getattr(call, 'libs', ())) - return call(builder, args) + registry = self._get_target_registry('method') + + @registry.lower((self.key, attr), self.key, types.VarArg(types.Any)) + def method_impl(context, builder, sig, args): + typ = sig.args[0] + typing_context = context.typing_context + fnty = self._get_function_type(typing_context, typ) + sig = self._get_signature(typing_context, fnty, sig.args, {}) + call = context.get_function(fnty, sig) + # Link dependent library + context.add_linking_libs(getattr(call, 'libs', ())) + return call(builder, args) def _resolve(self, typ, attr): if self._attr != attr: @@ -1162,7 +1155,7 @@ def get_template_info(self): return types.BoundFunction(MethodTemplate, typ) -def make_overload_attribute_template(typ, attr, overload_func, inline, +def make_overload_attribute_template(typ, attr, overload_func, inline='never', prefer_literal=False, base=_OverloadAttributeTemplate, **kwargs): diff --git a/numba/cuda/tests/cudapy/test_overload.py b/numba/cuda/tests/cudapy/test_overload.py index 2764678ec6..412fe2434c 100644 --- a/numba/cuda/tests/cudapy/test_overload.py +++ b/numba/cuda/tests/cudapy/test_overload.py @@ -1,7 +1,8 @@ -from numba import cuda, njit -from numba.core.extending import overload +from numba import cuda, njit, types +from numba.core.errors import TypingError +from numba.core.extending import overload, overload_attribute +from numba.core.typing.typeof import typeof from numba.cuda.testing import CUDATestCase, skip_on_cudasim, unittest - import numpy as np @@ -295,6 +296,32 @@ def kernel(x): expected = GENERIC_TARGET_OL_CALLS_TARGET_OL * GENERIC_TARGET_OL self.check_overload_cpu(kernel, expected) + def test_overload_attribute_target(self): + MyDummy, MyDummyType = self.make_dummy_type() + mydummy_type = typeof(MyDummy()) + + @overload_attribute(MyDummyType, 'cuda_only', target='cuda') + def ov_dummy_cuda_attr(obj): + def imp(obj): + return 42 + + return imp + + # Ensure that we cannot use the CUDA target-specific attribute on the + # CPU, and that an appropriate typing error is raised + with self.assertRaisesRegex(TypingError, + "Unknown attribute 'cuda_only'"): + @njit(types.int64(mydummy_type)) + def illegal_target_attr_use(x): + return x.cuda_only + + # Ensure that the CUDA target-specific attribute is usable and works + # correctly when the target is CUDA - note eager compilation via + # signature + @cuda.jit(types.void(types.int64[::1], mydummy_type)) + def cuda_target_attr_use(res, dummy): + res[0] = dummy.cuda_only + if __name__ == '__main__': unittest.main() diff --git a/numba/tests/test_target_extension.py b/numba/tests/test_target_extension.py index c472b1c71c..a78b234b81 100644 --- a/numba/tests/test_target_extension.py +++ b/numba/tests/test_target_extension.py @@ -12,7 +12,8 @@ from functools import cached_property import numpy as np from numba import njit, types -from numba.extending import overload, intrinsic, overload_classmethod +from numba.extending import (overload, overload_attribute, + overload_classmethod, intrinsic) from numba.core.target_extension import ( JitDecorator, target_registry, @@ -32,6 +33,7 @@ from numba.core.codegen import CPUCodegen, JITCodeLibrary from numba.core.callwrapper import PyCallWrapper from numba.core.imputils import RegistryLoader, Registry +from numba.core.typing.typeof import typeof from numba import _dynfunc import llvmlite.binding as ll from llvmlite import ir as llir @@ -281,6 +283,10 @@ def typing_context(self): class DPUDispatcher(Dispatcher): targetdescr = dpu_target + def compile(self, sig): + with target_override('dpu'): + return super().compile(sig) + # Register a dispatcher for the DPU target, a lot of the code uses this # internally to work out what to do RE compilation @@ -512,7 +518,7 @@ def dpu_foo(): def test_invalid_target_jit(self): - with self.assertRaises(errors.NumbaValueError) as raises: + with self.assertRaises(errors.NonexistentTargetError) as raises: @njit(_target='invalid_silicon') def foo(): pass @@ -529,7 +535,7 @@ def bar(): # This is a typing error at present as it fails during typing when the # overloads are walked. - with self.assertRaises(errors.TypingError) as raises: + with self.assertRaises(errors.NonexistentTargetError) as raises: @overload(bar, target='invalid_silicon') def ol_bar(): return lambda : None @@ -704,6 +710,32 @@ def foo(): from numba.core.runtime import nrt self.assertIsInstance(r, nrt.MemInfo) + def test_overload_attribute_target(self): + MyDummy, MyDummyType = self.make_dummy_type() + mydummy_type = typeof(MyDummy()) + + @overload_attribute(MyDummyType, 'dpu_only', target='dpu') + def ov_dummy_dpu_attr(obj): + def imp(obj): + return 42 + + return imp + + # Ensure that we cannot use the DPU target-specific attribute on the + # CPU, and that an appropriate typing error is raised + with self.assertRaisesRegex(errors.TypingError, + "Unknown attribute 'dpu_only'"): + @njit(types.int64(mydummy_type)) + def illegal_target_attr_use(x): + return x.dpu_only + + # Ensure that the DPU target-specific attribute is usable and works + # correctly when the target is DPU - note eager compilation via + # signature + @djit(types.void(types.int64[::1], mydummy_type)) + def cuda_target_attr_use(res, dummy): + res[0] = dummy.dpu_only + class TestTargetOffload(TestCase): """In this use case the CPU compilation pipeline is extended with a new