Skip to content

Commit

Permalink
Browse files Browse the repository at this point in the history
  • Loading branch information
JanKleine committed Feb 14, 2024
1 parent b54a3c6 commit 7c5d0b6
Show file tree
Hide file tree
Showing 30 changed files with 1,232 additions and 381 deletions.
63 changes: 24 additions & 39 deletions dace/codegen/compiled_sdfg.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,9 +31,8 @@ def __init__(self, library_filename, program_name):
:param program_name: Name of the DaCe program (for use in finding
the stub library loader).
"""
self._stub_filename = os.path.join(
os.path.dirname(os.path.realpath(library_filename)),
f'libdacestub_{program_name}.{Config.get("compiler", "library_extension")}')
self._stub_filename = os.path.join(os.path.dirname(os.path.realpath(library_filename)),
f'libdacestub_{program_name}.{Config.get("compiler", "library_extension")}')
self._library_filename = os.path.realpath(library_filename)
self._stub = None
self._lib = None
Expand Down Expand Up @@ -159,6 +158,8 @@ def _array_interface_ptr(array: Any, storage: dtypes.StorageType) -> int:
"""
if hasattr(array, 'data_ptr'):
return array.data_ptr()
if isinstance(array, ctypes.Array):
return ctypes.addressof(array)

if storage == dtypes.StorageType.GPU_Global:
try:
Expand Down Expand Up @@ -219,7 +220,6 @@ def __init__(self, sdfg, lib: ReloadableDLL, argnames: List[str] = None):
self.has_gpu_code = True
break


def get_exported_function(self, name: str, restype=None) -> Optional[Callable[..., Any]]:
"""
Tries to find a symbol by name in the compiled SDFG, and convert it to a callable function
Expand All @@ -233,7 +233,6 @@ def get_exported_function(self, name: str, restype=None) -> Optional[Callable[..
except KeyError: # Function not found
return None


def get_state_struct(self) -> ctypes.Structure:
""" Attempt to parse the SDFG source code and extract the state struct. This method will parse the first
consecutive entries in the struct that are pointers. As soon as a non-pointer or other unparseable field is
Expand All @@ -247,7 +246,6 @@ def get_state_struct(self) -> ctypes.Structure:

return ctypes.cast(self._libhandle, ctypes.POINTER(self._try_parse_state_struct())).contents


def _try_parse_state_struct(self) -> Optional[Type[ctypes.Structure]]:
from dace.codegen.targets.cpp import mangle_dace_state_struct_name # Avoid import cycle
# the path of the main sdfg file containing the state struct
Expand Down Expand Up @@ -375,7 +373,6 @@ def _get_error_text(self, result: Union[str, int]) -> str:
else:
return result


