Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Don't attempt to register overloads that aren't for this target in BaseContext and related fixes #9454

Merged
merged 7 commits into from Apr 26, 2024
6 changes: 6 additions & 0 deletions numba/core/errors.py
Expand Up @@ -720,6 +720,12 @@ def __init__(self, kind, target_hw, hw_clazz):
super().__init__(msg)


class NonexistentTargetError(InternalError):
"""For signalling that a target that does not exist was requested.
"""
pass


class RequireLiteralValue(TypingError):
"""
For signalling that a function's typing requires a constant value for
Expand Down
2 changes: 1 addition & 1 deletion numba/core/extending.py
Expand Up @@ -190,7 +190,7 @@ def get(arr):
def decorate(overload_func):
template = make_overload_attribute_template(
typ, attr, overload_func,
inline=kwargs.get('inline', 'never'),
**kwargs
stuartarchibald marked this conversation as resolved.
Show resolved Hide resolved
)
infer_getattr(template)
overload(overload_func, **kwargs)(overload_func)
Expand Down
5 changes: 3 additions & 2 deletions numba/core/target_extension.py
@@ -1,7 +1,8 @@
from abc import ABC, abstractmethod
from numba.core.registry import DelayedRegistry, CPUDispatcher
from numba.core.decorators import jit
from numba.core.errors import InternalTargetMismatchError, NumbaValueError
from numba.core.errors import (InternalTargetMismatchError,
NonexistentTargetError)
from threading import local as tls


Expand All @@ -18,7 +19,7 @@ def __getitem__(self, item):
msg = "No target is registered against '{}', known targets:\n{}"
known = '\n'.join([f"{k: <{10}} -> {v}"
for k, v in target_registry.items()])
raise NumbaValueError(msg.format(item, known)) from None
raise NonexistentTargetError(msg.format(item, known)) from None


# Registry mapping target name strings to Target classes
Expand Down
44 changes: 44 additions & 0 deletions numba/core/typing/context.py
Expand Up @@ -435,9 +435,53 @@ def install_registry(self, registry):
except KeyError:
loader = templates.RegistryLoader(registry)
self._registries[registry] = loader

from numba.core.target_extension import (get_local_target,
resolve_target_str)
current_target = get_local_target(self)

def is_for_this_target(ftcls):
metadata = getattr(ftcls, 'metadata', None)
if metadata is None:
return True

target_str = metadata.get('target')
if target_str is None:
return True

# There may be pending registrations for nonexistent targets.
# Ideally it would be impossible to leave a registration pending
# for an invalid target, but in practice this is exceedingly
# difficult to guard against - many things are registered at import
# time, and eagerly reporting an error when registering for invalid
# targets would require that all target registration code is
# executed prior to all typing registrations during the import
# process; attempting to enforce this would impose constraints on
# execution order during import that would be very difficult to
# resolve and maintain in the presence of typical code maintenance.
# Furthermore, these constraints would be imposed not only on
# Numba internals, but also on its dependents.
#
# Instead of that enforcement, we simply catch any occurrences of
# registrations for targets that don't exist, and report that
# they're not for this target. They will then not be encountered
# again during future typing context refreshes (because the
# loader's new registrations are a stream_list that doesn't yield
# previously-yielded items).
stuartarchibald marked this conversation as resolved.
Show resolved Hide resolved
try:
ft_target = resolve_target_str(target_str)
except errors.NonexistentTargetError:
return False

return current_target.inherits_from(ft_target)

for ftcls in loader.new_registrations('functions'):
if not is_for_this_target(ftcls):
continue
self.insert_function(ftcls(self))
for ftcls in loader.new_registrations('attributes'):
if not is_for_this_target(ftcls):
continue
self.insert_attributes(ftcls(self))
for gv, gty in loader.new_registrations('globals'):
existing = self._lookup_global(gv)
Expand Down
33 changes: 13 additions & 20 deletions numba/core/typing/templates.py
Expand Up @@ -16,7 +16,6 @@
from numba.core.errors import (
TypingError,
InternalError,
InternalTargetMismatchError,
)
from numba.core.cpu_options import InlineOptions

Expand Down Expand Up @@ -1096,24 +1095,18 @@ def _init_once(self):
"""
attr = self._attr

try:
registry = self._get_target_registry('method')
except InternalTargetMismatchError:
# Target mismatch. Do not register attribute lookup here.
pass
else:
lower_builtin = registry.lower

@lower_builtin((self.key, attr), self.key, types.VarArg(types.Any))
def method_impl(context, builder, sig, args):
typ = sig.args[0]
typing_context = context.typing_context
fnty = self._get_function_type(typing_context, typ)
sig = self._get_signature(typing_context, fnty, sig.args, {})
call = context.get_function(fnty, sig)
# Link dependent library
context.add_linking_libs(getattr(call, 'libs', ()))
return call(builder, args)
registry = self._get_target_registry('method')

@registry.lower((self.key, attr), self.key, types.VarArg(types.Any))
def method_impl(context, builder, sig, args):
typ = sig.args[0]
typing_context = context.typing_context
fnty = self._get_function_type(typing_context, typ)
sig = self._get_signature(typing_context, fnty, sig.args, {})
call = context.get_function(fnty, sig)
# Link dependent library
context.add_linking_libs(getattr(call, 'libs', ()))
return call(builder, args)

