Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

More NumPy operation implementations #1498

Open
wants to merge 21 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from 18 commits
Commits
Show all changes
21 commits
Select commit Hold shift + click to select a range
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
8 changes: 4 additions & 4 deletions dace/codegen/cppunparse.py
Original file line number Diff line number Diff line change
Expand Up @@ -749,6 +749,8 @@ def _Num(self, t):
# For complex values, use ``dtype_to_typeclass``
if isinstance(t_n, complex):
dtype = dtypes.dtype_to_typeclass(complex)
repr_n = f'{dtype}({t_n.real}, {t_n.imag})'


# Handle large integer values
if isinstance(t_n, int):
Expand All @@ -765,10 +767,8 @@ def _Num(self, t):
elif bits >= 64:
warnings.warn(f'Value wider than 64 bits encountered in expression ({t_n}), emitting as-is')

if repr_n.endswith("j"):
self.write("%s(0, %s)" % (dtype, repr_n.replace("inf", INFSTR)[:-1]))
else:
self.write(repr_n.replace("inf", INFSTR))
repr_n = repr_n.replace("inf", INFSTR)
self.write(repr_n)

def _List(self, t):
raise NotImplementedError('Invalid C++')
Expand Down
59 changes: 39 additions & 20 deletions dace/frontend/python/newast.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,7 +49,6 @@
Shape = Union[ShapeTuple, ShapeList]
DependencyType = Dict[str, Tuple[SDFGState, Union[Memlet, nodes.Tasklet], Tuple[int]]]


if sys.version_info < (3, 8):
_simple_ast_nodes = (ast.Constant, ast.Name, ast.NameConstant, ast.Num)
BytesConstant = ast.Bytes
Expand All @@ -65,15 +64,13 @@
NumConstant = ast.Constant
StrConstant = ast.Constant


if sys.version_info < (3, 9):
Index = ast.Index
ExtSlice = ast.ExtSlice
else:
Index = type(None)
ExtSlice = type(None)


if sys.version_info < (3, 12):
TypeAlias = type(None)
else:
Expand Down Expand Up @@ -4330,7 +4327,14 @@ def visit_Call(self, node: ast.Call, create_callbacks=False):
func = node.func.value

if func is None:
funcname = rname(node)
func_result = self.visit(node.func)
if isinstance(func_result, str):
if isinstance(node.func, ast.Attribute):
funcname = f'{func_result}.{node.func.attr}'
else:
funcname = func_result
else:
funcname = rname(node)
# Check if the function exists as an SDFG in a different module
modname = until(funcname, '.')
if ('.' in funcname and len(modname) > 0 and modname in self.globals
Expand Down Expand Up @@ -4426,7 +4430,7 @@ def visit_Call(self, node: ast.Call, create_callbacks=False):
arg = self.scope_vars[modname]
else:
# Fallback to (name, object)
arg = (modname, self.defined[modname])
arg = modname
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is this change handled already in the code somehow, i.e., is self.defined queried on demand?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It’s a code path that we didn’t cover before in coverage. The rest of the code assumes a string return type rather than a tuple, and crashes somewhere else.

args.append(arg)
# Otherwise, try to find a default implementation for the SDFG
elif not found_ufunc:
Expand Down Expand Up @@ -4623,12 +4627,18 @@ def _visitname(self, name: str, node: ast.AST):
self.sdfg.add_symbol(result.name, result.dtype)
return result

if name in self.closure.callbacks:
return name

if name in self.sdfg.arrays:
return name

if name in self.sdfg.symbols:
return name

if name in __builtins__:
return name

if name not in self.scope_vars:
raise DaceSyntaxError(self, node, 'Use of undefined variable "%s"' % name)
rname = self.scope_vars[name]
Expand Down Expand Up @@ -4673,30 +4683,39 @@ def visit_NameConstant(self, node: NameConstant):
return self.visit_Constant(node)

def visit_Attribute(self, node: ast.Attribute):
# If visiting an attribute, return attribute value if it's of an array or global
name = until(astutils.unparse(node), '.')
result = self._visitname(name, node)
result = self.visit(node.value)
if isinstance(result, (tuple, list, dict)):
if len(result) > 1:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is this the fix for attributes on expressions?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Exactly

raise DaceSyntaxError(
self, node.value, f'{type(result)} object cannot use attributes. Try storing the '
'object to a different variable first (e.g., ``a = result; a.attribute``')
else:
result = result[0]

if isinstance(result, str) and result in self.sdfg.arrays:
arr = self.sdfg.arrays[result]
elif isinstance(result, str) and result in self.scope_arrays:
arr = self.scope_arrays[result]
else:
return result
arr = None

# Try to find sub-SDFG attribute
func = oprepo.Replacements.get_attribute(type(arr), node.attr)
if func is not None:
# A new state is likely needed here, e.g., for transposition (ndarray.T)
self._add_state('%s_%d' % (type(node).__name__, node.lineno))
self.last_state.set_default_lineinfo(self.current_lineinfo)
result = func(self, self.sdfg, self.last_state, result)
self.last_state.set_default_lineinfo(None)
return result
if arr is not None:
func = oprepo.Replacements.get_attribute(type(arr), node.attr)
if func is not None:
# A new state is likely needed here, e.g., for transposition (ndarray.T)
self._add_state('%s_%d' % (type(node).__name__, node.lineno))
self.last_state.set_default_lineinfo(self.current_lineinfo)
result = func(self, self.sdfg, self.last_state, result)
self.last_state.set_default_lineinfo(None)
return result

# Otherwise, try to find compile-time attribute (such as shape)
try:
return getattr(arr, node.attr)
except KeyError:
if arr is not None:
return getattr(arr, node.attr)
return getattr(result, node.attr)
except (AttributeError, KeyError):
return result

def visit_List(self, node: ast.List):
Expand All @@ -4718,7 +4737,7 @@ def visit_Dict(self, node: ast.Dict):
def visit_Lambda(self, node: ast.Lambda):
# Return a string representation of the function
return astutils.unparse(node)

def visit_TypeAlias(self, node: TypeAlias):
raise NotImplementedError('Type aliases are not supported in DaCe')

Expand Down
2 changes: 2 additions & 0 deletions dace/frontend/python/preprocessing.py
Original file line number Diff line number Diff line change
Expand Up @@ -527,6 +527,8 @@ def global_value_to_node(self,
elif isinstance(value, symbolic.symbol):
# Symbols resolve to the symbol name
newnode = ast.Name(id=value.name, ctx=ast.Load())
elif isinstance(value, sympy.Basic): # Symbolic or constant expression
newnode = ast.parse(symbolic.symstr(value)).body[0].value
elif isinstance(value, ast.Name):
newnode = ast.Name(id=value.id, ctx=ast.Load())
elif (dtypes.isconstant(value) or isinstance(value, (StringLiteral, SDFG)) or hasattr(value, '__sdfg__')):
Expand Down