def __call__(self, *args, **kwargs):
"""
Forwards the Python call to the compiled ``SDFG``.
Expand All @@ -400,13 +397,12 @@ def __call__(self, *args, **kwargs):
elif len(args) > 0 and self.argnames is not None:
kwargs.update(
# `_construct_args` will handle all of its arguments as kwargs.
{aname: arg for aname, arg in zip(self.argnames, args)}
)
argtuple, initargtuple = self._construct_args(kwargs) # Missing arguments will be detected here.
# Return values are cached in `self._lastargs`.
{aname: arg
for aname, arg in zip(self.argnames, args)})
argtuple, initargtuple = self._construct_args(kwargs) # Missing arguments will be detected here.
# Return values are cached in `self._lastargs`.
return self.fast_call(argtuple, initargtuple, do_gpu_check=True)


def fast_call(
self,
callargs: Tuple[Any, ...],
Expand Down Expand Up @@ -455,15 +451,13 @@ def fast_call(
self._lib.unload()
raise


def __del__(self):
if self._initialized is True:
self.finalize()
self._initialized = False
self._libhandle = ctypes.c_void_p(0)
self._lib.unload()


def _construct_args(self, kwargs) -> Tuple[Tuple[Any], Tuple[Any]]:
"""
Main function that controls argument construction for calling
Expand All @@ -486,7 +480,7 @@ def _construct_args(self, kwargs) -> Tuple[Tuple[Any], Tuple[Any]]:
typedict = self._typedict
if len(kwargs) > 0:
# Construct mapping from arguments to signature
arglist = []
arglist = []
argtypes = []
argnames = []
for a in sig:
Expand Down Expand Up @@ -516,13 +510,15 @@ def _construct_args(self, kwargs) -> Tuple[Tuple[Any], Tuple[Any]]:
if atype.optional is False: # If array cannot be None
raise TypeError(f'Passing a None value to a non-optional array in argument "{a}"')
# Otherwise, None values are passed as null pointers below
elif isinstance(arg, ctypes._Pointer):
pass
else:
raise TypeError(f'Passing an object (type {type(arg).__name__}) to an array in argument "{a}"')
elif is_array and not is_dtArray:
# GPU scalars and return values are pointers, so this is fine
if atype.storage != dtypes.StorageType.GPU_Global and not a.startswith('__return'):
raise TypeError(f'Passing an array to a scalar (type {atype.dtype.ctype}) in argument "{a}"')
elif (is_dtArray and is_ndarray and not isinstance(atype, dt.StructArray)
elif (is_dtArray and is_ndarray and not isinstance(atype, dt.ContainerArray)
and atype.dtype.as_numpy_dtype() != arg.dtype):
# Make exception for vector types
if (isinstance(atype.dtype, dtypes.vector) and atype.dtype.vtype.as_numpy_dtype() == arg.dtype):
Expand All @@ -536,10 +532,9 @@ def _construct_args(self, kwargs) -> Tuple[Tuple[Any], Tuple[Any]]:
'you are doing, you can override this error in the '
'configuration by setting compiler.allow_view_arguments '
'to True.')
elif (not isinstance(atype, (dt.Array, dt.Structure)) and
not isinstance(atype.dtype, dtypes.callback) and
not isinstance(arg, (atype.dtype.type, sp.Basic)) and
not (isinstance(arg, symbolic.symbol) and arg.dtype == atype.dtype)):
elif (not isinstance(atype, (dt.Array, dt.Structure)) and not isinstance(atype.dtype, dtypes.callback)
and not isinstance(arg, (atype.dtype.type, sp.Basic))
and not (isinstance(arg, symbolic.symbol) and arg.dtype == atype.dtype)):
is_int = isinstance(arg, int)
if is_int and atype.dtype.type == np.int64:
pass
Expand Down Expand Up @@ -573,30 +568,24 @@ def _construct_args(self, kwargs) -> Tuple[Tuple[Any], Tuple[Any]]:
# Retain only the element datatype for upcoming checks and casts
arg_ctypes = tuple(at.dtype.as_ctypes() for at in argtypes)

constants = self.sdfg.constants
callparams = tuple(
(actype(arg.get())
if isinstance(arg, symbolic.symbol)
else arg, actype, atype, aname
)
for arg, actype, atype, aname in zip(arglist, arg_ctypes, argtypes, argnames)
if not (symbolic.issymbolic(arg) and (hasattr(arg, 'name') and arg.name in constants))
)
constants = self.sdfg.constants
callparams = tuple((actype(arg.get()) if isinstance(arg, symbolic.symbol) else arg, actype, atype, aname)
for arg, actype, atype, aname in zip(arglist, arg_ctypes, argtypes, argnames)
if not (symbolic.issymbolic(arg) and (hasattr(arg, 'name') and arg.name in constants)))

symbols = self._free_symbols
initargs = tuple(
actype(arg) if not isinstance(arg, ctypes._SimpleCData) else arg
for arg, actype, atype, aname in callparams
if aname in symbols
)
actype(arg) if not isinstance(arg, (ctypes._SimpleCData, ctypes._Pointer)) else arg
for arg, actype, atype, aname in callparams if aname in symbols)

try:
# Replace arrays with their base host/device pointers
newargs = [None] * len(callparams)
for i, (arg, actype, atype, _) in enumerate(callparams):
if dtypes.is_array(arg):
newargs[i] = ctypes.c_void_p(_array_interface_ptr(arg, atype.storage)) # `c_void_p` is subclass of `ctypes._SimpleCData`.
elif not isinstance(arg, (ctypes._SimpleCData)):
newargs[i] = ctypes.c_void_p(_array_interface_ptr(
arg, atype.storage)) # `c_void_p` is subclass of `ctypes._SimpleCData`.
elif not isinstance(arg, (ctypes._SimpleCData, ctypes._Pointer)):
newargs[i] = actype(arg)
else:
newargs[i] = arg
Expand All @@ -607,11 +596,9 @@ def _construct_args(self, kwargs) -> Tuple[Tuple[Any], Tuple[Any]]:
self._lastargs = newargs, initargs
return self._lastargs


def clear_return_values(self):
self._create_new_arrays = True


def _create_array(self, _: str, dtype: np.dtype, storage: dtypes.StorageType, shape: Tuple[int],
strides: Tuple[int], total_size: int):
ndarray = np.ndarray
Expand All @@ -636,7 +623,6 @@ def ndarray(*args, buffer=None, **kwargs):
# Create an array with the properties of the SDFG array
return ndarray(shape, dtype, buffer=zeros(total_size, dtype), strides=strides)