def _resolve(self, typ, attr):
if self._attr != attr:
Expand Down Expand Up @@ -1162,7 +1155,7 @@ def get_template_info(self):
return types.BoundFunction(MethodTemplate, typ)


def make_overload_attribute_template(typ, attr, overload_func, inline,
def make_overload_attribute_template(typ, attr, overload_func, inline='never',
prefer_literal=False,
base=_OverloadAttributeTemplate,
**kwargs):
Expand Down
30 changes: 27 additions & 3 deletions numba/cuda/tests/cudapy/test_overload.py
@@ -1,7 +1,8 @@
from numba import cuda, njit
from numba.core.extending import overload
from numba import cuda, njit, types
from numba.core.errors import TypingError
from numba.core.extending import overload, overload_attribute
from numba.cuda.testing import CUDATestCase, skip_on_cudasim, unittest

from numba.tests.test_extending import mydummy_type, MyDummyType
import numpy as np
stuartarchibald marked this conversation as resolved.
Show resolved Hide resolved


Expand Down Expand Up @@ -295,6 +296,29 @@ def kernel(x):
expected = GENERIC_TARGET_OL_CALLS_TARGET_OL * GENERIC_TARGET_OL
self.check_overload_cpu(kernel, expected)

def test_overload_attribute_target(self):
@overload_attribute(MyDummyType, 'cuda_only', target='cuda')
def ov_dummy_cuda_attr(obj):
def imp(obj):
return 42

return imp

# Ensure that we cannot use the CUDA target-specific attribute on the
# CPU, and that an appropriate typing error is raised
with self.assertRaisesRegex(TypingError,
"Unknown attribute 'cuda_only'"):
@njit(types.void(mydummy_type))
def illegal_target_attr_use(x):
return x.cuda_only

# Ensure that the CUDA target-specific attribute is usable and works
# correctly when the target is CUDA - note eager compilation via
# signature
@cuda.jit(types.void(types.int64[::1], mydummy_type))
def cuda_target_attr_use(res, dummy):
res[0] = dummy.cuda_only


if __name__ == '__main__':
unittest.main()
37 changes: 34 additions & 3 deletions numba/tests/test_target_extension.py
Expand Up @@ -12,7 +12,8 @@
from functools import cached_property
import numpy as np
from numba import njit, types
from numba.extending import overload, intrinsic, overload_classmethod
from numba.extending import (overload, overload_attribute,
overload_classmethod, intrinsic)
from numba.core.target_extension import (
JitDecorator,
target_registry,
Expand Down Expand Up @@ -41,6 +42,7 @@
from numba.core.compiler import CompilerBase, DefaultPassBuilder
from numba.core.compiler_machinery import FunctionPass, register_pass
from numba.core.typed_passes import PreLowerStripPhis
from numba.tests.test_extending import mydummy_type, MyDummyType
stuartarchibald marked this conversation as resolved.
Show resolved Hide resolved

# Define a new target, this target extends GPU, this places the DPU in the
# target hierarchy as a type of GPU.
Expand Down Expand Up @@ -281,6 +283,10 @@ def typing_context(self):
class DPUDispatcher(Dispatcher):
targetdescr = dpu_target

def compile(self, sig):
with target_override('dpu'):
return super().compile(sig)


# Register a dispatcher for the DPU target, a lot of the code uses this
# internally to work out what to do RE compilation
Expand Down Expand Up @@ -317,6 +323,8 @@ def dispatcher_wrapper(self):
if "nopython" in self._kwargs:
topt["nopython"] = True

topt['target_backend'] = 'dpu'
stuartarchibald marked this conversation as resolved.
Show resolved Hide resolved

# It would be easy to specialise the default compilation pipeline for
# this target here.
pipeline_class = compiler.Compiler
Expand Down Expand Up @@ -512,7 +520,7 @@ def dpu_foo():

def test_invalid_target_jit(self):

with self.assertRaises(errors.NumbaValueError) as raises:
with self.assertRaises(errors.NonexistentTargetError) as raises:
@njit(_target='invalid_silicon')
def foo():
pass
Expand All @@ -529,7 +537,7 @@ def bar():

# This is a typing error at present as it fails during typing when the
# overloads are walked.
with self.assertRaises(errors.TypingError) as raises:
with self.assertRaises(errors.NonexistentTargetError) as raises:
@overload(bar, target='invalid_silicon')
def ol_bar():
return lambda : None
Expand Down Expand Up @@ -704,6 +712,29 @@ def foo():
from numba.core.runtime import nrt
self.assertIsInstance(r, nrt.MemInfo)

def test_overload_attribute_target(self):
@overload_attribute(MyDummyType, 'dpu_only', target='dpu')
def ov_dummy_dpu_attr(obj):
def imp(obj):
return 42

return imp

# Ensure that we cannot use the DPU target-specific attribute on the
# CPU, and that an appropriate typing error is raised
with self.assertRaisesRegex(errors.TypingError,
"Unknown attribute 'dpu_only'"):
@njit(types.void(mydummy_type))
def illegal_target_attr_use(x):
return x.dpu_only

# Ensure that the DPU target-specific attribute is usable and works
# correctly when the target is DPU - note eager compilation via
# signature
@djit(types.void(types.int64[::1], mydummy_type))
def cuda_target_attr_use(res, dummy):
res[0] = dummy.dpu_only


class TestTargetOffload(TestCase):
"""In this use case the CPU compilation pipeline is extended with a new
Expand Down