Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Missing atomic in WCR code generation. #1528

Open
wants to merge 16 commits into
base: master
Choose a base branch
from
Open
2 changes: 1 addition & 1 deletion .github/workflows/fpga-ci.yml
Original file line number Diff line number Diff line change
Expand Up @@ -54,7 +54,7 @@ jobs:
# and overflowed in the year 2022, run the FPGA tests pretending like it's January 1st 2021.
# faketime -f "@2021-01-01 00:00:00" pytest -n auto --cov-report=xml --cov=dace --tb=short -m "fpga"
# Try running without faketime
pytest -n auto --cov-report=xml --cov=dace --tb=short -m "fpga"
pytest --cov-report=xml --cov=dace --tb=short -m "fpga"

coverage report
coverage xml
Expand Down
55 changes: 18 additions & 37 deletions dace/codegen/compiled_sdfg.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,9 +31,8 @@ def __init__(self, library_filename, program_name):
:param program_name: Name of the DaCe program (for use in finding
the stub library loader).
"""
self._stub_filename = os.path.join(
os.path.dirname(os.path.realpath(library_filename)),
f'libdacestub_{program_name}.{Config.get("compiler", "library_extension")}')
self._stub_filename = os.path.join(os.path.dirname(os.path.realpath(library_filename)),
f'libdacestub_{program_name}.{Config.get("compiler", "library_extension")}')
self._library_filename = os.path.realpath(library_filename)
self._stub = None
self._lib = None
Expand Down Expand Up @@ -219,7 +218,6 @@ def __init__(self, sdfg, lib: ReloadableDLL, argnames: List[str] = None):
self.has_gpu_code = True
break


def get_exported_function(self, name: str, restype=None) -> Optional[Callable[..., Any]]:
"""
Tries to find a symbol by name in the compiled SDFG, and convert it to a callable function
Expand All @@ -233,7 +231,6 @@ def get_exported_function(self, name: str, restype=None) -> Optional[Callable[..
except KeyError: # Function not found
return None


def get_state_struct(self) -> ctypes.Structure:
""" Attempt to parse the SDFG source code and extract the state struct. This method will parse the first
consecutive entries in the struct that are pointers. As soon as a non-pointer or other unparseable field is
Expand All @@ -247,7 +244,6 @@ def get_state_struct(self) -> ctypes.Structure:

return ctypes.cast(self._libhandle, ctypes.POINTER(self._try_parse_state_struct())).contents


def _try_parse_state_struct(self) -> Optional[Type[ctypes.Structure]]:
from dace.codegen.targets.cpp import mangle_dace_state_struct_name # Avoid import cycle
# the path of the main sdfg file containing the state struct
Expand Down Expand Up @@ -375,7 +371,6 @@ def _get_error_text(self, result: Union[str, int]) -> str:
else:
return result


def __call__(self, *args, **kwargs):
"""
Forwards the Python call to the compiled ``SDFG``.
Expand All @@ -400,13 +395,12 @@ def __call__(self, *args, **kwargs):
elif len(args) > 0 and self.argnames is not None:
kwargs.update(
# `_construct_args` will handle all of its arguments as kwargs.
{aname: arg for aname, arg in zip(self.argnames, args)}
)
argtuple, initargtuple = self._construct_args(kwargs) # Missing arguments will be detected here.
# Return values are cached in `self._lastargs`.
{aname: arg
for aname, arg in zip(self.argnames, args)})
argtuple, initargtuple = self._construct_args(kwargs) # Missing arguments will be detected here.
# Return values are cached in `self._lastargs`.
return self.fast_call(argtuple, initargtuple, do_gpu_check=True)


def fast_call(
self,
callargs: Tuple[Any, ...],
Expand Down Expand Up @@ -455,15 +449,13 @@ def fast_call(
self._lib.unload()
raise


def __del__(self):
if self._initialized is True:
self.finalize()
self._initialized = False
self._libhandle = ctypes.c_void_p(0)
self._lib.unload()


def _construct_args(self, kwargs) -> Tuple[Tuple[Any], Tuple[Any]]:
"""
Main function that controls argument construction for calling
Expand All @@ -486,7 +478,7 @@ def _construct_args(self, kwargs) -> Tuple[Tuple[Any], Tuple[Any]]:
typedict = self._typedict
if len(kwargs) > 0:
# Construct mapping from arguments to signature
arglist = []
arglist = []
argtypes = []
argnames = []
for a in sig:
Expand Down Expand Up @@ -536,10 +528,9 @@ def _construct_args(self, kwargs) -> Tuple[Tuple[Any], Tuple[Any]]:
'you are doing, you can override this error in the '
'configuration by setting compiler.allow_view_arguments '
'to True.')
elif (not isinstance(atype, (dt.Array, dt.Structure)) and
not isinstance(atype.dtype, dtypes.callback) and
not isinstance(arg, (atype.dtype.type, sp.Basic)) and
not (isinstance(arg, symbolic.symbol) and arg.dtype == atype.dtype)):
elif (not isinstance(atype, (dt.Array, dt.Structure)) and not isinstance(atype.dtype, dtypes.callback)
and not isinstance(arg, (atype.dtype.type, sp.Basic))
and not (isinstance(arg, symbolic.symbol) and arg.dtype == atype.dtype)):
is_int = isinstance(arg, int)
if is_int and atype.dtype.type == np.int64:
pass
Expand Down Expand Up @@ -573,29 +564,23 @@ def _construct_args(self, kwargs) -> Tuple[Tuple[Any], Tuple[Any]]:
# Retain only the element datatype for upcoming checks and casts
arg_ctypes = tuple(at.dtype.as_ctypes() for at in argtypes)

constants = self.sdfg.constants
callparams = tuple(
(actype(arg.get())
if isinstance(arg, symbolic.symbol)
else arg, actype, atype, aname
)
for arg, actype, atype, aname in zip(arglist, arg_ctypes, argtypes, argnames)
if not (symbolic.issymbolic(arg) and (hasattr(arg, 'name') and arg.name in constants))
)
constants = self.sdfg.constants
callparams = tuple((arg, actype, atype, aname)
for arg, actype, atype, aname in zip(arglist, arg_ctypes, argtypes, argnames)
if not (symbolic.issymbolic(arg) and (hasattr(arg, 'name') and arg.name in constants)))

symbols = self._free_symbols
initargs = tuple(
actype(arg) if not isinstance(arg, ctypes._SimpleCData) else arg
for arg, actype, atype, aname in callparams
if aname in symbols
)
actype(arg) if not isinstance(arg, ctypes._SimpleCData) else arg for arg, actype, atype, aname in callparams
if aname in symbols)

try:
# Replace arrays with their base host/device pointers
newargs = [None] * len(callparams)
for i, (arg, actype, atype, _) in enumerate(callparams):
if dtypes.is_array(arg):
newargs[i] = ctypes.c_void_p(_array_interface_ptr(arg, atype.storage)) # `c_void_p` is subclass of `ctypes._SimpleCData`.
newargs[i] = ctypes.c_void_p(_array_interface_ptr(
arg, atype.storage)) # `c_void_p` is subclass of `ctypes._SimpleCData`.
elif not isinstance(arg, (ctypes._SimpleCData)):
newargs[i] = actype(arg)
else:
Expand All @@ -607,11 +592,9 @@ def _construct_args(self, kwargs) -> Tuple[Tuple[Any], Tuple[Any]]:
self._lastargs = newargs, initargs
return self._lastargs


def clear_return_values(self):
self._create_new_arrays = True


def _create_array(self, _: str, dtype: np.dtype, storage: dtypes.StorageType, shape: Tuple[int],
strides: Tuple[int], total_size: int):
ndarray = np.ndarray
Expand All @@ -636,7 +619,6 @@ def ndarray(*args, buffer=None, **kwargs):
# Create an array with the properties of the SDFG array
return ndarray(shape, dtype, buffer=zeros(total_size, dtype), strides=strides)


def _initialize_return_values(self, kwargs):
# Obtain symbol values from arguments and constants
syms = dict()
Expand Down Expand Up @@ -687,7 +669,6 @@ def _initialize_return_values(self, kwargs):
arr = self._create_array(*shape_desc)
self._return_arrays.append(arr)


def _convert_return_values(self):
# Return the values as they would be from a Python function
if self._return_arrays is None or len(self._return_arrays) == 0:
Expand Down
3 changes: 3 additions & 0 deletions dace/frontend/python/newast.py
Original file line number Diff line number Diff line change
Expand Up @@ -3431,6 +3431,9 @@ def _visit_assign(self, node, node_target, op, dtype=None, is_return=False):
self._add_aug_assignment(node, rtarget, wtarget, result, op, boolarr)
else:
self._add_assignment(node, wtarget, result, op, boolarr)
if op and not independent:
# NOTE: Assuming WCR on the memlet
self.outputs[new_name][1].wcr = LambdaProperty.from_string('lambda x, y: x {} y'.format(op))

# Connect states properly when there is output indirection
if output_indirection:
Expand Down
7 changes: 2 additions & 5 deletions dace/frontend/python/wrappers.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,12 +12,9 @@


def ndarray(shape, dtype=numpy.float64, *args, **kwargs):
""" Returns a numpy ndarray where all symbols have been evaluated to
numbers and types are converted to numpy types. """
repldict = {sym: sym.get() for sym in symbolic.symlist(shape).values()}
new_shape = [int(s.subs(repldict) if symbolic.issymbolic(s) else s) for s in shape]
""" Returns a numpy ndarray where all types are converted to numpy types. """
new_dtype = dtype.type if isinstance(dtype, dtypes.typeclass) else dtype
return numpy.ndarray(shape=new_shape, dtype=new_dtype, *args, **kwargs)
return numpy.ndarray(shape=shape, dtype=new_dtype, *args, **kwargs)


stream: Type[Deque[T]] = deque
Expand Down
10 changes: 1 addition & 9 deletions dace/sdfg/sdfg.py
Original file line number Diff line number Diff line change
Expand Up @@ -2125,16 +2125,8 @@ def specialize(self, symbols: Dict[str, Any]):

:param symbols: Values to specialize.
"""
# Set symbol values to add
syms = {
# If symbols are passed, extract the value. If constants are
# passed, use them directly.
name: val.get() if isinstance(val, dace.symbolic.symbol) else val
for name, val in symbols.items()
}

# Update constants
for k, v in syms.items():
for k, v in symbols.items():
self.add_constant(str(k), v)

def is_loaded(self) -> bool:
Expand Down
46 changes: 7 additions & 39 deletions dace/symbolic.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,19 +53,10 @@ def __new__(cls, name=None, dtype=DEFAULT_SYMBOL_TYPE, **assumptions):

self.dtype = dtype
self._constraints = []
self.value = None
return self

def set(self, value):
warnings.warn('symbol.set is deprecated, use keyword arguments', DeprecationWarning)
if value is not None:
# First, check constraints
self.check_constraints(value)

self.value = self.dtype(value)

def __getstate__(self):
return dict(self.assumptions0, **{'value': self.value, 'dtype': self.dtype, '_constraints': self._constraints})
return dict(self.assumptions0, **{'dtype': self.dtype, '_constraints': self._constraints})

def _eval_subs(self, old, new):
"""
Expand All @@ -85,15 +76,6 @@ def _eval_subs(self, old, new):
except AttributeError:
return None

def is_initialized(self):
return self.value is not None

def get(self):
warnings.warn('symbol.get is deprecated, use keyword arguments', DeprecationWarning)
if self.value is None:
raise UnboundLocalError('Uninitialized symbol value for \'' + self.name + '\'')
return self.value

def set_constraints(self, constraint_list):
try:
iter(constraint_list)
Expand Down Expand Up @@ -141,9 +123,6 @@ def check_constraints(self, value):
if fail is not None:
raise RuntimeError('Value %s invalidates constraint %s for symbol %s' % (str(value), str(fail), self.name))

def get_or_return(self, uninitialized_ret):
return self.value or uninitialized_ret


class SymExpr(object):
""" Symbolic expressions with support for an overapproximation expression.
Expand Down Expand Up @@ -287,13 +266,6 @@ def __gt__(self, other):
SymbolicType = Union[sympy.Basic, SymExpr]


def symvalue(val):
""" Returns the symbol value if it is a symbol. """
if isinstance(val, symbol):
return val.get()
return val


# http://stackoverflow.com/q/3844948/
def _checkEqualIvo(lst):
return not lst or lst.count(lst[0]) == len(lst)
Expand Down Expand Up @@ -333,9 +305,8 @@ def symlist(values):
return result


def evaluate(expr: Union[sympy.Basic, int, float],
symbols: Dict[Union[symbol, str], Union[int, float]]) -> \
Union[int, float, numpy.number]:
def evaluate(expr: Union[sympy.Basic, int, float], symbols: Dict[Union[symbol, str],
Union[int, float]]) -> Union[int, float, numpy.number]:
"""
Evaluates an expression to a constant based on a mapping from symbols
to values.
Expand All @@ -356,9 +327,7 @@ def evaluate(expr: Union[sympy.Basic, int, float],
return expr

# Evaluate all symbols
syms = {(sname if isinstance(sname, sympy.Symbol) else symbol(sname)):
sval.get() if isinstance(sval, symbol) else sval
for sname, sval in symbols.items()}
syms = {(sname if isinstance(sname, sympy.Symbol) else symbol(sname)): sval for sname, sval in symbols.items()}

# Filter out `None` values, callables, and iterables but not strings (for SymPy 1.12)
syms = {
Expand Down Expand Up @@ -1028,7 +997,7 @@ def visit_IfExp(self, node):
self.visit(node.orelse)],
keywords=[])
return ast.copy_location(new_node, node)

def visit_Subscript(self, node):
if isinstance(node.value, ast.Attribute):
attr = ast.Subscript(value=ast.Name(id=node.value.attr, ctx=ast.Load()), slice=node.slice, ctx=ast.Load())
Expand Down Expand Up @@ -1405,8 +1374,7 @@ def equal(a: SymbolicType, b: SymbolicType, is_length: bool = True) -> Union[boo
return sympy.ask(sympy.Q.is_true(sympy.Eq(*args)))


def symbols_in_code(code: str, potential_symbols: Set[str] = None,
symbols_to_ignore: Set[str] = None) -> Set[str]:
def symbols_in_code(code: str, potential_symbols: Set[str] = None, symbols_to_ignore: Set[str] = None) -> Set[str]:
"""
Tokenizes a code string for symbols and returns a set thereof.

Expand All @@ -1419,7 +1387,7 @@ def symbols_in_code(code: str, potential_symbols: Set[str] = None,
if potential_symbols is not None and len(potential_symbols) == 0:
# Don't bother tokenizing for an empty set of potential symbols
return set()

tokens = set(re.findall(_NAME_TOKENS, code))
if potential_symbols is not None:
tokens &= potential_symbols
Expand Down