def _initialize_return_values(self, kwargs):
# Obtain symbol values from arguments and constants
syms = dict()
Expand Down Expand Up @@ -687,7 +673,6 @@ def _initialize_return_values(self, kwargs):
arr = self._create_array(*shape_desc)
self._return_arrays.append(arr)


def _convert_return_values(self):
# Return the values as they would be from a Python function
if self._return_arrays is None or len(self._return_arrays) == 0:
Expand Down
4 changes: 2 additions & 2 deletions dace/codegen/dispatcher.py
Original file line number Diff line number Diff line change
Expand Up @@ -505,11 +505,11 @@ def get_copy_dispatcher(self, src_node, dst_node, edge, sdfg, state):
dst_is_data = True

# Skip copies to/from views where edge matches
if src_is_data and isinstance(src_node.desc(sdfg), (dt.StructureView, dt.View)):
if src_is_data and isinstance(src_node.desc(sdfg), dt.View):
e = sdutil.get_view_edge(state, src_node)
if e is edge:
return None
if dst_is_data and isinstance(dst_node.desc(sdfg), (dt.StructureView, dt.View)):
if dst_is_data and isinstance(dst_node.desc(sdfg), dt.View):
e = sdutil.get_view_edge(state, dst_node)
if e is edge:
return None
Expand Down
21 changes: 18 additions & 3 deletions dace/codegen/targets/cpp.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,6 +61,12 @@ def copy_expr(
packed_types=False,
):
data_desc = sdfg.arrays[data_name]
# TODO: Are there any cases where a mix of '.' and '->' is needed when traversing nested structs?
tokens = data_name.split('.')
if len(tokens) > 1 and tokens[0] in sdfg.arrays and isinstance(sdfg.arrays[tokens[0]], data.Structure):
name = data_name.replace('.', '->')
else:
name = data_name
ptrname = ptr(data_name, data_desc, sdfg, dispatcher.frame)
if relative_offset:
s = memlet.subset
Expand Down Expand Up @@ -99,6 +105,7 @@ def copy_expr(
# get conf flag
decouple_array_interfaces = Config.get_bool("compiler", "xilinx", "decouple_array_interfaces")

# TODO: Study structures on FPGAs. Should probably use 'name' instead of 'data_name' here.
expr = fpga.fpga_ptr(
data_name,
data_desc,
Expand All @@ -112,7 +119,7 @@ def copy_expr(
and not isinstance(data_desc, data.View),
decouple_array_interfaces=decouple_array_interfaces)
else:
expr = ptr(data_name, data_desc, sdfg, dispatcher.frame)
expr = ptr(name, data_desc, sdfg, dispatcher.frame)

add_offset = offset_cppstr != "0"

Expand Down Expand Up @@ -344,7 +351,7 @@ def make_const(expr: str) -> str:
is_scalar = False
elif defined_type == DefinedType.Scalar:
typedef = defined_ctype if is_scalar else (defined_ctype + '*')
if is_write is False:
if is_write is False and not isinstance(desc, data.Structure):
typedef = make_const(typedef)
ref = '&' if is_scalar else ''
defined_type = DefinedType.Scalar if is_scalar else DefinedType.Pointer
Expand Down Expand Up @@ -578,17 +585,25 @@ def cpp_array_expr(sdfg,
desc = (sdfg.arrays[memlet.data] if referenced_array is None else referenced_array)
offset_cppstr = cpp_offset_expr(desc, s, o, packed_veclen, indices=indices)

# TODO: Are there any cases where a mix of '.' and '->' is needed when traversing nested structs?
tokens = memlet.data.split('.')
if len(tokens) > 1 and tokens[0] in sdfg.arrays and isinstance(sdfg.arrays[tokens[0]], data.Structure):
name = memlet.data.replace('.', '->')
else:
name = memlet.data

if with_brackets:
if fpga.is_fpga_array(desc):
# get conf flag
decouple_array_interfaces = Config.get_bool("compiler", "xilinx", "decouple_array_interfaces")
# TODO: Study structures on FPGAs. Should probably use 'name' instead of 'memlet.data' here.
ptrname = fpga.fpga_ptr(memlet.data,
desc,
sdfg,
subset,
decouple_array_interfaces=decouple_array_interfaces)
else:
ptrname = ptr(memlet.data, desc, sdfg, codegen)
ptrname = ptr(name, desc, sdfg, codegen)
return "%s[%s]" % (ptrname, offset_cppstr)
else:
return offset_cppstr
Expand Down

0 comments on commit 7c5d0b6

Please sign in to comment.